qwen3-0.6b-summarizer / training /gpu_distill.py
ericflo's picture
Upload training/gpu_distill.py with huggingface_hub
834460a verified
#!/usr/bin/env python3
"""
Distill Gemini Flash summaries into Qwen3-0.6B.
Fine-tunes Qwen3-0.6B with LoRA to generate one-sentence summaries from
raw markdown text β€” distilling from 6,720 high-quality Gemini-generated
summaries. At inference time, feed any markdown text and get a summary
back. Runs on CPU for inference (~1-2s per summary).
Input: raw embedded_text (markdown)
Output: one-sentence summary (Gemini-quality, Qwen-speed)
Usage:
python3 gpu_distill.py --data-dir /workspace/data --output-dir /workspace/output
"""
import json
import os
import sys
import time
import datetime
import argparse
import math
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
def log(msg, level="INFO"):
ts = datetime.datetime.now().strftime("%H:%M:%S")
print(f"[{ts}] [{level}] {msg}", flush=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", default="/workspace/data")
parser.add_argument("--output-dir", default="/workspace/output")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--lr", type=float, default=2e-4)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--lora-alpha", type=int, default=32)
parser.add_argument("--model-name", default="Qwen/Qwen3-0.6B")
parser.add_argument("--max-input-len", type=int, default=384, help="Max input tokens")
parser.add_argument("--max-output-len", type=int, default=64, help="Max output tokens")
parser.add_argument("--log-every", type=int, default=10)
parser.add_argument("--sample-every", type=int, default=2)
args = parser.parse_args()
log("=" * 60)
log("DISTILLATION: Markdown β†’ Summary (LoRA fine-tune)")
log("=" * 60)
log(f"Config: epochs={args.epochs} batch={args.batch_size} lr={args.lr} "
f"lora_rank={args.lora_rank} input_len={args.max_input_len} output_len={args.max_output_len}")
# Auto-install missing deps (don't touch torch β€” use image's version)
import subprocess as _sp
for pkg in ["numpy", "transformers", "accelerate", "safetensors"]:
try:
__import__(pkg)
except ImportError:
log(f"Installing {pkg}...")
_sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages",
"-q", pkg], check=True)
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
log(f"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
props = torch.cuda.get_device_properties(0)
log(f"GPU: {torch.cuda.get_device_name()} | VRAM: {props.total_memory / 1024**3:.1f} GB")
os.makedirs(args.output_dir, exist_ok=True)
def vram_mb():
return torch.cuda.memory_allocated() / 1024**2 if device.type == "cuda" else 0
metrics = {
"config": vars(args), "device": str(device),
"gpu": torch.cuda.get_device_name() if device.type == "cuda" else "cpu",
"method": "distillation", "steps": [], "epochs": [], "samples": [],
"start_time": time.time(),
}
# ── Load data ──────────────────────────────────────────────────────
log("Loading data...")
t0 = time.time()
# Load texts (embedded_text from clouderic.db) and summaries
with open(os.path.join(args.data_dir, "texts.json")) as f:
text_data = json.load(f) # [{"id": str, "text": str}]
with open(os.path.join(args.data_dir, "summaries.json")) as f:
sum_data = json.load(f) # [{"id": str, "summary": str}]
sum_map = {s["id"]: s["summary"] for s in sum_data}
pairs = [(t["text"], sum_map[t["id"]]) for t in text_data
if t["id"] in sum_map and t["text"] and len(t["text"].strip()) > 20]
log(f"Loaded {len(pairs)} (text, summary) pairs in {time.time()-t0:.1f}s")
# Stats
text_lens = [len(t) for t, _ in pairs]
sum_lens = [len(s) for _, s in pairs]
log(f"Text lengths: mean={np.mean(text_lens):.0f} median={np.median(text_lens):.0f} "
f"max={max(text_lens)} chars")
log(f"Summary lengths: mean={np.mean(sum_lens):.0f} median={np.median(sum_lens):.0f} "
f"max={max(sum_lens)} chars")
# ── Load model ─────────────────────────────────────────────────────
log(f"Loading {args.model_name}...")
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # for decoder-only models
model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.float16, trust_remote_code=True,
).to(device)
for param in model.parameters():
param.requires_grad = False
hidden_dim = model.config.hidden_size
log(f"Model loaded in {time.time()-t0:.1f}s: hidden={hidden_dim} | VRAM: {vram_mb():.0f}MB")
# ── LoRA ───────────────────────────────────────────────────────────
class LoRALayer(nn.Module):
def __init__(self, original_layer, rank, alpha):
super().__init__()
self.original = original_layer
in_f, out_f = original_layer.in_features, original_layer.out_features
self.lora_A = nn.Linear(in_f, rank, bias=False)
self.lora_B = nn.Linear(rank, out_f, bias=False)
self.scaling = alpha / rank
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x):
orig_out = self.original(x)
lora_out = self.lora_B(self.lora_A(x.to(self.lora_A.weight.dtype)))
return orig_out + lora_out.to(orig_out.dtype) * self.scaling
lora_modules = []
n_adapted = 0
for name, module in model.named_modules():
if hasattr(module, 'q_proj') and isinstance(module.q_proj, nn.Linear):
lora_q = LoRALayer(module.q_proj, args.lora_rank, args.lora_alpha).to(device)
module.q_proj = lora_q
lora_modules.append(lora_q)
n_adapted += 1
if hasattr(module, 'v_proj') and isinstance(module.v_proj, nn.Linear):
lora_v = LoRALayer(module.v_proj, args.lora_rank, args.lora_alpha).to(device)
module.v_proj = lora_v
lora_modules.append(lora_v)
n_adapted += 1
lora_params = []
for lm in lora_modules:
lora_params.extend(lm.lora_A.parameters())
lora_params.extend(lm.lora_B.parameters())
lora_total = sum(p.numel() for p in lora_params)
log(f"LoRA applied to {n_adapted} layers | {lora_total:,} trainable params | VRAM: {vram_mb():.0f}MB")
# ── Dataset ────────────────────────────────────────────────────────
PROMPT_TEMPLATE = "Summarize in one sentence:\n{text}\n\nSummary:"
class DistillDataset(Dataset):
def __init__(self, pairs, tokenizer, max_input, max_output):
self.items = []
for text, summary in pairs:
# Truncate text to fit
prompt = PROMPT_TEMPLATE.format(text=text[:2000])
# Tokenize prompt and summary separately
prompt_enc = tokenizer(prompt, truncation=True, max_length=max_input,
return_tensors="pt")
summary_enc = tokenizer(summary, truncation=True, max_length=max_output,
return_tensors="pt")
# Concatenate: [prompt_tokens] [summary_tokens] [eos]
input_ids = torch.cat([
prompt_enc["input_ids"].squeeze(0),
summary_enc["input_ids"].squeeze(0),
torch.tensor([tokenizer.eos_token_id]),
])
# Labels: -100 for prompt, actual ids for summary+eos
n_prompt = prompt_enc["input_ids"].shape[1]
labels = input_ids.clone()
labels[:n_prompt] = -100
# Truncate total to max_input + max_output
max_total = max_input + max_output
if len(input_ids) > max_total:
input_ids = input_ids[:max_total]
labels = labels[:max_total]
self.items.append((input_ids, labels))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.items[idx]
def collate_fn(batch):
input_ids_list, labels_list = zip(*batch)
max_len = max(ids.shape[0] for ids in input_ids_list)
input_ids = torch.full((len(batch), max_len), tokenizer.pad_token_id, dtype=torch.long)
labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
attention_mask = torch.zeros((len(batch), max_len), dtype=torch.long)
for i, (ids, lab) in enumerate(zip(input_ids_list, labels_list)):
# Right-align (pad on left for decoder-only)
offset = max_len - ids.shape[0]
input_ids[i, offset:] = ids
labels[i, offset:] = lab
attention_mask[i, offset:] = 1
return input_ids, labels, attention_mask
# Split
n_val = max(int(len(pairs) * 0.1), 1)
rng = np.random.RandomState(42)
indices = rng.permutation(len(pairs))
val_pairs = [pairs[i] for i in indices[:n_val]]
train_pairs = [pairs[i] for i in indices[n_val:]]
train_ds = DistillDataset(train_pairs, tokenizer, args.max_input_len, args.max_output_len)
val_ds = DistillDataset(val_pairs, tokenizer, args.max_input_len, args.max_output_len)
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
drop_last=True, collate_fn=collate_fn)
val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
steps_per_epoch = len(train_dl)
total_steps = steps_per_epoch * args.epochs
log(f"Data: train={len(train_ds)} val={len(val_ds)} | {steps_per_epoch} steps/epoch, "
f"{total_steps} total")
# ── Training ───────────────────────────────────────────────────────
optimizer = torch.optim.AdamW(lora_params, lr=args.lr, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
scaler = torch.amp.GradScaler("cuda") if device.type == "cuda" else None
best_val_loss = float("inf")
global_step = 0
log("")
log("=" * 60)
log("TRAINING START")
log("=" * 60)
train_start = time.time()
for epoch in range(args.epochs):
model.train()
epoch_loss, epoch_tokens = 0.0, 0
epoch_start = time.time()
log(f"")
log(f"── Epoch {epoch+1}/{args.epochs} ──")
for step, (input_ids, labels, attn_mask) in enumerate(train_dl):
step_start = time.time()
input_ids = input_ids.to(device)
labels = labels.to(device)
attn_mask = attn_mask.to(device)
optimizer.zero_grad()
if scaler:
with torch.amp.autocast("cuda"):
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels)
loss = outputs.loss
if torch.isnan(loss):
log(f"NaN at step {step+1}!", "ERROR")
break
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(lora_params, 1.0).item()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels)
loss = outputs.loss
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(lora_params, 1.0).item()
optimizer.step()
scheduler.step()
n_tokens = (labels != -100).sum().item()
step_time = time.time() - step_start
tps = n_tokens / step_time if step_time > 0 else 0
epoch_loss += loss.item() * n_tokens
epoch_tokens += n_tokens
global_step += 1
metrics["steps"].append({
"epoch": epoch+1, "step": step+1, "global_step": global_step,
"loss": round(loss.item(), 4), "lr": scheduler.get_last_lr()[0],
"grad_norm": round(grad_norm, 4), "vram_mb": round(vram_mb()),
"tokens_per_sec": round(tps),
})
if step % args.log_every == 0:
elapsed = time.time() - train_start
eta = elapsed / global_step * (total_steps - global_step) if global_step > 0 else 0
log(f" step {step+1:>3}/{steps_per_epoch} | loss={loss.item():.4f} | "
f"lr={scheduler.get_last_lr()[0]:.1e} | grad={grad_norm:.3f} | "
f"VRAM={vram_mb():.0f}MB | {tps:.0f} tok/s | ETA={eta/60:.0f}m")
if torch.isnan(loss):
break
avg_train = epoch_loss / max(epoch_tokens, 1)
# Validation
log(f" Validating...")
model.eval()
val_loss, val_tokens = 0.0, 0
with torch.no_grad():
for input_ids, labels, attn_mask in val_dl:
input_ids, labels, attn_mask = input_ids.to(device), labels.to(device), attn_mask.to(device)
with torch.amp.autocast("cuda") if device.type == "cuda" else torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels)
n = (labels != -100).sum().item()
val_loss += outputs.loss.item() * n
val_tokens += n
avg_val = val_loss / max(val_tokens, 1)
epoch_time = time.time() - epoch_start
is_best = avg_val < best_val_loss
metrics["epochs"].append({
"epoch": epoch+1, "train_loss": round(avg_train, 4),
"val_loss": round(avg_val, 4), "time_s": round(epoch_time, 1), "best": is_best,
})
marker = " β˜… NEW BEST" if is_best else ""
log(f" Epoch {epoch+1}/{args.epochs} DONE | train={avg_train:.4f} val={avg_val:.4f} | "
f"{epoch_time:.0f}s{marker}")
if device.type == "cuda":
torch.cuda.empty_cache()
if is_best:
best_val_loss = avg_val
lora_state = {}
for name, module in model.named_modules():
if isinstance(module, LoRALayer):
lora_state[name + ".lora_A"] = module.lora_A.state_dict()
lora_state[name + ".lora_B"] = module.lora_B.state_dict()
torch.save({
"epoch": epoch, "val_loss": avg_val,
"lora_state": lora_state,
"config": vars(args),
}, os.path.join(args.output_dir, "best_distill.pt"))
# Samples
if (epoch + 1) % args.sample_every == 0 or epoch == args.epochs - 1 or is_best:
try:
log(f" Generating samples...")
model.eval()
sample_rng = np.random.RandomState(epoch)
sample_idx = sample_rng.choice(len(val_pairs), size=min(3, len(val_pairs)), replace=False)
for si in sample_idx:
text, ref = val_pairs[si]
prompt = PROMPT_TEMPLATE.format(text=text[:1500])
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
max_length=args.max_input_len).to(device)
with torch.no_grad():
gen = model.generate(
**inputs, max_new_tokens=args.max_output_len,
do_sample=False, temperature=1.0,
pad_token_id=tokenizer.pad_token_id,
)
gen_text = tokenizer.decode(gen[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True)
del gen
if device.type == "cuda":
torch.cuda.empty_cache()
metrics["samples"].append({"epoch": epoch+1, "ref": ref[:200], "gen": gen_text[:200]})
log(f" REF: {ref[:100]}")
log(f" GEN: {gen_text[:100]}")
log(f"")
except Exception as e:
log(f" Sample generation failed: {e}", "WARN")
if device.type == "cuda":
torch.cuda.empty_cache()
# ── Summary ────────────────────────────────────────────────────────
total_time = time.time() - train_start
metrics["total_time_s"] = round(total_time, 1)
metrics["best_val_loss"] = round(best_val_loss, 4)
with open(os.path.join(args.output_dir, "training_metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)
log("")
log("=" * 60)
log("TRAINING COMPLETE")
log("=" * 60)
log(f"Total time: {total_time/60:.1f} minutes")
log(f"Best val loss: {best_val_loss:.4f}")
log(f"")
log("Epoch | Train Loss | Val Loss | Time | Best")
log("-" * 50)
for e in metrics["epochs"]:
m = " β˜…" if e["best"] else ""
log(f" {e['epoch']:>3} | {e['train_loss']:.4f} | {e['val_loss']:.4f} | {e['time_s']:.0f}s{m}")
if __name__ == "__main__":
main()