| |
| """ |
| Distill Gemini Flash summaries into Qwen3-0.6B. |
| |
| Fine-tunes Qwen3-0.6B with LoRA to generate one-sentence summaries from |
| raw markdown text β distilling from 6,720 high-quality Gemini-generated |
| summaries. At inference time, feed any markdown text and get a summary |
| back. Runs on CPU for inference (~1-2s per summary). |
| |
| Input: raw embedded_text (markdown) |
| Output: one-sentence summary (Gemini-quality, Qwen-speed) |
| |
| Usage: |
| python3 gpu_distill.py --data-dir /workspace/data --output-dir /workspace/output |
| """ |
| import json |
| import os |
| import sys |
| import time |
| import datetime |
| import argparse |
| import math |
|
|
| sys.stdout.reconfigure(line_buffering=True) |
| sys.stderr.reconfigure(line_buffering=True) |
|
|
|
|
| def log(msg, level="INFO"): |
| ts = datetime.datetime.now().strftime("%H:%M:%S") |
| print(f"[{ts}] [{level}] {msg}", flush=True) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data-dir", default="/workspace/data") |
| parser.add_argument("--output-dir", default="/workspace/output") |
| parser.add_argument("--epochs", type=int, default=5) |
| parser.add_argument("--batch-size", type=int, default=8) |
| parser.add_argument("--lr", type=float, default=2e-4) |
| parser.add_argument("--lora-rank", type=int, default=16) |
| parser.add_argument("--lora-alpha", type=int, default=32) |
| parser.add_argument("--model-name", default="Qwen/Qwen3-0.6B") |
| parser.add_argument("--max-input-len", type=int, default=384, help="Max input tokens") |
| parser.add_argument("--max-output-len", type=int, default=64, help="Max output tokens") |
| parser.add_argument("--log-every", type=int, default=10) |
| parser.add_argument("--sample-every", type=int, default=2) |
| args = parser.parse_args() |
|
|
| log("=" * 60) |
| log("DISTILLATION: Markdown β Summary (LoRA fine-tune)") |
| log("=" * 60) |
| log(f"Config: epochs={args.epochs} batch={args.batch_size} lr={args.lr} " |
| f"lora_rank={args.lora_rank} input_len={args.max_input_len} output_len={args.max_output_len}") |
|
|
| |
| import subprocess as _sp |
| for pkg in ["numpy", "transformers", "accelerate", "safetensors"]: |
| try: |
| __import__(pkg) |
| except ImportError: |
| log(f"Installing {pkg}...") |
| _sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", |
| "-q", pkg], check=True) |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| log(f"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if device.type == "cuda": |
| props = torch.cuda.get_device_properties(0) |
| log(f"GPU: {torch.cuda.get_device_name()} | VRAM: {props.total_memory / 1024**3:.1f} GB") |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| def vram_mb(): |
| return torch.cuda.memory_allocated() / 1024**2 if device.type == "cuda" else 0 |
|
|
| metrics = { |
| "config": vars(args), "device": str(device), |
| "gpu": torch.cuda.get_device_name() if device.type == "cuda" else "cpu", |
| "method": "distillation", "steps": [], "epochs": [], "samples": [], |
| "start_time": time.time(), |
| } |
|
|
| |
| log("Loading data...") |
| t0 = time.time() |
|
|
| |
| with open(os.path.join(args.data_dir, "texts.json")) as f: |
| text_data = json.load(f) |
| with open(os.path.join(args.data_dir, "summaries.json")) as f: |
| sum_data = json.load(f) |
|
|
| sum_map = {s["id"]: s["summary"] for s in sum_data} |
| pairs = [(t["text"], sum_map[t["id"]]) for t in text_data |
| if t["id"] in sum_map and t["text"] and len(t["text"].strip()) > 20] |
| log(f"Loaded {len(pairs)} (text, summary) pairs in {time.time()-t0:.1f}s") |
|
|
| |
| text_lens = [len(t) for t, _ in pairs] |
| sum_lens = [len(s) for _, s in pairs] |
| log(f"Text lengths: mean={np.mean(text_lens):.0f} median={np.median(text_lens):.0f} " |
| f"max={max(text_lens)} chars") |
| log(f"Summary lengths: mean={np.mean(sum_lens):.0f} median={np.median(sum_lens):.0f} " |
| f"max={max(sum_lens)} chars") |
|
|
| |
| log(f"Loading {args.model_name}...") |
| t0 = time.time() |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "left" |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_name, torch_dtype=torch.float16, trust_remote_code=True, |
| ).to(device) |
|
|
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| hidden_dim = model.config.hidden_size |
| log(f"Model loaded in {time.time()-t0:.1f}s: hidden={hidden_dim} | VRAM: {vram_mb():.0f}MB") |
|
|
| |
| class LoRALayer(nn.Module): |
| def __init__(self, original_layer, rank, alpha): |
| super().__init__() |
| self.original = original_layer |
| in_f, out_f = original_layer.in_features, original_layer.out_features |
| self.lora_A = nn.Linear(in_f, rank, bias=False) |
| self.lora_B = nn.Linear(rank, out_f, bias=False) |
| self.scaling = alpha / rank |
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B.weight) |
|
|
| def forward(self, x): |
| orig_out = self.original(x) |
| lora_out = self.lora_B(self.lora_A(x.to(self.lora_A.weight.dtype))) |
| return orig_out + lora_out.to(orig_out.dtype) * self.scaling |
|
|
| lora_modules = [] |
| n_adapted = 0 |
| for name, module in model.named_modules(): |
| if hasattr(module, 'q_proj') and isinstance(module.q_proj, nn.Linear): |
| lora_q = LoRALayer(module.q_proj, args.lora_rank, args.lora_alpha).to(device) |
| module.q_proj = lora_q |
| lora_modules.append(lora_q) |
| n_adapted += 1 |
| if hasattr(module, 'v_proj') and isinstance(module.v_proj, nn.Linear): |
| lora_v = LoRALayer(module.v_proj, args.lora_rank, args.lora_alpha).to(device) |
| module.v_proj = lora_v |
| lora_modules.append(lora_v) |
| n_adapted += 1 |
|
|
| lora_params = [] |
| for lm in lora_modules: |
| lora_params.extend(lm.lora_A.parameters()) |
| lora_params.extend(lm.lora_B.parameters()) |
|
|
| lora_total = sum(p.numel() for p in lora_params) |
| log(f"LoRA applied to {n_adapted} layers | {lora_total:,} trainable params | VRAM: {vram_mb():.0f}MB") |
|
|
| |
| PROMPT_TEMPLATE = "Summarize in one sentence:\n{text}\n\nSummary:" |
|
|
| class DistillDataset(Dataset): |
| def __init__(self, pairs, tokenizer, max_input, max_output): |
| self.items = [] |
| for text, summary in pairs: |
| |
| prompt = PROMPT_TEMPLATE.format(text=text[:2000]) |
| |
| prompt_enc = tokenizer(prompt, truncation=True, max_length=max_input, |
| return_tensors="pt") |
| summary_enc = tokenizer(summary, truncation=True, max_length=max_output, |
| return_tensors="pt") |
|
|
| |
| input_ids = torch.cat([ |
| prompt_enc["input_ids"].squeeze(0), |
| summary_enc["input_ids"].squeeze(0), |
| torch.tensor([tokenizer.eos_token_id]), |
| ]) |
|
|
| |
| n_prompt = prompt_enc["input_ids"].shape[1] |
| labels = input_ids.clone() |
| labels[:n_prompt] = -100 |
|
|
| |
| max_total = max_input + max_output |
| if len(input_ids) > max_total: |
| input_ids = input_ids[:max_total] |
| labels = labels[:max_total] |
|
|
| self.items.append((input_ids, labels)) |
|
|
| def __len__(self): |
| return len(self.items) |
|
|
| def __getitem__(self, idx): |
| return self.items[idx] |
|
|
| def collate_fn(batch): |
| input_ids_list, labels_list = zip(*batch) |
| max_len = max(ids.shape[0] for ids in input_ids_list) |
|
|
| input_ids = torch.full((len(batch), max_len), tokenizer.pad_token_id, dtype=torch.long) |
| labels = torch.full((len(batch), max_len), -100, dtype=torch.long) |
| attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long) |
|
|
| for i, (ids, lab) in enumerate(zip(input_ids_list, labels_list)): |
| |
| offset = max_len - ids.shape[0] |
| input_ids[i, offset:] = ids |
| labels[i, offset:] = lab |
| attention_mask[i, offset:] = 1 |
|
|
| return input_ids, labels, attention_mask |
|
|
| |
| n_val = max(int(len(pairs) * 0.1), 1) |
| rng = np.random.RandomState(42) |
| indices = rng.permutation(len(pairs)) |
| val_pairs = [pairs[i] for i in indices[:n_val]] |
| train_pairs = [pairs[i] for i in indices[n_val:]] |
|
|
| train_ds = DistillDataset(train_pairs, tokenizer, args.max_input_len, args.max_output_len) |
| val_ds = DistillDataset(val_pairs, tokenizer, args.max_input_len, args.max_output_len) |
| train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| drop_last=True, collate_fn=collate_fn) |
| val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) |
|
|
| steps_per_epoch = len(train_dl) |
| total_steps = steps_per_epoch * args.epochs |
| log(f"Data: train={len(train_ds)} val={len(val_ds)} | {steps_per_epoch} steps/epoch, " |
| f"{total_steps} total") |
|
|
| |
| optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=0.01) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6) |
| scaler = torch.amp.GradScaler("cuda") if device.type == "cuda" else None |
| best_val_loss = float("inf") |
| global_step = 0 |
|
|
| log("") |
| log("=" * 60) |
| log("TRAINING START") |
| log("=" * 60) |
| train_start = time.time() |
|
|
| for epoch in range(args.epochs): |
| model.train() |
| epoch_loss, epoch_tokens = 0.0, 0 |
| epoch_start = time.time() |
| log(f"") |
| log(f"ββ Epoch {epoch+1}/{args.epochs} ββ") |
|
|
| for step, (input_ids, labels, attn_mask) in enumerate(train_dl): |
| step_start = time.time() |
| input_ids = input_ids.to(device) |
| labels = labels.to(device) |
| attn_mask = attn_mask.to(device) |
|
|
| optimizer.zero_grad() |
|
|
| if scaler: |
| with torch.amp.autocast("cuda"): |
| outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels) |
| loss = outputs.loss |
| if torch.isnan(loss): |
| log(f"NaN at step {step+1}!", "ERROR") |
| break |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(lora_params, 1.0).item() |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels) |
| loss = outputs.loss |
| loss.backward() |
| grad_norm = torch.nn.utils.clip_grad_norm_(lora_params, 1.0).item() |
| optimizer.step() |
|
|
| scheduler.step() |
|
|
| n_tokens = (labels != -100).sum().item() |
| step_time = time.time() - step_start |
| tps = n_tokens / step_time if step_time > 0 else 0 |
| epoch_loss += loss.item() * n_tokens |
| epoch_tokens += n_tokens |
| global_step += 1 |
|
|
| metrics["steps"].append({ |
| "epoch": epoch+1, "step": step+1, "global_step": global_step, |
| "loss": round(loss.item(), 4), "lr": scheduler.get_last_lr()[0], |
| "grad_norm": round(grad_norm, 4), "vram_mb": round(vram_mb()), |
| "tokens_per_sec": round(tps), |
| }) |
|
|
| if step % args.log_every == 0: |
| elapsed = time.time() - train_start |
| eta = elapsed / global_step * (total_steps - global_step) if global_step > 0 else 0 |
| log(f" step {step+1:>3}/{steps_per_epoch} | loss={loss.item():.4f} | " |
| f"lr={scheduler.get_last_lr()[0]:.1e} | grad={grad_norm:.3f} | " |
| f"VRAM={vram_mb():.0f}MB | {tps:.0f} tok/s | ETA={eta/60:.0f}m") |
|
|
| if torch.isnan(loss): |
| break |
|
|
| avg_train = epoch_loss / max(epoch_tokens, 1) |
|
|
| |
| log(f" Validating...") |
| model.eval() |
| val_loss, val_tokens = 0.0, 0 |
| with torch.no_grad(): |
| for input_ids, labels, attn_mask in val_dl: |
| input_ids, labels, attn_mask = input_ids.to(device), labels.to(device), attn_mask.to(device) |
| with torch.amp.autocast("cuda") if device.type == "cuda" else torch.no_grad(): |
| outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels) |
| n = (labels != -100).sum().item() |
| val_loss += outputs.loss.item() * n |
| val_tokens += n |
|
|
| avg_val = val_loss / max(val_tokens, 1) |
| epoch_time = time.time() - epoch_start |
| is_best = avg_val < best_val_loss |
|
|
| metrics["epochs"].append({ |
| "epoch": epoch+1, "train_loss": round(avg_train, 4), |
| "val_loss": round(avg_val, 4), "time_s": round(epoch_time, 1), "best": is_best, |
| }) |
|
|
| marker = " β
NEW BEST" if is_best else "" |
| log(f" Epoch {epoch+1}/{args.epochs} DONE | train={avg_train:.4f} val={avg_val:.4f} | " |
| f"{epoch_time:.0f}s{marker}") |
|
|
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| if is_best: |
| best_val_loss = avg_val |
| lora_state = {} |
| for name, module in model.named_modules(): |
| if isinstance(module, LoRALayer): |
| lora_state[name + ".lora_A"] = module.lora_A.state_dict() |
| lora_state[name + ".lora_B"] = module.lora_B.state_dict() |
| torch.save({ |
| "epoch": epoch, "val_loss": avg_val, |
| "lora_state": lora_state, |
| "config": vars(args), |
| }, os.path.join(args.output_dir, "best_distill.pt")) |
|
|
| |
| if (epoch + 1) % args.sample_every == 0 or epoch == args.epochs - 1 or is_best: |
| try: |
| log(f" Generating samples...") |
| model.eval() |
| sample_rng = np.random.RandomState(epoch) |
| sample_idx = sample_rng.choice(len(val_pairs), size=min(3, len(val_pairs)), replace=False) |
|
|
| for si in sample_idx: |
| text, ref = val_pairs[si] |
| prompt = PROMPT_TEMPLATE.format(text=text[:1500]) |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, |
| max_length=args.max_input_len).to(device) |
|
|
| with torch.no_grad(): |
| gen = model.generate( |
| **inputs, max_new_tokens=args.max_output_len, |
| do_sample=False, temperature=1.0, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| gen_text = tokenizer.decode(gen[0][inputs["input_ids"].shape[1]:], |
| skip_special_tokens=True) |
|
|
| del gen |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| metrics["samples"].append({"epoch": epoch+1, "ref": ref[:200], "gen": gen_text[:200]}) |
| log(f" REF: {ref[:100]}") |
| log(f" GEN: {gen_text[:100]}") |
| log(f"") |
| except Exception as e: |
| log(f" Sample generation failed: {e}", "WARN") |
|
|
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| |
| total_time = time.time() - train_start |
| metrics["total_time_s"] = round(total_time, 1) |
| metrics["best_val_loss"] = round(best_val_loss, 4) |
|
|
| with open(os.path.join(args.output_dir, "training_metrics.json"), "w") as f: |
| json.dump(metrics, f, indent=2) |
|
|
| log("") |
| log("=" * 60) |
| log("TRAINING COMPLETE") |
| log("=" * 60) |
| log(f"Total time: {total_time/60:.1f} minutes") |
| log(f"Best val loss: {best_val_loss:.4f}") |
| log(f"") |
| log("Epoch | Train Loss | Val Loss | Time | Best") |
| log("-" * 50) |
| for e in metrics["epochs"]: |
| m = " β
" if e["best"] else "" |
| log(f" {e['epoch']:>3} | {e['train_loss']:.4f} | {e['val_loss']:.4f} | {e['time_s']:.0f}s{m}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|