File size: 11,919 Bytes
4ee3607
 
 
 
 
 
 
 
 
 
f431b5d
4ee3607
031c327
4ee3607
 
 
 
 
 
031c327
 
 
 
 
 
 
f431b5d
031c327
4ee3607
 
 
 
 
 
 
 
 
 
 
031c327
4ee3607
 
f431b5d
 
 
4ee3607
231c7d9
 
4ee3607
 
 
61abf28
32c1712
 
 
61abf28
 
4ee3607
 
 
 
 
f431b5d
4ee3607
 
f431b5d
4ee3607
f431b5d
4ee3607
 
f431b5d
4ee3607
 
 
f431b5d
4ee3607
f431b5d
 
 
 
4ee3607
f431b5d
4ee3607
 
 
 
f431b5d
4ee3607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231c7d9
 
 
4ee3607
031c327
231c7d9
4ee3607
 
 
f431b5d
4ee3607
 
 
231c7d9
f431b5d
 
 
231c7d9
031c327
231c7d9
4ee3607
231c7d9
031c327
4ee3607
 
f431b5d
 
 
 
 
4ee3607
 
f431b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5d4e06
f431b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ee3607
 
 
 
 
 
 
 
 
 
 
 
 
 
231c7d9
4ee3607
 
 
 
 
 
 
f431b5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# =============================================================================
# train.py
# Dataset Preparation + Finetuning Entry Point
# SmolLM2 Service Space
# Copyright 2026 - Volkan KΓΌcΓΌkbudak
# Apache License V2 + ESOL 1.1
# =============================================================================
# Usage:
#   python train.py --mode export   β†’ export HF dataset to training format
#   python train.py --mode validate β†’ validate ADI weights against dataset
#   python train.py --mode finetune β†’ finetune SmolLM2 on exported data
# =============================================================================
import os
import argparse
import json
import logging
from datetime import datetime
from pathlib import Path

# ── Path Resolution ───────────────────────────────────────────────────────────
# HF Spaces: /tmp/ (read-only filesystem)
# Local dev: current directory
_TMP = Path("/tmp") if os.getenv("SPACE_ID") else Path(".")

TRAIN_DATA   = _TMP / "train_data.jsonl"
VALID_RESULT = _TMP / "validation_results.json"
MODEL_OUTPUT = _TMP / "finetuned_model"

import model as model_module
from adi import DumpindexAnalyzer

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("train")


# =============================================================================
# Mode 1 β€” Export dataset to training format
# =============================================================================

def export_dataset(output_path: str = None):
    """
    Export HF dataset logs to JSONL format for training.
    Includes HIGH_PRIORITY, MEDIUM_PRIORITY and BLOCKED entries.
    BLOCKED entries teach the model what to reject.
    REJECT entries (ADI noise/quality fail) are skipped β€” no response logged.
    """
    output = Path(output_path) if output_path else TRAIN_DATA

    logger.info("Loading dataset from HF...")
    entries = model_module.load_logs()

    # ── DEBUG: remove after fix ───────────────────────────────────────────
    #if entries:
        #logger.info(f"Keys: {list(entries[0].keys())}")
        #logger.info(f"Sample: {entries[0]}")
    # ─────────────────────────────────────────────────────────────────────

    if not entries:
        logger.warning("Dataset empty β€” nothing to export")
        return

    count = 0
    skipped = 0
    with open(output, "w") as f:
        for entry in entries:
            # Skip ADI-rejected entries β€” no meaningful response logged
            if entry.get("adi_decision") == "REJECT":
                skipped += 1
                continue
            if not entry.get("response"):
                skipped += 1
                continue

            # Format as instruction tuning pair
            # BLOCKED entries are included β€” model learns what to refuse
            record = {
                "instruction":  entry.get("system_prompt", "You are a helpful assistant."),
                "input":        entry.get("prompt", ""),
                "output":       entry.get("response", ""),
                "adi_score":    entry.get("adi_score"),
                "adi_decision": entry.get("adi_decision"),
                "is_safe":      entry.get("adi_decision") != "BLOCKED",
            }
            f.write(json.dumps(record) + "\n")
            count += 1

    logger.info(f"Exported {count}/{len(entries)} entries β†’ {output} (skipped: {skipped})")


# =============================================================================
# Mode 2 β€” Validate ADI weights against collected data
# =============================================================================

def validate_adi():
    """
    Run ADI weight validation against dataset.
    Uses entries that have human_label field (manually labeled).
    """
    logger.info("Loading dataset for ADI validation...")
    entries = model_module.load_logs()

    labeled = [(e["prompt"], e["human_label"]) for e in entries if e.get("human_label")]

    if not labeled:
        logger.warning("No labeled entries found β€” add 'human_label' field to dataset entries")
        logger.info("Expected labels: REJECT | MEDIUM_PRIORITY | HIGH_PRIORITY")
        return

    analyzer = DumpindexAnalyzer()
    accuracy = analyzer.validate_weights(labeled)
    logger.info(f"ADI Validation accuracy: {accuracy:.1%} on {len(labeled)} samples")

    result = {
        "timestamp": datetime.utcnow().isoformat(),
        "accuracy":  accuracy,
        "samples":   len(labeled),
        "weights":   analyzer.weights,
    }
    VALID_RESULT.write_text(json.dumps(result, indent=2))
    logger.info(f"Results saved β†’ {VALID_RESULT}")


