Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
|
@@ -8,7 +8,7 @@
|
|
| 8 |
# Usage:
|
| 9 |
# python train.py --mode export β export HF dataset to training format
|
| 10 |
# python train.py --mode validate β validate ADI weights against dataset
|
| 11 |
-
# python train.py --mode finetune β finetune SmolLM2 on
|
| 12 |
# =============================================================================
|
| 13 |
import os
|
| 14 |
import argparse
|
|
@@ -24,6 +24,7 @@ _TMP = Path("/tmp") if os.getenv("SPACE_ID") else Path(".")
|
|
| 24 |
|
| 25 |
TRAIN_DATA = _TMP / "train_data.jsonl"
|
| 26 |
VALID_RESULT = _TMP / "validation_results.json"
|
|
|
|
| 27 |
|
| 28 |
import model as model_module
|
| 29 |
from adi import DumpindexAnalyzer
|
|
@@ -39,7 +40,9 @@ logger = logging.getLogger("train")
|
|
| 39 |
def export_dataset(output_path: str = None):
|
| 40 |
"""
|
| 41 |
Export HF dataset logs to JSONL format for training.
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
output = Path(output_path) if output_path else TRAIN_DATA
|
| 45 |
|
|
@@ -51,26 +54,31 @@ def export_dataset(output_path: str = None):
|
|
| 51 |
return
|
| 52 |
|
| 53 |
count = 0
|
|
|
|
| 54 |
with open(output, "w") as f:
|
| 55 |
for entry in entries:
|
| 56 |
-
#
|
| 57 |
if entry.get("adi_decision") == "REJECT":
|
|
|
|
| 58 |
continue
|
| 59 |
if not entry.get("response"):
|
|
|
|
| 60 |
continue
|
| 61 |
|
| 62 |
# Format as instruction tuning pair
|
|
|
|
| 63 |
record = {
|
| 64 |
-
"instruction":
|
| 65 |
-
"input":
|
| 66 |
-
"output":
|
| 67 |
-
"adi_score":
|
| 68 |
"adi_decision": entry.get("adi_decision"),
|
|
|
|
| 69 |
}
|
| 70 |
f.write(json.dumps(record) + "\n")
|
| 71 |
count += 1
|
| 72 |
|
| 73 |
-
logger.info(f"Exported {count}/{len(entries)} entries β {output}")
|
| 74 |
|
| 75 |
|
| 76 |
# =============================================================================
|
|
@@ -107,13 +115,14 @@ def validate_adi():
|
|
| 107 |
|
| 108 |
|
| 109 |
# =============================================================================
|
| 110 |
-
# Mode 3 β Finetune
|
| 111 |
# =============================================================================
|
| 112 |
|
| 113 |
def finetune():
|
| 114 |
"""
|
| 115 |
-
Finetune SmolLM2 on
|
| 116 |
-
Requires export first + enough data (
|
|
|
|
| 117 |
"""
|
| 118 |
if not TRAIN_DATA.exists():
|
| 119 |
logger.error(f"train_data.jsonl not found at {TRAIN_DATA} β run export first")
|
|
@@ -122,17 +131,115 @@ def finetune():
|
|
| 122 |
lines = TRAIN_DATA.read_text().strip().splitlines()
|
| 123 |
logger.info(f"Training samples available: {len(lines)}")
|
| 124 |
|
| 125 |
-
if len(lines) <
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
logger.warning(f"Only {len(lines)} samples β recommend 500+ for meaningful finetuning")
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
# =============================================================================
|
|
@@ -155,4 +262,4 @@ if __name__ == "__main__":
|
|
| 155 |
elif args.mode == "validate":
|
| 156 |
validate_adi()
|
| 157 |
elif args.mode == "finetune":
|
| 158 |
-
finetune()
|
|
|
|
| 8 |
# Usage:
|
| 9 |
# python train.py --mode export β export HF dataset to training format
|
| 10 |
# python train.py --mode validate β validate ADI weights against dataset
|
| 11 |
+
# python train.py --mode finetune β finetune SmolLM2 on exported data
|
| 12 |
# =============================================================================
|
| 13 |
import os
|
| 14 |
import argparse
|
|
|
|
| 24 |
|
| 25 |
TRAIN_DATA = _TMP / "train_data.jsonl"
|
| 26 |
VALID_RESULT = _TMP / "validation_results.json"
|
| 27 |
+
MODEL_OUTPUT = _TMP / "finetuned_model"
|
| 28 |
|
| 29 |
import model as model_module
|
| 30 |
from adi import DumpindexAnalyzer
|
|
|
|
| 40 |
def export_dataset(output_path: str = None):
|
| 41 |
"""
|
| 42 |
Export HF dataset logs to JSONL format for training.
|
| 43 |
+
Includes HIGH_PRIORITY, MEDIUM_PRIORITY and BLOCKED entries.
|
| 44 |
+
BLOCKED entries teach the model what to reject.
|
| 45 |
+
REJECT entries (ADI noise/quality fail) are skipped β no response logged.
|
| 46 |
"""
|
| 47 |
output = Path(output_path) if output_path else TRAIN_DATA
|
| 48 |
|
|
|
|
| 54 |
return
|
| 55 |
|
| 56 |
count = 0
|
| 57 |
+
skipped = 0
|
| 58 |
with open(output, "w") as f:
|
| 59 |
for entry in entries:
|
| 60 |
+
# Skip ADI-rejected entries β no meaningful response logged
|
| 61 |
if entry.get("adi_decision") == "REJECT":
|
| 62 |
+
skipped += 1
|
| 63 |
continue
|
| 64 |
if not entry.get("response"):
|
| 65 |
+
skipped += 1
|
| 66 |
continue
|
| 67 |
|
| 68 |
# Format as instruction tuning pair
|
| 69 |
+
# BLOCKED entries are included β model learns what to refuse
|
| 70 |
record = {
|
| 71 |
+
"instruction": entry.get("system_prompt", "You are a helpful assistant."),
|
| 72 |
+
"input": entry.get("prompt", ""),
|
| 73 |
+
"output": entry.get("response", ""),
|
| 74 |
+
"adi_score": entry.get("adi_score"),
|
| 75 |
"adi_decision": entry.get("adi_decision"),
|
| 76 |
+
"is_safe": entry.get("adi_decision") != "BLOCKED",
|
| 77 |
}
|
| 78 |
f.write(json.dumps(record) + "\n")
|
| 79 |
count += 1
|
| 80 |
|
| 81 |
+
logger.info(f"Exported {count}/{len(entries)} entries β {output} (skipped: {skipped})")
|
| 82 |
|
| 83 |
|
| 84 |
# =============================================================================
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
# =============================================================================
|
| 118 |
+
# Mode 3 β Finetune SmolLM2 with TRL SFTTrainer
|
| 119 |
# =============================================================================
|
| 120 |
|
| 121 |
def finetune():
|
| 122 |
"""
|
| 123 |
+
Finetune SmolLM2 on exported dataset using TRL SFTTrainer.
|
| 124 |
+
Requires export first + enough data (500+ samples recommended).
|
| 125 |
+
On completion: pushes finetuned weights to private HF model repo.
|
| 126 |
"""
|
| 127 |
if not TRAIN_DATA.exists():
|
| 128 |
logger.error(f"train_data.jsonl not found at {TRAIN_DATA} β run export first")
|
|
|
|
| 131 |
lines = TRAIN_DATA.read_text().strip().splitlines()
|
| 132 |
logger.info(f"Training samples available: {len(lines)}")
|
| 133 |
|
| 134 |
+
if len(lines) < 10:
|
| 135 |
+
logger.error(f"Too few samples ({len(lines)}) β aborting finetune")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
if len(lines) < 500:
|
| 139 |
logger.warning(f"Only {len(lines)} samples β recommend 500+ for meaningful finetuning")
|
| 140 |
|
| 141 |
+
# ββ Imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
try:
|
| 143 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 144 |
+
from trl import SFTTrainer, SFTConfig
|
| 145 |
+
from datasets import Dataset
|
| 146 |
+
import torch
|
| 147 |
+
except ImportError as e:
|
| 148 |
+
logger.error(f"Missing dependency: {e} β run: pip install trl transformers datasets torch")
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
# ββ Load dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 152 |
+
logger.info("Loading training data...")
|
| 153 |
+
records = [json.loads(l) for l in lines]
|
| 154 |
+
|
| 155 |
+
def format_record(record):
|
| 156 |
+
"""Format record into chat template string."""
|
| 157 |
+
instruction = record.get("instruction", "You are a helpful assistant.")
|
| 158 |
+
user_input = record.get("input", "")
|
| 159 |
+
output = record.get("output", "")
|
| 160 |
+
return {
|
| 161 |
+
"text": f"<|system|>\n{instruction}\n<|user|>\n{user_input}\n<|assistant|>\n{output}"
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
formatted = [format_record(r) for r in records]
|
| 165 |
+
dataset = Dataset.from_list(formatted)
|
| 166 |
+
logger.info(f"Dataset ready: {len(dataset)} samples")
|
| 167 |
+
|
| 168 |
+
# ββ Load model + tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
model_id = model_module.get_model_id()
|
| 170 |
+
kwargs = model_module.get_model_kwargs()
|
| 171 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 172 |
+
|
| 173 |
+
logger.info(f"Loading base model: {model_id} on {device}...")
|
| 174 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
|
| 175 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
|
| 176 |
+
|
| 177 |
+
# Ensure pad token exists
|
| 178 |
+
if tokenizer.pad_token is None:
|
| 179 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 180 |
+
|
| 181 |
+
# ββ Training config βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 182 |
+
# Conservative settings for CPU / low RAM (2-8GB)
|
| 183 |
+
sft_config = SFTConfig(
|
| 184 |
+
output_dir=str(MODEL_OUTPUT),
|
| 185 |
+
num_train_epochs=3,
|
| 186 |
+
per_device_train_batch_size=1, # CPU friendly
|
| 187 |
+
gradient_accumulation_steps=4, # effective batch size = 4
|
| 188 |
+
learning_rate=2e-5,
|
| 189 |
+
warmup_steps=10,
|
| 190 |
+
logging_steps=10,
|
| 191 |
+
save_steps=50,
|
| 192 |
+
save_total_limit=2,
|
| 193 |
+
fp16=False, # no GPU, no fp16
|
| 194 |
+
bf16=False,
|
| 195 |
+
dataloader_num_workers=0, # HF Spaces: no multiprocessing
|
| 196 |
+
report_to="none", # no wandb/tensorboard
|
| 197 |
+
max_seq_length=512, # SmolLM2 context limit
|
| 198 |
+
dataset_text_field="text",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# ββ SFTTrainer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 202 |
+
logger.info("Initializing SFTTrainer...")
|
| 203 |
+
trainer = SFTTrainer(
|
| 204 |
+
model=model,
|
| 205 |
+
args=sft_config,
|
| 206 |
+
train_dataset=dataset,
|
| 207 |
+
tokenizer=tokenizer,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# ββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 211 |
+
logger.info("Starting finetuning...")
|
| 212 |
+
start = datetime.utcnow()
|
| 213 |
+
trainer.train()
|
| 214 |
+
duration = (datetime.utcnow() - start).total_seconds()
|
| 215 |
+
logger.info(f"Training complete in {duration:.0f}s")
|
| 216 |
+
|
| 217 |
+
# ββ Save locally ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 218 |
+
trainer.save_model(str(MODEL_OUTPUT))
|
| 219 |
+
tokenizer.save_pretrained(str(MODEL_OUTPUT))
|
| 220 |
+
logger.info(f"Model saved β {MODEL_OUTPUT}")
|
| 221 |
+
|
| 222 |
+
# ββ Push to HF private repo βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 223 |
+
token = model_module.TOKEN
|
| 224 |
+
private_repo = model_module.PRIVATE_MODEL
|
| 225 |
+
|
| 226 |
+
if token and private_repo:
|
| 227 |
+
logger.info(f"Pushing to HF: {private_repo}...")
|
| 228 |
+
try:
|
| 229 |
+
model.push_to_hub(private_repo, token=token, private=True)
|
| 230 |
+
tokenizer.push_to_hub(private_repo, token=token, private=True)
|
| 231 |
+
model_module.push_model_card({
|
| 232 |
+
"model_id": model_id,
|
| 233 |
+
"samples": len(dataset),
|
| 234 |
+
"epochs": 3,
|
| 235 |
+
"duration_sec": int(duration),
|
| 236 |
+
"finetuned_from": model_id,
|
| 237 |
+
})
|
| 238 |
+
logger.info(f"Model pushed β {private_repo}")
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"Push failed: {type(e).__name__}: {e}")
|
| 241 |
+
else:
|
| 242 |
+
logger.warning("No token or private repo configured β skipping HF push")
|
| 243 |
|
| 244 |
|
| 245 |
# =============================================================================
|
|
|
|
| 262 |
elif args.mode == "validate":
|
| 263 |
validate_adi()
|
| 264 |
elif args.mode == "finetune":
|
| 265 |
+
finetune()
|