"""Fine-tune DeBERTa-v3 for multi-label clause classification on CUAD. Consumes the JSONL produced by scripts/prepare_cuad.py and writes a model dir that app/finetuned.py loads (CLASSIFIER=finetuned). Plain torch training loop (no Trainer) so it runs the same on CPU / Apple MPS / CUDA, with positive-class weighting because CUAD categories are sparse (most clauses are "none"). # 1. build data 2. train 3. score python -m scripts.prepare_cuad python -m scripts.train_classifier --epochs 3 python -m eval.run_eval --classifier finetuned --limit 50 Key knobs: --model (default microsoft/deberta-v3-base), --epochs, --batch, --lr, --out, --limit (cap training rows for a quick smoke run). """ from __future__ import annotations import json import pathlib import sys ROOT = pathlib.Path(__file__).resolve().parents[2] BACKEND = ROOT / "backend" sys.path.insert(0, str(BACKEND)) DATA = ROOT / "data" / "cuad" DEFAULT_OUT = BACKEND / "models" / "clause-clf" def arg(name: str, default): for a in sys.argv: if a.startswith(f"--{name}="): v = a.split("=", 1)[1] return type(default)(v) if default is not None else v return default def load_jsonl(path: pathlib.Path): return [json.loads(line) for line in path.read_text().splitlines() if line.strip()] def main() -> None: import torch from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer model_id = arg("model", "microsoft/deberta-v3-base") epochs = int(arg("epochs", 3)) batch = int(arg("batch", 8)) lr = float(arg("lr", 2e-5)) limit = int(arg("limit", 0)) # 0 = all max_len = int(arg("max_len", 512)) force_device = arg("device", "auto") # auto | cpu | mps | cuda out = pathlib.Path(arg("out", str(DEFAULT_OUT))) labels = json.loads((DATA / "labels.json").read_text()) lab2i = {l: i for i, l in enumerate(labels)} train = load_jsonl(DATA / "train.jsonl") val = load_jsonl(DATA / "val.jsonl") if limit: train = train[:limit] print(f"train={len(train)} val={len(val)} labels={len(labels)} model={model_id}") if force_device != "auto": device = force_device else: device = ("cuda" if torch.cuda.is_available() else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") print(f"device={device} max_len={max_len} batch={batch}") tok = AutoTokenizer.from_pretrained(model_id) def encode(ex): y = torch.zeros(len(labels)) for l in ex["labels"]: y[lab2i[l]] = 1.0 return ex["text"], y class DS(Dataset): def __init__(self, rows): self.rows = [encode(r) for r in rows] def __len__(self): return len(self.rows) def __getitem__(self, i): return self.rows[i] def collate(b): texts, ys = zip(*b) enc = tok(list(texts), truncation=True, max_length=max_len, padding=True, return_tensors="pt") return enc, torch.stack(ys) # positive weighting: rarer label -> higher weight, so the model can't win # by predicting all-zeros under the heavy class imbalance. Clamp matters: # too high (e.g. 50) makes the model over-predict (high recall, ~0 # precision). ~10 is a balanced default; tune with --pos_clamp. pos_clamp = float(arg("pos_clamp", 10.0)) counts = torch.zeros(len(labels)) for _, y in DS(train).rows: counts += y pos_weight = ((len(train) - counts).clamp(min=1) / counts.clamp(min=1)).clamp(max=pos_clamp) model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=len(labels), problem_type="multi_label_classification").to(device) loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device)) opt = torch.optim.AdamW(model.parameters(), lr=lr) dl = DataLoader(DS(train), batch_size=batch, shuffle=True, collate_fn=collate) model.train() for ep in range(epochs): total = 0.0 for enc, y in dl: enc = {k: v.to(device) for k, v in enc.items()} opt.zero_grad() loss = loss_fn(model(**enc).logits, y.to(device)) loss.backward() opt.step() total += loss.item() print(f"epoch {ep + 1}/{epochs} loss={total / max(1, len(dl)):.4f}") out.mkdir(parents=True, exist_ok=True) model.save_pretrained(out) tok.save_pretrained(out) (out / "labels.json").write_text(json.dumps(labels, indent=2)) print(f"saved model -> {out}") # quick val macro-F1 @0.5 (sanity; run_eval gives the comparable CUAD number) model.eval() tp = {l: 0 for l in labels}; fp = dict(tp); fn = dict(tp) with torch.no_grad(): for enc, y in DataLoader(DS(val), batch_size=batch, collate_fn=collate): enc = {k: v.to(device) for k, v in enc.items()} pred = (torch.sigmoid(model(**enc).logits) >= 0.5).cpu() for pr, gt in zip(pred, y): for j, l in enumerate(labels): if pr[j] and gt[j]: tp[l] += 1 elif pr[j] and not gt[j]: fp[l] += 1 elif not pr[j] and gt[j]: fn[l] += 1 f1s = [] for l in labels: p = tp[l] / (tp[l] + fp[l]) if tp[l] + fp[l] else 0.0 r = tp[l] / (tp[l] + fn[l]) if tp[l] + fn[l] else 0.0 f = 2 * p * r / (p + r) if p + r else 0.0 f1s.append(f) print(f" {l:<16} P={p:.2f} R={r:.2f} F1={f:.2f}") print(f"val macro-F1 @0.5 = {sum(f1s) / len(f1s):.3f}") if __name__ == "__main__": main()