# =============================================================================
# Mode 3 β€” Finetune SmolLM2 with TRL SFTTrainer
# =============================================================================

def finetune():
    """
    Finetune SmolLM2 on exported dataset using TRL SFTTrainer.
    Requires export first + enough data (500+ samples recommended).
    On completion: pushes finetuned weights to private HF model repo.
    """
    if not TRAIN_DATA.exists():
        logger.error(f"train_data.jsonl not found at {TRAIN_DATA} β€” run export first")
        return

    lines = TRAIN_DATA.read_text().strip().splitlines()
    logger.info(f"Training samples available: {len(lines)}")

    if len(lines) < 10:
        logger.error(f"Too few samples ({len(lines)}) β€” aborting finetune")
        return

    if len(lines) < 500:
        logger.warning(f"Only {len(lines)} samples β€” recommend 500+ for meaningful finetuning")

    # ── Imports ───────────────────────────────────────────────────────────────
    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer
        from trl import SFTTrainer, SFTConfig
        from datasets import Dataset
        import torch
    except ImportError as e:
        logger.error(f"Missing dependency: {e} β€” run: pip install trl transformers datasets torch")
        return

    # ── Load dataset ──────────────────────────────────────────────────────────
    logger.info("Loading training data...")
    records = [json.loads(l) for l in lines]

    def format_record(record):
        """Format record into chat template string."""
        instruction = record.get("instruction", "You are a helpful assistant.")
        user_input  = record.get("input", "")
        output      = record.get("output", "")
        return {
            "text": f"<|system|>\n{instruction}\n<|user|>\n{user_input}\n<|assistant|>\n{output}"
        }

    formatted = [format_record(r) for r in records]
    dataset   = Dataset.from_list(formatted)
    logger.info(f"Dataset ready: {len(dataset)} samples")

    # ── Load model + tokenizer ────────────────────────────────────────────────
    model_id = model_module.get_model_id()
    kwargs   = model_module.get_model_kwargs()
    device   = "cuda" if torch.cuda.is_available() else "cpu"

    logger.info(f"Loading base model: {model_id} on {device}...")
    tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
    model     = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)

    # Ensure pad token exists
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # ── Training config ───────────────────────────────────────────────────────
    # Conservative settings for CPU / low RAM (2-8GB)
    sft_config = SFTConfig(
        output_dir=str(MODEL_OUTPUT),
        num_train_epochs=3,
        per_device_train_batch_size=1,      # CPU friendly
        gradient_accumulation_steps=4,      # effective batch size = 4
        learning_rate=2e-5,
        warmup_steps=10,
        logging_steps=10,
        save_steps=50,
        save_total_limit=2,
        fp16=False,                         # no GPU, no fp16
        bf16=False,
        dataloader_num_workers=0,           # HF Spaces: no multiprocessing
        report_to="none",                   # no wandb/tensorboard
        max_seq_length=512,                 # SmolLM2 context limit
        dataset_text_field="text",
    )

    # ── SFTTrainer ────────────────────────────────────────────────────────────
    logger.info("Initializing SFTTrainer...")
    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=dataset,
        #tokenizer=tokenizer,
    )

    # ── Train ─────────────────────────────────────────────────────────────────
    logger.info("Starting finetuning...")
    start = datetime.utcnow()
    trainer.train()
    duration = (datetime.utcnow() - start).total_seconds()
    logger.info(f"Training complete in {duration:.0f}s")

    # ── Save locally ──────────────────────────────────────────────────────────
    trainer.save_model(str(MODEL_OUTPUT))
    tokenizer.save_pretrained(str(MODEL_OUTPUT))
    logger.info(f"Model saved β†’ {MODEL_OUTPUT}")

    # ── Push to HF private repo ───────────────────────────────────────────────
    token        = model_module.TOKEN
    private_repo = model_module.PRIVATE_MODEL

    if token and private_repo:
        logger.info(f"Pushing to HF: {private_repo}...")
        try:
            model.push_to_hub(private_repo, token=token, private=True)
            tokenizer.push_to_hub(private_repo, token=token, private=True)
            model_module.push_model_card({
                "model_id":       model_id,
                "samples":        len(dataset),
                "epochs":         3,
                "duration_sec":   int(duration),
                "finetuned_from": model_id,
            })
            logger.info(f"Model pushed β†’ {private_repo}")
        except Exception as e:
            logger.error(f"Push failed: {type(e).__name__}: {e}")
    else:
        logger.warning("No token or private repo configured β€” skipping HF push")


# =============================================================================
# CLI
# =============================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SmolLM2 Training Utilities")
    parser.add_argument(
        "--mode",
        choices=["export", "validate", "finetune"],
        required=True,
        help="export: dump dataset to JSONL | validate: test ADI weights | finetune: train model"
    )
    parser.add_argument("--output", default=None, help="Output file for export mode (default: auto)")
    args = parser.parse_args()

    if args.mode == "export":
        export_dataset(args.output)
    elif args.mode == "validate":
        validate_adi()
    elif args.mode == "finetune":
        finetune()