Spaces:
Sleeping
Sleeping
| """Fine-tune DeBERTa-v3 for multi-label clause classification on CUAD. | |
| Consumes the JSONL produced by scripts/prepare_cuad.py and writes a model dir | |
| that app/finetuned.py loads (CLASSIFIER=finetuned). Plain torch training loop | |
| (no Trainer) so it runs the same on CPU / Apple MPS / CUDA, with positive-class | |
| weighting because CUAD categories are sparse (most clauses are "none"). | |
| # 1. build data 2. train 3. score | |
| python -m scripts.prepare_cuad | |
| python -m scripts.train_classifier --epochs 3 | |
| python -m eval.run_eval --classifier finetuned --limit 50 | |
| Key knobs: --model (default microsoft/deberta-v3-base), --epochs, --batch, | |
| --lr, --out, --limit (cap training rows for a quick smoke run). | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import pathlib | |
| import sys | |
| ROOT = pathlib.Path(__file__).resolve().parents[2] | |
| BACKEND = ROOT / "backend" | |
| sys.path.insert(0, str(BACKEND)) | |
| DATA = ROOT / "data" / "cuad" | |
| DEFAULT_OUT = BACKEND / "models" / "clause-clf" | |
| def arg(name: str, default): | |
| for a in sys.argv: | |
| if a.startswith(f"--{name}="): | |
| v = a.split("=", 1)[1] | |
| return type(default)(v) if default is not None else v | |
| return default | |
| def load_jsonl(path: pathlib.Path): | |
| return [json.loads(line) for line in path.read_text().splitlines() if line.strip()] | |
| def main() -> None: | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| model_id = arg("model", "microsoft/deberta-v3-base") | |
| epochs = int(arg("epochs", 3)) | |
| batch = int(arg("batch", 8)) | |
| lr = float(arg("lr", 2e-5)) | |
| limit = int(arg("limit", 0)) # 0 = all | |
| max_len = int(arg("max_len", 512)) | |
| force_device = arg("device", "auto") # auto | cpu | mps | cuda | |
| out = pathlib.Path(arg("out", str(DEFAULT_OUT))) | |
| labels = json.loads((DATA / "labels.json").read_text()) | |
| lab2i = {l: i for i, l in enumerate(labels)} | |
| train = load_jsonl(DATA / "train.jsonl") | |
| val = load_jsonl(DATA / "val.jsonl") | |
| if limit: | |
| train = train[:limit] | |
| print(f"train={len(train)} val={len(val)} labels={len(labels)} model={model_id}") | |
| if force_device != "auto": | |
| device = force_device | |
| else: | |
| device = ("cuda" if torch.cuda.is_available() | |
| else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() | |
| else "cpu") | |
| print(f"device={device} max_len={max_len} batch={batch}") | |
| tok = AutoTokenizer.from_pretrained(model_id) | |
| def encode(ex): | |
| y = torch.zeros(len(labels)) | |
| for l in ex["labels"]: | |
| y[lab2i[l]] = 1.0 | |
| return ex["text"], y | |
| class DS(Dataset): | |
| def __init__(self, rows): self.rows = [encode(r) for r in rows] | |
| def __len__(self): return len(self.rows) | |
| def __getitem__(self, i): return self.rows[i] | |
| def collate(b): | |
| texts, ys = zip(*b) | |
| enc = tok(list(texts), truncation=True, max_length=max_len, padding=True, return_tensors="pt") | |
| return enc, torch.stack(ys) | |
| # positive weighting: rarer label -> higher weight, so the model can't win | |
| # by predicting all-zeros under the heavy class imbalance. Clamp matters: | |
| # too high (e.g. 50) makes the model over-predict (high recall, ~0 | |
| # precision). ~10 is a balanced default; tune with --pos_clamp. | |
| pos_clamp = float(arg("pos_clamp", 10.0)) | |
| counts = torch.zeros(len(labels)) | |
| for _, y in DS(train).rows: | |
| counts += y | |
| pos_weight = ((len(train) - counts).clamp(min=1) / counts.clamp(min=1)).clamp(max=pos_clamp) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_id, num_labels=len(labels), problem_type="multi_label_classification").to(device) | |
| loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device)) | |
| opt = torch.optim.AdamW(model.parameters(), lr=lr) | |
| dl = DataLoader(DS(train), batch_size=batch, shuffle=True, collate_fn=collate) | |
| model.train() | |
| for ep in range(epochs): | |
| total = 0.0 | |
| for enc, y in dl: | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| opt.zero_grad() | |
| loss = loss_fn(model(**enc).logits, y.to(device)) | |
| loss.backward() | |
| opt.step() | |
| total += loss.item() | |
| print(f"epoch {ep + 1}/{epochs} loss={total / max(1, len(dl)):.4f}") | |
| out.mkdir(parents=True, exist_ok=True) | |
| model.save_pretrained(out) | |
| tok.save_pretrained(out) | |
| (out / "labels.json").write_text(json.dumps(labels, indent=2)) | |
| print(f"saved model -> {out}") | |
| # quick val macro-F1 @0.5 (sanity; run_eval gives the comparable CUAD number) | |
| model.eval() | |
| tp = {l: 0 for l in labels}; fp = dict(tp); fn = dict(tp) | |
| with torch.no_grad(): | |
| for enc, y in DataLoader(DS(val), batch_size=batch, collate_fn=collate): | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| pred = (torch.sigmoid(model(**enc).logits) >= 0.5).cpu() | |
| for pr, gt in zip(pred, y): | |
| for j, l in enumerate(labels): | |
| if pr[j] and gt[j]: tp[l] += 1 | |
| elif pr[j] and not gt[j]: fp[l] += 1 | |
| elif not pr[j] and gt[j]: fn[l] += 1 | |
| f1s = [] | |
| for l in labels: | |
| p = tp[l] / (tp[l] + fp[l]) if tp[l] + fp[l] else 0.0 | |
| r = tp[l] / (tp[l] + fn[l]) if tp[l] + fn[l] else 0.0 | |
| f = 2 * p * r / (p + r) if p + r else 0.0 | |
| f1s.append(f) | |
| print(f" {l:<16} P={p:.2f} R={r:.2f} F1={f:.2f}") | |
| print(f"val macro-F1 @0.5 = {sum(f1s) / len(f1s):.3f}") | |
| if __name__ == "__main__": | |
| main() | |