|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
import argparse
|
|
|
import inspect
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
from transformers import (
|
|
|
AutoTokenizer,
|
|
|
AutoConfig,
|
|
|
AutoModelForTokenClassification,
|
|
|
Trainer,
|
|
|
TrainingArguments,
|
|
|
set_seed,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JsonlTokenDataset(Dataset):
|
|
|
"""Loads JSONL produced by prep.py. Masks special tokens in labels to -100."""
|
|
|
def __init__(self, path: str, tokenizer: AutoTokenizer):
|
|
|
self.path = path
|
|
|
self.tokenizer = tokenizer
|
|
|
self.samples: List[Dict[str, Any]] = []
|
|
|
with open(self.path, "r", encoding="utf-8") as f:
|
|
|
for line in f:
|
|
|
rec = json.loads(line)
|
|
|
self.samples.append(rec)
|
|
|
|
|
|
|
|
|
for rec in self.samples:
|
|
|
input_ids = rec["input_ids"]
|
|
|
labels = rec["labels"]
|
|
|
try:
|
|
|
special_mask = tokenizer.get_special_tokens_mask(input_ids, already_has_special_tokens=True)
|
|
|
except Exception:
|
|
|
spec = set(tokenizer.all_special_ids or [])
|
|
|
special_mask = [1 if t in spec else 0 for t in input_ids]
|
|
|
rec["labels"] = [-100 if sm == 1 else int(l) for l, sm in zip(labels, special_mask)]
|
|
|
|
|
|
def __len__(self): return len(self.samples)
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
|
r = self.samples[idx]
|
|
|
return {
|
|
|
"input_ids": torch.tensor(r["input_ids"], dtype=torch.long),
|
|
|
"attention_mask": torch.tensor(r["attention_mask"], dtype=torch.long),
|
|
|
"labels": torch.tensor(r["labels"], dtype=torch.long),
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleTokenCollator:
|
|
|
"""Pads input_ids with pad_token_id, attention_mask with 0, labels with -100."""
|
|
|
def __init__(self, tokenizer: AutoTokenizer, pad_to_multiple_of: int = None):
|
|
|
self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
|
|
self.pad_to_multiple = pad_to_multiple_of
|
|
|
|
|
|
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
|
ids = [f["input_ids"].tolist() for f in features]
|
|
|
att = [f["attention_mask"].tolist() for f in features]
|
|
|
lab = [f["labels"].tolist() for f in features]
|
|
|
max_len = max(len(x) for x in ids)
|
|
|
if self.pad_to_multiple and max_len % self.pad_to_multiple != 0:
|
|
|
max_len = ((max_len // self.pad_to_multiple) + 1) * self.pad_to_multiple
|
|
|
|
|
|
def pad(seq, val): return seq + [val] * (max_len - len(seq))
|
|
|
ids = [pad(x, self.pad_id) for x in ids]
|
|
|
att = [pad(x, 0) for x in att]
|
|
|
lab = [pad(x, -100) for x in lab]
|
|
|
return {
|
|
|
"input_ids": torch.tensor(ids, dtype=torch.long),
|
|
|
"attention_mask": torch.tensor(att, dtype=torch.long),
|
|
|
"labels": torch.tensor(lab, dtype=torch.long),
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_class_weights(dataset: JsonlTokenDataset) -> torch.Tensor:
|
|
|
pos = 0; neg = 0
|
|
|
for rec in dataset.samples:
|
|
|
for l in rec["labels"]:
|
|
|
if l == -100: continue
|
|
|
if l == 1: pos += 1
|
|
|
else: neg += 1
|
|
|
return torch.tensor([1.0, (neg / max(1, pos)) if pos > 0 else 1.0], dtype=torch.float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_metrics_fn(eval_pred):
|
|
|
logits, labels = eval_pred
|
|
|
preds = np.argmax(logits, axis=-1)
|
|
|
y_true, y_pred = [], []
|
|
|
for p, l in zip(preds, labels):
|
|
|
for pi, li in zip(p, l):
|
|
|
if li == -100: continue
|
|
|
y_true.append(int(li)); y_pred.append(int(pi))
|
|
|
if not y_true:
|
|
|
return {"accuracy":0.0,"precision":0.0,"recall":0.0,"f1":0.0,"pos_rate_true":0.0,"pos_rate_pred":0.0}
|
|
|
y_true = np.array(y_true); y_pred = np.array(y_pred)
|
|
|
tp = int(np.sum((y_pred==1)&(y_true==1))); fp = int(np.sum((y_pred==1)&(y_true==0)))
|
|
|
tn = int(np.sum((y_pred==0)&(y_true==0))); fn = int(np.sum((y_pred==0)&(y_true==1)))
|
|
|
acc = (tp+tn)/max(1,tp+tn+fp+fn); prec = tp/max(1,tp+fp); rec = tp/max(1,tp+fn)
|
|
|
f1 = (2*prec*rec/max(1e-12,prec+rec)) if (prec+rec)>0 else 0.0
|
|
|
return {"accuracy":acc,"precision":prec,"recall":rec,"f1":f1,
|
|
|
"pos_rate_true":float(np.mean(y_true)),"pos_rate_pred":float(np.mean(y_pred))}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WeightedCELossTrainer(Trainer):
|
|
|
def __init__(self, class_weights: torch.Tensor = None, **kwargs):
|
|
|
super().__init__(**kwargs); self.class_weights = class_weights
|
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
|
|
labels = inputs["labels"]
|
|
|
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
|
|
logits = outputs.logits
|
|
|
loss_fct = torch.nn.CrossEntropyLoss(
|
|
|
weight=(self.class_weights.to(logits.device) if self.class_weights is not None else None)
|
|
|
)
|
|
|
mask = labels.ne(-100)
|
|
|
loss = loss_fct(logits.view(-1,2)[mask.view(-1)], labels.view(-1)[mask.view(-1)])
|
|
|
return (loss, outputs) if return_outputs else loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_training_arguments(args) -> TrainingArguments:
|
|
|
sig = set(inspect.signature(TrainingArguments.__init__).parameters.keys())
|
|
|
|
|
|
supports_eval_strategy = "evaluation_strategy" in sig
|
|
|
supports_save_strategy = "save_strategy" in sig
|
|
|
supports_log_strategy = "logging_strategy" in sig
|
|
|
supports_report_to = "report_to" in sig
|
|
|
supports_load_best = "load_best_model_at_end" in sig
|
|
|
supports_metric_forbest = "metric_for_best_model" in sig
|
|
|
supports_workers = "dataloader_num_workers" in sig
|
|
|
|
|
|
kw = {
|
|
|
"output_dir": args.output_dir,
|
|
|
"num_train_epochs": args.epochs,
|
|
|
"per_device_train_batch_size": args.train_batch_size,
|
|
|
"per_device_eval_batch_size": args.eval_batch_size,
|
|
|
"learning_rate": args.lr,
|
|
|
"weight_decay": args.weight_decay,
|
|
|
"logging_steps": args.logging_steps,
|
|
|
"eval_steps": args.eval_steps,
|
|
|
"save_steps": args.save_steps,
|
|
|
"save_total_limit": 2,
|
|
|
"seed": args.seed,
|
|
|
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
|
|
"fp16": args.fp16,
|
|
|
"bf16": args.bf16,
|
|
|
"gradient_checkpointing": args.gradient_checkpointing,
|
|
|
"log_level": "info",
|
|
|
}
|
|
|
if supports_workers:
|
|
|
kw["dataloader_num_workers"] = args.num_workers
|
|
|
if supports_report_to:
|
|
|
kw["report_to"] = (None if args.report_to == "none" else ["wandb"])
|
|
|
|
|
|
|
|
|
if supports_eval_strategy and supports_save_strategy:
|
|
|
kw["evaluation_strategy"] = "steps"
|
|
|
kw["save_strategy"] = "steps"
|
|
|
if supports_log_strategy:
|
|
|
kw["logging_strategy"] = "steps"
|
|
|
if supports_load_best:
|
|
|
kw["load_best_model_at_end"] = True
|
|
|
if supports_metric_forbest:
|
|
|
kw["metric_for_best_model"] = "f1"
|
|
|
if "greater_is_better" in sig:
|
|
|
kw["greater_is_better"] = True
|
|
|
else:
|
|
|
for k in ("evaluation_strategy","save_strategy","logging_strategy","load_best_model_at_end",
|
|
|
"metric_for_best_model","greater_is_better"):
|
|
|
kw.pop(k, None)
|
|
|
if "evaluate_during_training" in sig and args.eval_steps > 0:
|
|
|
kw["evaluate_during_training"] = True
|
|
|
|
|
|
kw = {k: v for k, v in kw.items() if k in sig}
|
|
|
return TrainingArguments(**kw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
ap = argparse.ArgumentParser(description="Train binary token classification model for link anchors.")
|
|
|
ap.add_argument("--model_name", default="microsoft/mdeberta-v3-base", help="HF model name or local path.")
|
|
|
ap.add_argument("--train_path", default="train_windows.jsonl", help="Training JSONL.")
|
|
|
ap.add_argument("--val_path", default="val_windows.jsonl", help="Validation JSONL.")
|
|
|
ap.add_argument("--output_dir", default="model_link_token_cls", help="Output directory.")
|
|
|
|
|
|
ap.add_argument("--epochs", type=int, default=3)
|
|
|
ap.add_argument("--lr", type=float, default=2e-5)
|
|
|
ap.add_argument("--weight_decay", type=float, default=0.01)
|
|
|
ap.add_argument("--train_batch_size", type=int, default=16)
|
|
|
ap.add_argument("--eval_batch_size", type=int, default=32)
|
|
|
ap.add_argument("--logging_steps", type=int, default=50)
|
|
|
ap.add_argument("--eval_steps", type=int, default=500)
|
|
|
ap.add_argument("--save_steps", type=int, default=500)
|
|
|
|
|
|
ap.add_argument("--seed", type=int, default=42)
|
|
|
ap.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
|
|
ap.add_argument("--fp16", action="store_true")
|
|
|
ap.add_argument("--bf16", action="store_true")
|
|
|
ap.add_argument("--gradient_checkpointing", action="store_true")
|
|
|
ap.add_argument("--report_to", default="wandb", choices=["wandb","none"])
|
|
|
ap.add_argument("--pad_to_multiple_of", type=int, default=8)
|
|
|
ap.add_argument("--num_workers", type=int, default=2)
|
|
|
return ap.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
set_seed(args.seed)
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
|
|
|
|
|
|
train_ds = JsonlTokenDataset(args.train_path, tokenizer)
|
|
|
val_ds = JsonlTokenDataset(args.val_path, tokenizer)
|
|
|
|
|
|
id2label = {0: "O", 1: "LINK"}
|
|
|
label2id = {"O": 0, "LINK": 1}
|
|
|
config = AutoConfig.from_pretrained(args.model_name, num_labels=2, id2label=id2label, label2id=label2id)
|
|
|
model = AutoModelForTokenClassification.from_pretrained(args.model_name, config=config)
|
|
|
|
|
|
class_weights = compute_class_weights(train_ds)
|
|
|
|
|
|
collator = SimpleTokenCollator(
|
|
|
tokenizer=tokenizer,
|
|
|
pad_to_multiple_of=(args.pad_to_multiple_of if torch.cuda.is_available() else None),
|
|
|
)
|
|
|
|
|
|
training_args = build_training_arguments(args)
|
|
|
|
|
|
trainer = WeightedCELossTrainer(
|
|
|
model=model,
|
|
|
args=training_args,
|
|
|
train_dataset=train_ds,
|
|
|
eval_dataset=val_ds,
|
|
|
data_collator=collator,
|
|
|
tokenizer=tokenizer,
|
|
|
compute_metrics=compute_metrics_fn,
|
|
|
class_weights=class_weights,
|
|
|
)
|
|
|
|
|
|
trainer.train()
|
|
|
trainer.save_model(args.output_dir)
|
|
|
tokenizer.save_pretrained(args.output_dir)
|
|
|
|
|
|
metrics = trainer.evaluate()
|
|
|
trainer.log_metrics("eval", metrics)
|
|
|
trainer.save_metrics("eval", metrics)
|
|
|
trainer.save_state()
|
|
|
|
|
|
with open(os.path.join(args.output_dir, "label_map.json"), "w", encoding="utf-8") as f:
|
|
|
json.dump({"0":"O","1":"LINK"}, f)
|
|
|
|
|
|
print("=== Training complete ===")
|
|
|
print(f"Output dir: {args.output_dir}")
|
|
|
print(f"Class weights [neg, pos]: [{class_weights[0].item():.4f}, {class_weights[1].item():.4f}]")
|
|
|
print(f"Eval metrics: {metrics}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|