google-links / train.py
dejanseo's picture
Upload 22 files
f29b6e6 verified
#!/usr/bin/env python3
# train.py
#
# Fine-tune mdeberta-v3-base for binary token classification (LINK vs O).
# Inputs (same dir): train_windows.jsonl, val_windows.jsonl (from prep.py)
# JSONL per line: {"doc_id": int, "window_id": int, "input_ids": [...], "attention_mask": [...], "labels": [...]}
import os
import json
import argparse
import inspect
from typing import List, Dict, Any
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForTokenClassification,
Trainer,
TrainingArguments,
set_seed,
)
# --------------------------
# Dataset
# --------------------------
class JsonlTokenDataset(Dataset):
"""Loads JSONL produced by prep.py. Masks special tokens in labels to -100."""
def __init__(self, path: str, tokenizer: AutoTokenizer):
self.path = path
self.tokenizer = tokenizer
self.samples: List[Dict[str, Any]] = []
with open(self.path, "r", encoding="utf-8") as f:
for line in f:
rec = json.loads(line)
self.samples.append(rec)
# Mask specials to -100
for rec in self.samples:
input_ids = rec["input_ids"]
labels = rec["labels"]
try:
special_mask = tokenizer.get_special_tokens_mask(input_ids, already_has_special_tokens=True)
except Exception:
spec = set(tokenizer.all_special_ids or [])
special_mask = [1 if t in spec else 0 for t in input_ids]
rec["labels"] = [-100 if sm == 1 else int(l) for l, sm in zip(labels, special_mask)]
def __len__(self): return len(self.samples)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
r = self.samples[idx]
return {
"input_ids": torch.tensor(r["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(r["attention_mask"], dtype=torch.long),
"labels": torch.tensor(r["labels"], dtype=torch.long),
}
# --------------------------
# Collator (pads inputs + labels)
# --------------------------
class SimpleTokenCollator:
"""Pads input_ids with pad_token_id, attention_mask with 0, labels with -100."""
def __init__(self, tokenizer: AutoTokenizer, pad_to_multiple_of: int = None):
self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.pad_to_multiple = pad_to_multiple_of
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
ids = [f["input_ids"].tolist() for f in features]
att = [f["attention_mask"].tolist() for f in features]
lab = [f["labels"].tolist() for f in features]
max_len = max(len(x) for x in ids)
if self.pad_to_multiple and max_len % self.pad_to_multiple != 0:
max_len = ((max_len // self.pad_to_multiple) + 1) * self.pad_to_multiple
def pad(seq, val): return seq + [val] * (max_len - len(seq))
ids = [pad(x, self.pad_id) for x in ids]
att = [pad(x, 0) for x in att]
lab = [pad(x, -100) for x in lab]
return {
"input_ids": torch.tensor(ids, dtype=torch.long),
"attention_mask": torch.tensor(att, dtype=torch.long),
"labels": torch.tensor(lab, dtype=torch.long),
}
# --------------------------
# Class weights
# --------------------------
def compute_class_weights(dataset: JsonlTokenDataset) -> torch.Tensor:
pos = 0; neg = 0
for rec in dataset.samples:
for l in rec["labels"]:
if l == -100: continue
if l == 1: pos += 1
else: neg += 1
return torch.tensor([1.0, (neg / max(1, pos)) if pos > 0 else 1.0], dtype=torch.float)
# --------------------------
# Metrics
# --------------------------
def compute_metrics_fn(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
y_true, y_pred = [], []
for p, l in zip(preds, labels):
for pi, li in zip(p, l):
if li == -100: continue
y_true.append(int(li)); y_pred.append(int(pi))
if not y_true:
return {"accuracy":0.0,"precision":0.0,"recall":0.0,"f1":0.0,"pos_rate_true":0.0,"pos_rate_pred":0.0}
y_true = np.array(y_true); y_pred = np.array(y_pred)
tp = int(np.sum((y_pred==1)&(y_true==1))); fp = int(np.sum((y_pred==1)&(y_true==0)))
tn = int(np.sum((y_pred==0)&(y_true==0))); fn = int(np.sum((y_pred==0)&(y_true==1)))
acc = (tp+tn)/max(1,tp+tn+fp+fn); prec = tp/max(1,tp+fp); rec = tp/max(1,tp+fn)
f1 = (2*prec*rec/max(1e-12,prec+rec)) if (prec+rec)>0 else 0.0
return {"accuracy":acc,"precision":prec,"recall":rec,"f1":f1,
"pos_rate_true":float(np.mean(y_true)),"pos_rate_pred":float(np.mean(y_pred))}
# --------------------------
# Weighted Trainer (accepts extra kwargs like num_items_in_batch)
# --------------------------
class WeightedCELossTrainer(Trainer):
def __init__(self, class_weights: torch.Tensor = None, **kwargs):
super().__init__(**kwargs); self.class_weights = class_weights
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs["labels"]
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
logits = outputs.logits # [B,T,2]
loss_fct = torch.nn.CrossEntropyLoss(
weight=(self.class_weights.to(logits.device) if self.class_weights is not None else None)
)
mask = labels.ne(-100)
loss = loss_fct(logits.view(-1,2)[mask.view(-1)], labels.view(-1)[mask.view(-1)])
return (loss, outputs) if return_outputs else loss
# --------------------------
# TrainingArguments compatibility
# --------------------------
def build_training_arguments(args) -> TrainingArguments:
sig = set(inspect.signature(TrainingArguments.__init__).parameters.keys())
supports_eval_strategy = "evaluation_strategy" in sig
supports_save_strategy = "save_strategy" in sig
supports_log_strategy = "logging_strategy" in sig
supports_report_to = "report_to" in sig
supports_load_best = "load_best_model_at_end" in sig
supports_metric_forbest = "metric_for_best_model" in sig
supports_workers = "dataloader_num_workers" in sig
kw = {
"output_dir": args.output_dir,
"num_train_epochs": args.epochs,
"per_device_train_batch_size": args.train_batch_size,
"per_device_eval_batch_size": args.eval_batch_size,
"learning_rate": args.lr,
"weight_decay": args.weight_decay,
"logging_steps": args.logging_steps,
"eval_steps": args.eval_steps,
"save_steps": args.save_steps,
"save_total_limit": 2,
"seed": args.seed,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"fp16": args.fp16,
"bf16": args.bf16,
"gradient_checkpointing": args.gradient_checkpointing,
"log_level": "info",
}
if supports_workers:
kw["dataloader_num_workers"] = args.num_workers
if supports_report_to:
kw["report_to"] = (None if args.report_to == "none" else ["wandb"])
# Pair strategies only if BOTH are supported to avoid mismatches
if supports_eval_strategy and supports_save_strategy:
kw["evaluation_strategy"] = "steps"
kw["save_strategy"] = "steps"
if supports_log_strategy:
kw["logging_strategy"] = "steps"
if supports_load_best:
kw["load_best_model_at_end"] = True
if supports_metric_forbest:
kw["metric_for_best_model"] = "f1"
if "greater_is_better" in sig:
kw["greater_is_better"] = True
else:
for k in ("evaluation_strategy","save_strategy","logging_strategy","load_best_model_at_end",
"metric_for_best_model","greater_is_better"):
kw.pop(k, None)
if "evaluate_during_training" in sig and args.eval_steps > 0:
kw["evaluate_during_training"] = True
kw = {k: v for k, v in kw.items() if k in sig}
return TrainingArguments(**kw)
# --------------------------
# Args
# --------------------------
def parse_args():
ap = argparse.ArgumentParser(description="Train binary token classification model for link anchors.")
ap.add_argument("--model_name", default="microsoft/mdeberta-v3-base", help="HF model name or local path.")
ap.add_argument("--train_path", default="train_windows.jsonl", help="Training JSONL.")
ap.add_argument("--val_path", default="val_windows.jsonl", help="Validation JSONL.")
ap.add_argument("--output_dir", default="model_link_token_cls", help="Output directory.")
ap.add_argument("--epochs", type=int, default=3)
ap.add_argument("--lr", type=float, default=2e-5)
ap.add_argument("--weight_decay", type=float, default=0.01)
ap.add_argument("--train_batch_size", type=int, default=16)
ap.add_argument("--eval_batch_size", type=int, default=32)
ap.add_argument("--logging_steps", type=int, default=50)
ap.add_argument("--eval_steps", type=int, default=500)
ap.add_argument("--save_steps", type=int, default=500)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--gradient_accumulation_steps", type=int, default=1)
ap.add_argument("--fp16", action="store_true")
ap.add_argument("--bf16", action="store_true")
ap.add_argument("--gradient_checkpointing", action="store_true")
ap.add_argument("--report_to", default="wandb", choices=["wandb","none"])
ap.add_argument("--pad_to_multiple_of", type=int, default=8)
ap.add_argument("--num_workers", type=int, default=2)
return ap.parse_args()
# --------------------------
# Main
# --------------------------
def main():
args = parse_args()
set_seed(args.seed)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
train_ds = JsonlTokenDataset(args.train_path, tokenizer)
val_ds = JsonlTokenDataset(args.val_path, tokenizer)
id2label = {0: "O", 1: "LINK"}
label2id = {"O": 0, "LINK": 1}
config = AutoConfig.from_pretrained(args.model_name, num_labels=2, id2label=id2label, label2id=label2id)
model = AutoModelForTokenClassification.from_pretrained(args.model_name, config=config)
class_weights = compute_class_weights(train_ds)
collator = SimpleTokenCollator(
tokenizer=tokenizer,
pad_to_multiple_of=(args.pad_to_multiple_of if torch.cuda.is_available() else None),
)
training_args = build_training_arguments(args)
trainer = WeightedCELossTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics_fn,
class_weights=class_weights,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
trainer.save_state()
with open(os.path.join(args.output_dir, "label_map.json"), "w", encoding="utf-8") as f:
json.dump({"0":"O","1":"LINK"}, f)
print("=== Training complete ===")
print(f"Output dir: {args.output_dir}")
print(f"Class weights [neg, pos]: [{class_weights[0].item():.4f}, {class_weights[1].item():.4f}]")
print(f"Eval metrics: {metrics}")
if __name__ == "__main__":
main()