#!/usr/bin/env python3 # train.py # # Fine-tune mdeberta-v3-base for binary token classification (LINK vs O). # Inputs (same dir): train_windows.jsonl, val_windows.jsonl (from prep.py) # JSONL per line: {"doc_id": int, "window_id": int, "input_ids": [...], "attention_mask": [...], "labels": [...]} import os import json import argparse import inspect from typing import List, Dict, Any import numpy as np import torch from torch.utils.data import Dataset from transformers import ( AutoTokenizer, AutoConfig, AutoModelForTokenClassification, Trainer, TrainingArguments, set_seed, ) # -------------------------- # Dataset # -------------------------- class JsonlTokenDataset(Dataset): """Loads JSONL produced by prep.py. Masks special tokens in labels to -100.""" def __init__(self, path: str, tokenizer: AutoTokenizer): self.path = path self.tokenizer = tokenizer self.samples: List[Dict[str, Any]] = [] with open(self.path, "r", encoding="utf-8") as f: for line in f: rec = json.loads(line) self.samples.append(rec) # Mask specials to -100 for rec in self.samples: input_ids = rec["input_ids"] labels = rec["labels"] try: special_mask = tokenizer.get_special_tokens_mask(input_ids, already_has_special_tokens=True) except Exception: spec = set(tokenizer.all_special_ids or []) special_mask = [1 if t in spec else 0 for t in input_ids] rec["labels"] = [-100 if sm == 1 else int(l) for l, sm in zip(labels, special_mask)] def __len__(self): return len(self.samples) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: r = self.samples[idx] return { "input_ids": torch.tensor(r["input_ids"], dtype=torch.long), "attention_mask": torch.tensor(r["attention_mask"], dtype=torch.long), "labels": torch.tensor(r["labels"], dtype=torch.long), } # -------------------------- # Collator (pads inputs + labels) # -------------------------- class SimpleTokenCollator: """Pads input_ids with pad_token_id, attention_mask with 0, labels with -100.""" def __init__(self, tokenizer: AutoTokenizer, pad_to_multiple_of: int = None): self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.pad_to_multiple = pad_to_multiple_of def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: ids = [f["input_ids"].tolist() for f in features] att = [f["attention_mask"].tolist() for f in features] lab = [f["labels"].tolist() for f in features] max_len = max(len(x) for x in ids) if self.pad_to_multiple and max_len % self.pad_to_multiple != 0: max_len = ((max_len // self.pad_to_multiple) + 1) * self.pad_to_multiple def pad(seq, val): return seq + [val] * (max_len - len(seq)) ids = [pad(x, self.pad_id) for x in ids] att = [pad(x, 0) for x in att] lab = [pad(x, -100) for x in lab] return { "input_ids": torch.tensor(ids, dtype=torch.long), "attention_mask": torch.tensor(att, dtype=torch.long), "labels": torch.tensor(lab, dtype=torch.long), } # -------------------------- # Class weights # -------------------------- def compute_class_weights(dataset: JsonlTokenDataset) -> torch.Tensor: pos = 0; neg = 0 for rec in dataset.samples: for l in rec["labels"]: if l == -100: continue if l == 1: pos += 1 else: neg += 1 return torch.tensor([1.0, (neg / max(1, pos)) if pos > 0 else 1.0], dtype=torch.float) # -------------------------- # Metrics # -------------------------- def compute_metrics_fn(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) y_true, y_pred = [], [] for p, l in zip(preds, labels): for pi, li in zip(p, l): if li == -100: continue y_true.append(int(li)); y_pred.append(int(pi)) if not y_true: return {"accuracy":0.0,"precision":0.0,"recall":0.0,"f1":0.0,"pos_rate_true":0.0,"pos_rate_pred":0.0} y_true = np.array(y_true); y_pred = np.array(y_pred) tp = int(np.sum((y_pred==1)&(y_true==1))); fp = int(np.sum((y_pred==1)&(y_true==0))) tn = int(np.sum((y_pred==0)&(y_true==0))); fn = int(np.sum((y_pred==0)&(y_true==1))) acc = (tp+tn)/max(1,tp+tn+fp+fn); prec = tp/max(1,tp+fp); rec = tp/max(1,tp+fn) f1 = (2*prec*rec/max(1e-12,prec+rec)) if (prec+rec)>0 else 0.0 return {"accuracy":acc,"precision":prec,"recall":rec,"f1":f1, "pos_rate_true":float(np.mean(y_true)),"pos_rate_pred":float(np.mean(y_pred))} # -------------------------- # Weighted Trainer (accepts extra kwargs like num_items_in_batch) # -------------------------- class WeightedCELossTrainer(Trainer): def __init__(self, class_weights: torch.Tensor = None, **kwargs): super().__init__(**kwargs); self.class_weights = class_weights def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs["labels"] outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) logits = outputs.logits # [B,T,2] loss_fct = torch.nn.CrossEntropyLoss( weight=(self.class_weights.to(logits.device) if self.class_weights is not None else None) ) mask = labels.ne(-100) loss = loss_fct(logits.view(-1,2)[mask.view(-1)], labels.view(-1)[mask.view(-1)]) return (loss, outputs) if return_outputs else loss # -------------------------- # TrainingArguments compatibility # -------------------------- def build_training_arguments(args) -> TrainingArguments: sig = set(inspect.signature(TrainingArguments.__init__).parameters.keys()) supports_eval_strategy = "evaluation_strategy" in sig supports_save_strategy = "save_strategy" in sig supports_log_strategy = "logging_strategy" in sig supports_report_to = "report_to" in sig supports_load_best = "load_best_model_at_end" in sig supports_metric_forbest = "metric_for_best_model" in sig supports_workers = "dataloader_num_workers" in sig kw = { "output_dir": args.output_dir, "num_train_epochs": args.epochs, "per_device_train_batch_size": args.train_batch_size, "per_device_eval_batch_size": args.eval_batch_size, "learning_rate": args.lr, "weight_decay": args.weight_decay, "logging_steps": args.logging_steps, "eval_steps": args.eval_steps, "save_steps": args.save_steps, "save_total_limit": 2, "seed": args.seed, "gradient_accumulation_steps": args.gradient_accumulation_steps, "fp16": args.fp16, "bf16": args.bf16, "gradient_checkpointing": args.gradient_checkpointing, "log_level": "info", } if supports_workers: kw["dataloader_num_workers"] = args.num_workers if supports_report_to: kw["report_to"] = (None if args.report_to == "none" else ["wandb"]) # Pair strategies only if BOTH are supported to avoid mismatches if supports_eval_strategy and supports_save_strategy: kw["evaluation_strategy"] = "steps" kw["save_strategy"] = "steps" if supports_log_strategy: kw["logging_strategy"] = "steps" if supports_load_best: kw["load_best_model_at_end"] = True if supports_metric_forbest: kw["metric_for_best_model"] = "f1" if "greater_is_better" in sig: kw["greater_is_better"] = True else: for k in ("evaluation_strategy","save_strategy","logging_strategy","load_best_model_at_end", "metric_for_best_model","greater_is_better"): kw.pop(k, None) if "evaluate_during_training" in sig and args.eval_steps > 0: kw["evaluate_during_training"] = True kw = {k: v for k, v in kw.items() if k in sig} return TrainingArguments(**kw) # -------------------------- # Args # -------------------------- def parse_args(): ap = argparse.ArgumentParser(description="Train binary token classification model for link anchors.") ap.add_argument("--model_name", default="microsoft/mdeberta-v3-base", help="HF model name or local path.") ap.add_argument("--train_path", default="train_windows.jsonl", help="Training JSONL.") ap.add_argument("--val_path", default="val_windows.jsonl", help="Validation JSONL.") ap.add_argument("--output_dir", default="model_link_token_cls", help="Output directory.") ap.add_argument("--epochs", type=int, default=3) ap.add_argument("--lr", type=float, default=2e-5) ap.add_argument("--weight_decay", type=float, default=0.01) ap.add_argument("--train_batch_size", type=int, default=16) ap.add_argument("--eval_batch_size", type=int, default=32) ap.add_argument("--logging_steps", type=int, default=50) ap.add_argument("--eval_steps", type=int, default=500) ap.add_argument("--save_steps", type=int, default=500) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--gradient_accumulation_steps", type=int, default=1) ap.add_argument("--fp16", action="store_true") ap.add_argument("--bf16", action="store_true") ap.add_argument("--gradient_checkpointing", action="store_true") ap.add_argument("--report_to", default="wandb", choices=["wandb","none"]) ap.add_argument("--pad_to_multiple_of", type=int, default=8) ap.add_argument("--num_workers", type=int, default=2) return ap.parse_args() # -------------------------- # Main # -------------------------- def main(): args = parse_args() set_seed(args.seed) tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) train_ds = JsonlTokenDataset(args.train_path, tokenizer) val_ds = JsonlTokenDataset(args.val_path, tokenizer) id2label = {0: "O", 1: "LINK"} label2id = {"O": 0, "LINK": 1} config = AutoConfig.from_pretrained(args.model_name, num_labels=2, id2label=id2label, label2id=label2id) model = AutoModelForTokenClassification.from_pretrained(args.model_name, config=config) class_weights = compute_class_weights(train_ds) collator = SimpleTokenCollator( tokenizer=tokenizer, pad_to_multiple_of=(args.pad_to_multiple_of if torch.cuda.is_available() else None), ) training_args = build_training_arguments(args) trainer = WeightedCELossTrainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=collator, tokenizer=tokenizer, compute_metrics=compute_metrics_fn, class_weights=class_weights, ) trainer.train() trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) metrics = trainer.evaluate() trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) trainer.save_state() with open(os.path.join(args.output_dir, "label_map.json"), "w", encoding="utf-8") as f: json.dump({"0":"O","1":"LINK"}, f) print("=== Training complete ===") print(f"Output dir: {args.output_dir}") print(f"Class weights [neg, pos]: [{class_weights[0].item():.4f}, {class_weights[1].item():.4f}]") print(f"Eval metrics: {metrics}") if __name__ == "__main__": main()