Dasuperhub commited on
Commit
cfb45f0
·
verified ·
1 Parent(s): 6191c28

Add A100 training script for v4 retrain

Browse files
Files changed (1) hide show
  1. training/weight-swap-a100.py +464 -0
training/weight-swap-a100.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GUINIUS DA — Soussou Curriculum LoRA Fine-Tune
4
+ ==============================================
5
+ Target: Google Colab A100 (40GB VRAM)
6
+ Dataset: soussou-curriculum-v4-CLEAN.jsonl (10,869 examples, 96.5% GT-verified + native-validated)
7
+ Base Model: Qwen/Qwen3-0.6B
8
+ Method: Single LoRA fine-tune → merge → GGUF export
9
+ Philosophy: Teach the OPERATING SYSTEM, not the dictionary.
10
+
11
+ USAGE (in Colab):
12
+ 1. Upload soussou-curriculum-v2.jsonl
13
+ 2. Run all cells top to bottom
14
+ 3. Download the GGUF
15
+
16
+ A100 Time Estimate: ~15 minutes total
17
+ """
18
+
19
+ # ==============================================================================
20
+ # CELL 1: Configuration
21
+ # ==============================================================================
22
+
23
+ BASE_MODEL = "Qwen/Qwen3-0.6B"
24
+ HF_REPO = "Dasuperhub/DA-MLC"
25
+ DATASET_FILE = "soussou-curriculum-v4-CLEAN.jsonl"
26
+
27
+ # Training hyperparams — tuned for 10.8K validated examples
28
+ EPOCHS = 3 # 3 passes sufficient for 10K+ examples
29
+ LR = 2e-4 # Slightly higher LR with more data
30
+ LORA_R = 64 # Rank 64 — more capacity for 1,106 unique Soussou tokens
31
+ LORA_ALPHA = 32 # Alpha = rank (standard)
32
+ BATCH_SIZE = 4 # Small batches, more gradient updates
33
+ GRAD_ACCUM = 4 # Effective batch = 16
34
+ MAX_SEQ_LEN = 512 # Curriculum examples are short
35
+ WARMUP_STEPS = 10 # Short warmup for small dataset
36
+
37
+ print(f"Config: {EPOCHS} epochs | lr={LR} | LoRA r={LORA_R} | batch={BATCH_SIZE}x{GRAD_ACCUM}")
38
+ print(f"Dataset: {DATASET_FILE}")
39
+ print(f"Base: {BASE_MODEL}")
40
+
41
+
42
+ # ==============================================================================
43
+ # CELL 2: Install Dependencies
44
+ # ==============================================================================
45
+
46
+ import subprocess, sys
47
+
48
+ def install(packages):
49
+ for pkg in packages:
50
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkg.split())
51
+
52
+ install([
53
+ "unsloth",
54
+ "--no-deps trl peft accelerate bitsandbytes",
55
+ "huggingface_hub",
56
+ ])
57
+
58
+ print("Dependencies installed.")
59
+
60
+
61
+ # ==============================================================================
62
+ # CELL 3: Upload Dataset
63
+ # ==============================================================================
64
+
65
+ import os, json
66
+
67
+ if not os.path.exists(DATASET_FILE):
68
+ print(f"{DATASET_FILE} not found. Upload it:")
69
+ print(" A) Drag-and-drop to Colab file browser")
70
+ print(" B) From Google Drive:")
71
+ print(" from google.colab import drive; drive.mount('/content/drive')")
72
+ print(" !cp /content/drive/MyDrive/guinius/soussou-curriculum.jsonl .")
73
+ try:
74
+ from google.colab import files
75
+ uploaded = files.upload()
76
+ except:
77
+ pass
78
+
79
+ assert os.path.exists(DATASET_FILE), f"{DATASET_FILE} not found!"
80
+
81
+ # Count and preview
82
+ line_count = sum(1 for _ in open(DATASET_FILE))
83
+ print(f"\nDataset: {line_count} examples")
84
+
85
+ # Show layer distribution
86
+ layer_counts = {}
87
+ with open(DATASET_FILE) as f:
88
+ for line in f:
89
+ ex = json.loads(line)
90
+ sys_msg = ex["messages"][0]["content"] if ex["messages"] else ""
91
+ if "Grammar Assistant" in sys_msg:
92
+ layer_counts["Grammar"] = layer_counts.get("Grammar", 0) + 1
93
+ elif "Guinius" in sys_msg:
94
+ layer_counts["Identity/Social"] = layer_counts.get("Identity/Social", 0) + 1
95
+ else:
96
+ layer_counts["Other"] = layer_counts.get("Other", 0) + 1
97
+
98
+ print("Distribution:", layer_counts)
99
+
100
+ # Preview first example
101
+ with open(DATASET_FILE) as f:
102
+ first = json.loads(f.readline())
103
+ print(f"\nSample:")
104
+ for msg in first["messages"]:
105
+ print(f" [{msg['role']}] {msg['content'][:100]}")
106
+
107
+
108
+ # ==============================================================================
109
+ # CELL 4: Load Base Model
110
+ # ==============================================================================
111
+
112
+ from unsloth import FastLanguageModel
113
+ import torch
114
+
115
+ print(f"Loading {BASE_MODEL}...")
116
+
117
+ model, tokenizer = FastLanguageModel.from_pretrained(
118
+ model_name=BASE_MODEL,
119
+ max_seq_length=MAX_SEQ_LEN,
120
+ dtype=torch.bfloat16, # A100 native
121
+ load_in_4bit=False, # Full precision — A100 has the VRAM
122
+ )
123
+
124
+ total_params = sum(p.numel() for p in model.parameters())
125
+ print(f"Model loaded: {total_params:,} parameters")
126
+ print(f"Device: {torch.cuda.get_device_name()}")
127
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
128
+
129
+
130
+ # ==============================================================================
131
+ # CELL 5: Baseline Evaluation (BEFORE training)
132
+ # ==============================================================================
133
+
134
+ EVAL_PROMPTS = [
135
+ # Soussou grammar (should learn)
136
+ {"prompt": "How do you say 'I am going' in Soussou?", "expected": "sigafe", "cat": "soussou"},
137
+ {"prompt": "Translate to Soussou: 'We are eating'", "expected": "donsefe", "cat": "soussou"},
138
+ {"prompt": "What are the Soussou pronouns?", "expected": "n", "cat": "soussou"},
139
+ {"prompt": "How do you say 'he came' in Soussou?", "expected": "faxi", "cat": "soussou"},
140
+ {"prompt": "What is the Soussou future tense marker?", "expected": "fama", "cat": "soussou"},
141
+
142
+ # Code-switching (should learn)
143
+ {"prompt": "How would a Guinean say 'I'm going to the market'?", "expected": "marché", "cat": "code-switch"},
144
+ {"prompt": "N na sigafe école ra — what does this mean?", "expected": "school", "cat": "code-switch"},
145
+
146
+ # French retention (should keep)
147
+ {"prompt": "Explique-moi ce qu'est l'intelligence artificielle.", "expected": "artificielle", "cat": "french"},
148
+ {"prompt": "Bonjour, comment vas-tu?", "expected": "bien", "cat": "french"},
149
+
150
+ # English retention (should keep)
151
+ {"prompt": "What is machine learning?", "expected": "data", "cat": "english"},
152
+ {"prompt": "Explain what a neural network does.", "expected": "network", "cat": "english"},
153
+
154
+ # Identity
155
+ {"prompt": "I khili mun di?", "expected": "Guinius", "cat": "identity"},
156
+
157
+ # Language mirroring
158
+ {"prompt": "Apprends-moi le soussou!", "expected": "Soussou", "cat": "mirror"},
159
+ {"prompt": "Teach me Soussou!", "expected": "Soussou", "cat": "mirror"},
160
+ ]
161
+
162
+ def evaluate(model, tokenizer, label=""):
163
+ """Run evaluation prompts and score."""
164
+ FastLanguageModel.for_inference(model)
165
+ import re
166
+
167
+ SYSTEM = "I khili Guinius, DA AI. N kelixi Soussou, Français, English."
168
+
169
+ results = {"total": 0, "hits": 0, "by_cat": {}}
170
+
171
+ print(f"\n{'='*60}")
172
+ print(f" EVALUATION: {label}")
173
+ print(f"{'='*60}")
174
+
175
+ for ep in EVAL_PROMPTS:
176
+ messages = [
177
+ {"role": "system", "content": SYSTEM},
178
+ {"role": "user", "content": ep["prompt"]},
179
+ ]
180
+ inputs = tokenizer.apply_chat_template(
181
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
182
+ ).to("cuda")
183
+
184
+ with torch.no_grad():
185
+ outputs = model.generate(
186
+ input_ids=inputs,
187
+ max_new_tokens=150,
188
+ temperature=0.6,
189
+ top_p=0.9,
190
+ do_sample=True,
191
+ )
192
+
193
+ response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
194
+ response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL).strip()
195
+
196
+ hit = ep["expected"].lower() in response.lower()
197
+ cat = ep["cat"]
198
+
199
+ results["total"] += 1
200
+ results["hits"] += int(hit)
201
+ if cat not in results["by_cat"]:
202
+ results["by_cat"][cat] = {"hits": 0, "total": 0}
203
+ results["by_cat"][cat]["total"] += 1
204
+ results["by_cat"][cat]["hits"] += int(hit)
205
+
206
+ status = "PASS" if hit else "FAIL"
207
+ print(f" [{status}] {ep['prompt']}")
208
+ print(f" -> {response[:200]}")
209
+
210
+ # Summary
211
+ print(f"\n SCORE: {results['hits']}/{results['total']} = {results['hits']/max(results['total'],1)*100:.0f}%")
212
+ for cat, s in results["by_cat"].items():
213
+ print(f" {cat:15s}: {s['hits']}/{s['total']}")
214
+
215
+ return results
216
+
217
+ baseline = evaluate(model, tokenizer, "BASELINE (before training)")
218
+
219
+
220
+ # ==============================================================================
221
+ # CELL 6: Apply LoRA
222
+ # ==============================================================================
223
+
224
+ model = FastLanguageModel.get_peft_model(
225
+ model,
226
+ r=LORA_R,
227
+ target_modules=[
228
+ "q_proj", "k_proj", "v_proj", "o_proj", # Attention
229
+ "gate_proj", "up_proj", "down_proj", # MLP
230
+ ],
231
+ lora_alpha=LORA_ALPHA,
232
+ lora_dropout=0,
233
+ bias="none",
234
+ use_gradient_checkpointing="unsloth",
235
+ random_state=42,
236
+ )
237
+
238
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
239
+ total = sum(p.numel() for p in model.parameters())
240
+ print(f"LoRA applied: {trainable:,} trainable / {total:,} total = {trainable/total*100:.2f}%")
241
+
242
+
243
+ # ==============================================================================
244
+ # CELL 7: Prepare Dataset
245
+ # ==============================================================================
246
+
247
+ from datasets import load_dataset
248
+
249
+ dataset = load_dataset("json", data_files=DATASET_FILE, split="train")
250
+ print(f"Loaded: {len(dataset)} examples")
251
+
252
+ def format_chatml(example):
253
+ """Format messages into ChatML text for SFTTrainer."""
254
+ text = ""
255
+ for msg in example["messages"]:
256
+ text += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
257
+ text += "<|im_start|>assistant\n"
258
+ return {"text": text}
259
+
260
+ dataset = dataset.map(format_chatml, num_proc=2)
261
+
262
+ # Token length distribution
263
+ lengths = []
264
+ for ex in dataset:
265
+ toks = tokenizer(ex["text"], return_length=True)
266
+ lengths.append(toks["length"][0])
267
+ print(f"Token lengths: min={min(lengths)}, median={sorted(lengths)[len(lengths)//2]}, max={max(lengths)}")
268
+ print(f"All fit in {MAX_SEQ_LEN}? {'YES' if max(lengths) <= MAX_SEQ_LEN else 'NO — increase MAX_SEQ_LEN!'}")
269
+
270
+
271
+ # ==============================================================================
272
+ # CELL 8: Train
273
+ # ==============================================================================
274
+
275
+ from trl import SFTTrainer
276
+ from transformers import TrainingArguments
277
+
278
+ trainer = SFTTrainer(
279
+ model=model,
280
+ tokenizer=tokenizer,
281
+ train_dataset=dataset,
282
+ dataset_text_field="text",
283
+ max_seq_length=MAX_SEQ_LEN,
284
+ dataset_num_proc=2,
285
+ packing=False,
286
+ args=TrainingArguments(
287
+ per_device_train_batch_size=BATCH_SIZE,
288
+ gradient_accumulation_steps=GRAD_ACCUM,
289
+ warmup_steps=WARMUP_STEPS,
290
+ num_train_epochs=EPOCHS,
291
+ learning_rate=LR,
292
+ bf16=True,
293
+ logging_steps=10,
294
+ optim="adamw_8bit",
295
+ weight_decay=0.01,
296
+ lr_scheduler_type="cosine",
297
+ seed=42,
298
+ output_dir="outputs",
299
+ report_to="none",
300
+ ),
301
+ )
302
+
303
+ total_steps = len(dataset) // (BATCH_SIZE * GRAD_ACCUM) * EPOCHS
304
+ print(f"\nStarting training...")
305
+ print(f" {len(dataset)} examples x {EPOCHS} epochs = {len(dataset)*EPOCHS} passes")
306
+ print(f" ~{total_steps} optimization steps")
307
+ print(f" Estimated time: ~5-15 min on A100")
308
+
309
+ stats = trainer.train()
310
+
311
+ print(f"\nTraining complete!")
312
+ print(f" Final loss: {stats.training_loss:.4f}")
313
+ print(f" Runtime: {stats.metrics['train_runtime']:.0f}s")
314
+ print(f" Samples/sec: {stats.metrics['train_samples_per_second']:.1f}")
315
+
316
+
317
+ # ==============================================================================
318
+ # CELL 9: Post-Training Evaluation
319
+ # ==============================================================================
320
+
321
+ post_train = evaluate(model, tokenizer, "AFTER TRAINING")
322
+
323
+ # Compare
324
+ print(f"\n{'='*60}")
325
+ print(f" BEFORE vs AFTER")
326
+ print(f"{'='*60}")
327
+ print(f" Baseline: {baseline['hits']}/{baseline['total']}")
328
+ print(f" Trained: {post_train['hits']}/{post_train['total']}")
329
+ for cat in baseline["by_cat"]:
330
+ b = baseline["by_cat"][cat]
331
+ a = post_train["by_cat"].get(cat, {"hits": 0, "total": 0})
332
+ delta = a["hits"] - b["hits"]
333
+ arrow = "+" if delta > 0 else ("=" if delta == 0 else "")
334
+ print(f" {cat:15s}: {b['hits']}/{b['total']} -> {a['hits']}/{a['total']} {arrow}{delta if delta != 0 else ''}")
335
+
336
+
337
+ # ==============================================================================
338
+ # CELL 10: Merge LoRA into Base Model
339
+ # ==============================================================================
340
+
341
+ print("Merging LoRA into base weights...")
342
+
343
+ # Save LoRA adapter first
344
+ LORA_DIR = "guinius-lora"
345
+ model.save_pretrained(LORA_DIR)
346
+ tokenizer.save_pretrained(LORA_DIR)
347
+ print(f"LoRA adapter saved: {LORA_DIR}/")
348
+
349
+ # Free GPU memory
350
+ del model, trainer
351
+ torch.cuda.empty_cache()
352
+
353
+ # Merge on CPU
354
+ from transformers import AutoModelForCausalLM, AutoTokenizer
355
+ from peft import PeftModel
356
+
357
+ print("Loading base model on CPU...")
358
+ base_model = AutoModelForCausalLM.from_pretrained(
359
+ BASE_MODEL,
360
+ torch_dtype=torch.float16,
361
+ device_map="cpu",
362
+ )
363
+ base_tok = AutoTokenizer.from_pretrained(BASE_MODEL)
364
+
365
+ print("Applying LoRA adapter...")
366
+ model_with_lora = PeftModel.from_pretrained(base_model, LORA_DIR)
367
+
368
+ print("Merging weights...")
369
+ merged = model_with_lora.merge_and_unload()
370
+
371
+ MERGED_DIR = "guinius-merged"
372
+ merged.save_pretrained(MERGED_DIR)
373
+ base_tok.save_pretrained(MERGED_DIR)
374
+ print(f"Merged model saved: {MERGED_DIR}/")
375
+
376
+ del base_model, model_with_lora, merged
377
+ torch.cuda.empty_cache()
378
+
379
+
380
+ # ==============================================================================
381
+ # CELL 11: Install MLC-LLM (for WebLLM-ready output)
382
+ # ==============================================================================
383
+
384
+ print("Installing MLC-LLM for direct WebLLM export...")
385
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
386
+ "--pre", "-f", "https://mlc.ai/wheels",
387
+ "mlc-ai-nightly-cu124", "mlc-llm-nightly-cu124"])
388
+ print("MLC-LLM installed.")
389
+
390
+
391
+ # ==============================================================================
392
+ # CELL 12: Convert to MLC (WebLLM-ready)
393
+ # ==============================================================================
394
+
395
+ MLC_DIR = "DA-MLC"
396
+
397
+ print(f"Converting merged model → MLC format...")
398
+ print(f" Input: {MERGED_DIR}/")
399
+ print(f" Output: {MLC_DIR}/")
400
+
401
+ # Step 1: Convert weights to q4f16_1 quantization
402
+ subprocess.run([
403
+ sys.executable, "-m", "mlc_llm", "convert_weight", MERGED_DIR,
404
+ "--quantization", "q4f16_1",
405
+ "--output", MLC_DIR,
406
+ ], check=True)
407
+ print("Weights converted.")
408
+
409
+ # Step 2: Generate MLC config
410
+ subprocess.run([
411
+ sys.executable, "-m", "mlc_llm", "gen_config", MLC_DIR,
412
+ "--quantization", "q4f16_1",
413
+ "--conv-template", "chatml",
414
+ "--context-window-size", "2048",
415
+ "--output", MLC_DIR,
416
+ ], check=True)
417
+ print("Config generated.")
418
+
419
+ # Show output
420
+ total_size = 0
421
+ for f in os.listdir(MLC_DIR):
422
+ fpath = os.path.join(MLC_DIR, f)
423
+ if os.path.isfile(fpath):
424
+ size_mb = os.path.getsize(fpath) / 1e6
425
+ total_size += size_mb
426
+ print(f" {f}: {size_mb:.1f} MB")
427
+ print(f" TOTAL: {total_size:.0f} MB")
428
+
429
+ print(f"\nMLC conversion complete! WebLLM can load this directly.")
430
+
431
+
432
+ # ==============================================================================
433
+ # CELL 13: Upload to HuggingFace → WebLLM loads it
434
+ # ==============================================================================
435
+
436
+ from huggingface_hub import HfApi, login
437
+
438
+ # Login — paste your HF token when prompted
439
+ token = os.environ.get("HF_TOKEN")
440
+ if token:
441
+ login(token=token)
442
+ else:
443
+ print("Paste your HuggingFace token:")
444
+ login()
445
+
446
+ api = HfApi()
447
+
448
+ print(f"\nUploading MLC model to {HF_REPO}...")
449
+ api.upload_folder(
450
+ folder_path=MLC_DIR,
451
+ repo_id=HF_REPO,
452
+ commit_message="Guinius DA v4 — Soussou curriculum (10,869 GT-verified + native-validated examples)",
453
+ delete_patterns=["*.bin", "*.safetensors", "*.gguf"], # Clean old files
454
+ )
455
+
456
+ print(f"\n{'='*60}")
457
+ print(f" DONE — WebLLM READY")
458
+ print(f"{'='*60}")
459
+ print(f" Model: Qwen3-0.6B + Soussou curriculum v4 (10,869 examples)")
460
+ print(f" HuggingFace: https://huggingface.co/{HF_REPO}")
461
+ print(f" WebLLM WASM: Qwen3-0.6B (same architecture, reuse existing)")
462
+ print()
463
+ print(f" Open guinius.dasuperhub.com — it loads from HuggingFace automatically.")
464
+ print(f" No GGUF. No local conversion. Direct to browser.")