Alibrown commited on
Commit
f431b5d
Β·
verified Β·
1 Parent(s): a51fcf9

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +128 -21
train.py CHANGED
@@ -8,7 +8,7 @@
8
  # Usage:
9
  # python train.py --mode export β†’ export HF dataset to training format
10
  # python train.py --mode validate β†’ validate ADI weights against dataset
11
- # python train.py --mode finetune β†’ finetune SmolLM2 on collected data (future)
12
  # =============================================================================
13
  import os
14
  import argparse
@@ -24,6 +24,7 @@ _TMP = Path("/tmp") if os.getenv("SPACE_ID") else Path(".")
24
 
25
  TRAIN_DATA = _TMP / "train_data.jsonl"
26
  VALID_RESULT = _TMP / "validation_results.json"
 
27
 
28
  import model as model_module
29
  from adi import DumpindexAnalyzer
@@ -39,7 +40,9 @@ logger = logging.getLogger("train")
39
  def export_dataset(output_path: str = None):
40
  """
41
  Export HF dataset logs to JSONL format for training.
42
- Filters: only HIGH_PRIORITY and MEDIUM_PRIORITY entries with actual responses.
 
 
43
  """
44
  output = Path(output_path) if output_path else TRAIN_DATA
45
 
@@ -51,26 +54,31 @@ def export_dataset(output_path: str = None):
51
  return
52
 
53
  count = 0
 
54
  with open(output, "w") as f:
55
  for entry in entries:
56
- # Only export entries where SmolLM2 actually responded
57
  if entry.get("adi_decision") == "REJECT":
 
58
  continue
59
  if not entry.get("response"):
 
60
  continue
61
 
62
  # Format as instruction tuning pair
 
63
  record = {
64
- "instruction": entry.get("system_prompt", "You are a helpful assistant."),
65
- "input": entry.get("prompt", ""),
66
- "output": entry.get("response", ""),
67
- "adi_score": entry.get("adi_score"),
68
  "adi_decision": entry.get("adi_decision"),
 
69
  }
70
  f.write(json.dumps(record) + "\n")
71
  count += 1
72
 
73
- logger.info(f"Exported {count}/{len(entries)} entries β†’ {output}")
74
 
75
 
76
  # =============================================================================
@@ -107,13 +115,14 @@ def validate_adi():
107
 
108
 
109
  # =============================================================================
110
- # Mode 3 β€” Finetune placeholder
111
  # =============================================================================
112
 
113
  def finetune():
114
  """
115
- Finetune SmolLM2 on collected dataset.
116
- Requires export first + enough data (>500 samples recommended).
 
117
  """
118
  if not TRAIN_DATA.exists():
119
  logger.error(f"train_data.jsonl not found at {TRAIN_DATA} β€” run export first")
@@ -122,17 +131,115 @@ def finetune():
122
  lines = TRAIN_DATA.read_text().strip().splitlines()
123
  logger.info(f"Training samples available: {len(lines)}")
124
 
125
- if len(lines) < 100:
 
 
 
 
126
  logger.warning(f"Only {len(lines)} samples β€” recommend 500+ for meaningful finetuning")
127
 
