contract-extractor / backend /scripts /train_classifier.py
myke69's picture
Add files using upload-large-folder tool
296a9b2 verified
Raw
History Blame Contribute Delete
5.82 kB
"""Fine-tune DeBERTa-v3 for multi-label clause classification on CUAD.
Consumes the JSONL produced by scripts/prepare_cuad.py and writes a model dir
that app/finetuned.py loads (CLASSIFIER=finetuned). Plain torch training loop
(no Trainer) so it runs the same on CPU / Apple MPS / CUDA, with positive-class
weighting because CUAD categories are sparse (most clauses are "none").
# 1. build data 2. train 3. score
python -m scripts.prepare_cuad
python -m scripts.train_classifier --epochs 3
python -m eval.run_eval --classifier finetuned --limit 50
Key knobs: --model (default microsoft/deberta-v3-base), --epochs, --batch,
--lr, --out, --limit (cap training rows for a quick smoke run).
"""
from __future__ import annotations
import json
import pathlib
import sys
ROOT = pathlib.Path(__file__).resolve().parents[2]
BACKEND = ROOT / "backend"
sys.path.insert(0, str(BACKEND))
DATA = ROOT / "data" / "cuad"
DEFAULT_OUT = BACKEND / "models" / "clause-clf"
def arg(name: str, default):
for a in sys.argv:
if a.startswith(f"--{name}="):
v = a.split("=", 1)[1]
return type(default)(v) if default is not None else v
return default
def load_jsonl(path: pathlib.Path):
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
def main() -> None:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_id = arg("model", "microsoft/deberta-v3-base")
epochs = int(arg("epochs", 3))
batch = int(arg("batch", 8))
lr = float(arg("lr", 2e-5))
limit = int(arg("limit", 0)) # 0 = all
max_len = int(arg("max_len", 512))
force_device = arg("device", "auto") # auto | cpu | mps | cuda
out = pathlib.Path(arg("out", str(DEFAULT_OUT)))
labels = json.loads((DATA / "labels.json").read_text())
lab2i = {l: i for i, l in enumerate(labels)}
train = load_jsonl(DATA / "train.jsonl")
val = load_jsonl(DATA / "val.jsonl")
if limit:
train = train[:limit]
print(f"train={len(train)} val={len(val)} labels={len(labels)} model={model_id}")
if force_device != "auto":
device = force_device
else:
device = ("cuda" if torch.cuda.is_available()
else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu")
print(f"device={device} max_len={max_len} batch={batch}")
tok = AutoTokenizer.from_pretrained(model_id)
def encode(ex):
y = torch.zeros(len(labels))
for l in ex["labels"]:
y[lab2i[l]] = 1.0
return ex["text"], y
class DS(Dataset):
def __init__(self, rows): self.rows = [encode(r) for r in rows]
def __len__(self): return len(self.rows)
def __getitem__(self, i): return self.rows[i]
def collate(b):
texts, ys = zip(*b)
enc = tok(list(texts), truncation=True, max_length=max_len, padding=True, return_tensors="pt")
return enc, torch.stack(ys)
# positive weighting: rarer label -> higher weight, so the model can't win
# by predicting all-zeros under the heavy class imbalance. Clamp matters:
# too high (e.g. 50) makes the model over-predict (high recall, ~0
# precision). ~10 is a balanced default; tune with --pos_clamp.
pos_clamp = float(arg("pos_clamp", 10.0))
counts = torch.zeros(len(labels))
for _, y in DS(train).rows:
counts += y
pos_weight = ((len(train) - counts).clamp(min=1) / counts.clamp(min=1)).clamp(max=pos_clamp)
model = AutoModelForSequenceClassification.from_pretrained(
model_id, num_labels=len(labels), problem_type="multi_label_classification").to(device)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
opt = torch.optim.AdamW(model.parameters(), lr=lr)
dl = DataLoader(DS(train), batch_size=batch, shuffle=True, collate_fn=collate)
model.train()
for ep in range(epochs):
total = 0.0
for enc, y in dl:
enc = {k: v.to(device) for k, v in enc.items()}
opt.zero_grad()
loss = loss_fn(model(**enc).logits, y.to(device))
loss.backward()
opt.step()
total += loss.item()
print(f"epoch {ep + 1}/{epochs} loss={total / max(1, len(dl)):.4f}")
out.mkdir(parents=True, exist_ok=True)
model.save_pretrained(out)
tok.save_pretrained(out)
(out / "labels.json").write_text(json.dumps(labels, indent=2))
print(f"saved model -> {out}")
# quick val macro-F1 @0.5 (sanity; run_eval gives the comparable CUAD number)
model.eval()
tp = {l: 0 for l in labels}; fp = dict(tp); fn = dict(tp)
with torch.no_grad():
for enc, y in DataLoader(DS(val), batch_size=batch, collate_fn=collate):
enc = {k: v.to(device) for k, v in enc.items()}
pred = (torch.sigmoid(model(**enc).logits) >= 0.5).cpu()
for pr, gt in zip(pred, y):
for j, l in enumerate(labels):
if pr[j] and gt[j]: tp[l] += 1
elif pr[j] and not gt[j]: fp[l] += 1
elif not pr[j] and gt[j]: fn[l] += 1
f1s = []
for l in labels:
p = tp[l] / (tp[l] + fp[l]) if tp[l] + fp[l] else 0.0
r = tp[l] / (tp[l] + fn[l]) if tp[l] + fn[l] else 0.0
f = 2 * p * r / (p + r) if p + r else 0.0
f1s.append(f)
print(f" {l:<16} P={p:.2f} R={r:.2f} F1={f:.2f}")
print(f"val macro-F1 @0.5 = {sum(f1s) / len(f1s):.3f}")
if __name__ == "__main__":
main()