#!/usr/bin/env python3 """Train the MiniCPM5-1B quest-classification LoRA adapter on Modal. The dataset (chat-JSONL produced by hackathon_advisor.quest_dataset) is sent to a GPU container, fine-tuned with PEFT LoRA, self-evaluated on a held-out slice, and the adapter is returned as a zip the local entrypoint unpacks under artifacts/. Smoke test the GPU first: modal run scripts/modal_train_quest_lora.py::smoke Train: modal run scripts/modal_train_quest_lora.py --dataset data/quest_sft.jsonl """ from __future__ import annotations import argparse from pathlib import Path import modal APP_NAME = "hackathon-advisor-quest-lora" BASE_MODEL = "openbmb/MiniCPM5-1B" GPU = "L40S" app = modal.App(APP_NAME) image = ( modal.Image.debian_slim(python_version="3.11") .pip_install( "torch>=2.4,<3", "transformers>=4.55,<5", "peft>=0.13,<1", "accelerate>=1.0,<2", "huggingface-hub>=0.36,<1", "datasets>=3,<4", "sentencepiece>=0.2,<1", ) .add_local_python_source("hackathon_advisor", copy=True) ) @app.function(image=image, gpu=GPU, timeout=3600) def smoke() -> dict: import torch return { "cuda": torch.cuda.is_available(), "device": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu", "torch": torch.__version__, } @app.function(image=image, gpu=GPU, timeout=7800) def train_remote( dataset_text: str, *, base_model: str = BASE_MODEL, rank: int = 64, alpha: int = 128, dropout: float = 0.0, learning_rate: float = 2e-4, epochs: float = 16.0, max_seq_length: int = 3072, eval_holdout: int = 0, upweight_variants: tuple = ("hard_negative", "remote_app_only", "contradiction", "empty"), upweight_factor: int = 3, ) -> dict: import io import json import os import random import zipfile os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import torch from peft import LoraConfig, TaskType, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments from hackathon_advisor.quest_dataset import parse_quest_dataset_jsonl from hackathon_advisor.quest_taxonomy import normalize_match manifest, examples = parse_quest_dataset_jsonl(dataset_text) random.Random(42).shuffle(examples) # representative holdout; keep edge cases mostly in train holdout = examples[-eval_holdout:] if eval_holdout and len(examples) > eval_holdout * 2 else [] base_train = examples[: len(examples) - len(holdout)] if holdout else list(examples) # Up-weight the contrastive negatives so they outweigh the strong Off-the-Grid prior. upweighted = [ex for ex in base_train for _ in range(upweight_factor - 1) if ex.get("variant") in upweight_variants] train_examples = base_train + upweighted random.Random(43).shuffle(train_examples) print(f"examples: total={len(examples)} base_train={len(base_train)} +upweighted={len(upweighted)} " f"-> train={len(train_examples)} holdout={len(holdout)}", flush=True) tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True, ) model.config.use_cache = False target_modules = sorted( { name.rsplit(".", 1)[-1] for name, module in model.named_modules() if isinstance(module, torch.nn.Linear) and name.rsplit(".", 1)[-1] not in {"lm_head", "embed_tokens"} } ) if not target_modules: raise RuntimeError("no LoRA target modules discovered") print("LoRA targets:", target_modules, flush=True) model = get_peft_model( model, LoraConfig( r=rank, lora_alpha=alpha, lora_dropout=dropout, target_modules=target_modules, task_type=TaskType.CAUSAL_LM, ), ) model.print_trainable_parameters() model.enable_input_require_grads() # required for gradient checkpointing over a frozen base im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") def template(messages, *, add_generation_prompt): try: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt, enable_thinking=False ) except TypeError: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt ) def encode(example: dict) -> dict: # Build the sequence as the EXACT inference prompt (which includes the empty # block emitted with enable_thinking=False) followed by the # strict-JSON completion and the <|im_end|> turn terminator. The prompt is # tokenized identically to inference so the model never sees a shifted context. messages = example["messages"] prompt_text = template(messages[:-1], add_generation_prompt=True) prompt_ids = tokenizer(prompt_text)["input_ids"] completion_ids = tokenizer(messages[-1]["content"], add_special_tokens=False)["input_ids"] + [im_end_id] input_ids = (prompt_ids + completion_ids)[:max_seq_length] labels = ([-100] * len(prompt_ids) + completion_ids)[:max_seq_length] return {"input_ids": input_ids, "attention_mask": [1] * len(input_ids), "labels": labels} class DS(torch.utils.data.Dataset): def __init__(self, rows): self.rows = [encode(r) for r in rows] def __len__(self): return len(self.rows) def __getitem__(self, i): return self.rows[i] def collate(batch): maxlen = max(len(b["input_ids"]) for b in batch) pad_id = tokenizer.pad_token_id input_ids, attn, labels = [], [], [] for b in batch: n = maxlen - len(b["input_ids"]) input_ids.append(b["input_ids"] + [pad_id] * n) attn.append(b["attention_mask"] + [0] * n) labels.append(b["labels"] + [-100] * n) return { "input_ids": torch.tensor(input_ids), "attention_mask": torch.tensor(attn), "labels": torch.tensor(labels), } args = TrainingArguments( output_dir="/tmp/quest-lora", num_train_epochs=epochs, per_device_train_batch_size=2, gradient_accumulation_steps=4, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, learning_rate=learning_rate, lr_scheduler_type="cosine", warmup_ratio=0.05, logging_steps=5, bf16=True, save_strategy="no", report_to=[], ) trainer = Trainer(model=model, args=args, train_dataset=DS(train_examples), data_collator=collate) trainer.train() out = Path("/tmp/quest-lora-adapter") out.mkdir(parents=True, exist_ok=True) model.save_pretrained(out) tokenizer.save_pretrained(out) (out / "training-recipe.json").write_text( json.dumps( { "type": "lora_training_recipe", "base_model": base_model, "adapter_task": manifest.get("adapter_task"), "method": "LoRA SFT (completion-only loss)", "example_count": len(train_examples), "epochs": epochs, "rank": rank, "alpha": alpha, "dropout": dropout, "learning_rate": learning_rate, "max_seq_length": max_seq_length, "target_modules": target_modules, "gpu": GPU, }, ensure_ascii=False, indent=2, ), encoding="utf-8", ) # --- full-dataset eval: does the adapter reproduce the gold quest set for EVERY example? --- # The goal is correct judgement across the whole dataset, so we score all of it. import gc loss_history = [h.get("loss") for h in trainer.state.log_history if "loss" in h] del trainer gc.collect() torch.cuda.empty_cache() model.config.use_cache = True try: model.gradient_checkpointing_disable() except Exception: # noqa: BLE001 pass model.eval() def gold_quests(ex): return {m["quest"] for m in json.loads(ex["messages"][-1]["content"]).get("matches", [])} valid = exact = 0 tp = fp = fn = 0 mismatches = [] eval_set = holdout if holdout else examples try: for ex in eval_set: prompt_text = template(ex["messages"][:-1], add_generation_prompt=True) inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") inputs.pop("token_type_ids", None) with torch.inference_mode(): gen = model.generate(**inputs, max_new_tokens=512, do_sample=False, eos_token_id=im_end_id) text = tokenizer.decode(gen[0, inputs["input_ids"].shape[-1]:], skip_special_tokens=True).strip() gold = gold_quests(ex) try: payload = json.loads(text) pred = set() for m in payload["matches"]: normalize_match(m) pred.add(m["quest"]) valid += 1 except Exception: # noqa: BLE001 mismatches.append({"project_id": ex.get("project_id", ""), "variant": ex.get("variant", ""), "gold": sorted(gold), "pred": "INVALID_JSON", "output": text[:300]}) fn += len(gold) continue tp += len(gold & pred) fp += len(pred - gold) fn += len(gold - pred) if pred == gold: exact += 1 else: mismatches.append({"project_id": ex.get("project_id", ""), "variant": ex.get("variant", ""), "gold": sorted(gold), "pred": sorted(pred)}) except Exception as error: # noqa: BLE001 - keep the adapter even if eval breaks print(f"eval aborted: {type(error).__name__}: {error}", flush=True) n = len(eval_set) precision = tp / (tp + fp) if (tp + fp) else 1.0 recall = tp / (tp + fn) if (tp + fn) else 1.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 print(f"full-eval: valid_json {valid}/{n} | quest-set exact {exact}/{n} " f"| micro P/R/F1 {precision:.3f}/{recall:.3f}/{f1:.3f} | mismatches {len(mismatches)}", flush=True) buffer = io.BytesIO() with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf: for path in sorted(out.rglob("*")): if path.is_file(): zf.write(path, path.relative_to(out).as_posix()) return { "adapter_zip": buffer.getvalue(), "eval": { "n": n, "valid_json": valid, "quest_set_exact": exact, "precision": round(precision, 4), "recall": round(recall, 4), "f1": round(f1, 4), "mismatches": mismatches, }, "train_examples": len(train_examples), "loss_history": loss_history, } @app.local_entrypoint() def main(dataset: str = "data/quest_sft.jsonl", out_dir: str = "artifacts/quest-lora", epochs: float = 8.0) -> None: import io import json import zipfile dataset_text = Path(dataset).read_text(encoding="utf-8") result = train_remote.remote(dataset_text, epochs=epochs) out = Path(out_dir) out.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(io.BytesIO(result["adapter_zip"])) as zf: zf.extractall(out) ev = result["eval"] (out / "self-eval.json").write_text(json.dumps(ev, ensure_ascii=False, indent=2), encoding="utf-8") print(f"adapter written to {out}") print(f"full-eval: valid_json {ev['valid_json']}/{ev['n']} | quest-set exact {ev['quest_set_exact']}/{ev['n']} " f"| micro F1 {ev['f1']} | mismatches {len(ev['mismatches'])}") print(f"loss history: {result['loss_history']}") def _cli() -> None: parser = argparse.ArgumentParser(description="Train the quest-classification LoRA on Modal.") parser.add_argument("--dataset", default="data/quest_sft.jsonl") parser.add_argument("--out-dir", default="artifacts/quest-lora") parser.add_argument("--epochs", type=float, default=4.0) parser.parse_args() print("Run via: modal run scripts/modal_train_quest_lora.py --dataset data/quest_sft.jsonl") if __name__ == "__main__": _cli()