128
- # TODO: implement finetuning with transformers Trainer
129
- # Rough plan:
130
- # 1. Load base model via model.get_model_id()
131
- # 2. Tokenize TRAIN_DATA
132
- # 3. TrainingArguments + Trainer (or TRL SFTTrainer)
133
- # 4. Save to PRIVATE_MODEL repo via model.push_model_card()
134
- logger.info("Finetune placeholder β€” not yet implemented")
135
- logger.info("Next step: implement with transformers.Trainer or TRL SFTTrainer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
 
138
  # =============================================================================
@@ -155,4 +262,4 @@ if __name__ == "__main__":
155
  elif args.mode == "validate":
156
  validate_adi()
157
  elif args.mode == "finetune":
158
- finetune()
 
8
  # Usage:
9
  # python train.py --mode export β†’ export HF dataset to training format
10
  # python train.py --mode validate β†’ validate ADI weights against dataset
11
+ # python train.py --mode finetune β†’ finetune SmolLM2 on exported data
12
  # =============================================================================
13
  import os
14
  import argparse
 
24
 
25
  TRAIN_DATA = _TMP / "train_data.jsonl"
26
  VALID_RESULT = _TMP / "validation_results.json"
27
+ MODEL_OUTPUT = _TMP / "finetuned_model"
28
 
29
  import model as model_module
30
  from adi import DumpindexAnalyzer
 
40
  def export_dataset(output_path: str = None):
41
  """
42
  Export HF dataset logs to JSONL format for training.
43
+ Includes HIGH_PRIORITY, MEDIUM_PRIORITY and BLOCKED entries.
44
+ BLOCKED entries teach the model what to reject.
45
+ REJECT entries (ADI noise/quality fail) are skipped β€” no response logged.
46
  """
47
  output = Path(output_path) if output_path else TRAIN_DATA
48
 
 
54
  return
55
 
56
  count = 0
57
+ skipped = 0
58
  with open(output, "w") as f:
59
  for entry in entries:
60
+ # Skip ADI-rejected entries β€” no meaningful response logged
61
  if entry.get("adi_decision") == "REJECT":
62
+ skipped += 1
63
  continue
64
  if not entry.get("response"):
65
+ skipped += 1
66
  continue
67
 
68
  # Format as instruction tuning pair
69
+ # BLOCKED entries are included β€” model learns what to refuse
70
  record = {
71
+ "instruction": entry.get("system_prompt", "You are a helpful assistant."),
72
+ "input": entry.get("prompt", ""),
73
+ "output": entry.get("response", ""),
74
+ "adi_score": entry.get("adi_score"),
75
  "adi_decision": entry.get("adi_decision"),
76
+ "is_safe": entry.get("adi_decision") != "BLOCKED",
77
  }
78
  f.write(json.dumps(record) + "\n")
79
  count += 1
80
 
81
+ logger.info(f"Exported {count}/{len(entries)} entries β†’ {output} (skipped: {skipped})")
82
 
83
 
84
  # =============================================================================
 
115
 
116
 
117
  # =============================================================================
118
+ # Mode 3 β€” Finetune SmolLM2 with TRL SFTTrainer
119
  # =============================================================================
120
 
121
  def finetune():
122
  """
123
+ Finetune SmolLM2 on exported dataset using TRL SFTTrainer.
124
+ Requires export first + enough data (500+ samples recommended).
125
+ On completion: pushes finetuned weights to private HF model repo.
126
  """
127
  if not TRAIN_DATA.exists():
128
  logger.error(f"train_data.jsonl not found at {TRAIN_DATA} β€” run export first")
 
131
  lines = TRAIN_DATA.read_text().strip().splitlines()
132
  logger.info(f"Training samples available: {len(lines)}")
133
 
134
+ if len(lines) < 10:
135
+ logger.error(f"Too few samples ({len(lines)}) β€” aborting finetune")
136
+ return
137
+
138
+ if len(lines) < 500:
139
  logger.warning(f"Only {len(lines)} samples β€” recommend 500+ for meaningful finetuning")
140
 
141
+ # ── Imports ───────────────────────────────────────────────────────────────
142
+ try:
143
+ from transformers import AutoModelForCausalLM, AutoTokenizer
144
+ from trl import SFTTrainer, SFTConfig
145
+ from datasets import Dataset
146
+ import torch
147
+ except ImportError as e:
148
+ logger.error(f"Missing dependency: {e} β€” run: pip install trl transformers datasets torch")
149
+ return
150
+
151
+ # ── Load dataset ──────────────────────────────────────────────────────────
152
+ logger.info("Loading training data...")
153
+ records = [json.loads(l) for l in lines]
154
+
155
+ def format_record(record):
156
+ """Format record into chat template string."""
157
+ instruction = record.get("instruction", "You are a helpful assistant.")
158
+ user_input = record.get("input", "")
159
+ output = record.get("output", "")
160
+ return {
161
+ "text": f"<|system|>\n{instruction}\n<|user|>\n{user_input}\n<|assistant|>\n{output}"
162
+ }
163
+
164
+ formatted = [format_record(r) for r in records]
165
+ dataset = Dataset.from_list(formatted)
166
+ logger.info(f"Dataset ready: {len(dataset)} samples")
167
+
168
+ # ── Load model + tokenizer ────────────────────────────────────────────────
169
+ model_id = model_module.get_model_id()
170
+ kwargs = model_module.get_model_kwargs()
171
+ device = "cuda" if torch.cuda.is_available() else "cpu"
172
+
173
+ logger.info(f"Loading base model: {model_id} on {device}...")
174
+ tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
175
+ model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
176
+
177
+ # Ensure pad token exists
178
+ if tokenizer.pad_token is None:
179
+ tokenizer.pad_token = tokenizer.eos_token
180
+
181
+ # ── Training config ───────────────────────────────────────────────────────
182
+ # Conservative settings for CPU / low RAM (2-8GB)
183
+ sft_config = SFTConfig(
184
+ output_dir=str(MODEL_OUTPUT),
185
+ num_train_epochs=3,
186
+ per_device_train_batch_size=1, # CPU friendly
187
+ gradient_accumulation_steps=4, # effective batch size = 4
188
+ learning_rate=2e-5,
189
+ warmup_steps=10,
190
+ logging_steps=10,
191
+ save_steps=50,
192
+ save_total_limit=2,
193
+ fp16=False, # no GPU, no fp16
194
+ bf16=False,
195
+ dataloader_num_workers=0, # HF Spaces: no multiprocessing
196
+ report_to="none", # no wandb/tensorboard
197
+ max_seq_length=512, # SmolLM2 context limit
198
+ dataset_text_field="text",
199
+ )
200
+
201
+ # ── SFTTrainer ────────────────────────────────────────────────────────────
202
+ logger.info("Initializing SFTTrainer...")
203
+ trainer = SFTTrainer(
204
+ model=model,
205
+ args=sft_config,
206
+ train_dataset=dataset,
207
+ tokenizer=tokenizer,
208
+ )
209
+
210
+ # ── Train ─────────────────────────────────────────────────────────────────
211
+ logger.info("Starting finetuning...")
212
+ start = datetime.utcnow()
213
+ trainer.train()
214
+ duration = (datetime.utcnow() - start).total_seconds()
215
+ logger.info(f"Training complete in {duration:.0f}s")
216
+
217
+ # ── Save locally ──────────────────────────────────────────────────────────
218
+ trainer.save_model(str(MODEL_OUTPUT))
219
+ tokenizer.save_pretrained(str(MODEL_OUTPUT))
220
+ logger.info(f"Model saved β†’ {MODEL_OUTPUT}")
221
+
222
+ # ── Push to HF private repo ───────────────────────────────────────────────
223
+ token = model_module.TOKEN
224
+ private_repo = model_module.PRIVATE_MODEL
225
+
226
+ if token and private_repo:
227
+ logger.info(f"Pushing to HF: {private_repo}...")
228
+ try:
229
+ model.push_to_hub(private_repo, token=token, private=True)
230
+ tokenizer.push_to_hub(private_repo, token=token, private=True)
231
+ model_module.push_model_card({
232
+ "model_id": model_id,
233
+ "samples": len(dataset),
234
+ "epochs": 3,
235
+ "duration_sec": int(duration),
236
+ "finetuned_from": model_id,
237
+ })
238
+ logger.info(f"Model pushed β†’ {private_repo}")
239
+ except Exception as e:
240
+ logger.error(f"Push failed: {type(e).__name__}: {e}")
241
+ else:
242
+ logger.warning("No token or private repo configured β€” skipping HF push")
243
 
244
 
245
  # =============================================================================
 
262
  elif args.mode == "validate":
263
  validate_adi()
264
  elif args.mode == "finetune":
265
+ finetune()