srl_bert_model / trainer.py
yeomtong's picture
Update trainer.py
72b55dd verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
# from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from SRL_preprocessing import data_processing_for_loader_conll, srl_collate
from model import PredicateAwareSRL
from utils import save_pkl
import re, pathlib, argparse, json, os, sys
try:
import _jsonnet
except ImportError:
_jsonnet = None
def load_cfg_from_jsonnet():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="Path to .jsonnet config")
parser.add_argument("--out_dir", default=None, help="Override training.out_dir")
parser.add_argument("--best_model_path", default=None, help="Override best model save path")
parser.add_argument("--save_history_path", default=None, help="Override history pickle path")
args, unknown = parser.parse_known_args()
if _jsonnet is None:
raise RuntimeError("Please `pip install jsonnet` to use --config")
cfg = json.loads(_jsonnet.evaluate_file(args.config))
# Apply CLI overrides
if args.out_dir:
cfg.setdefault("training", {})["out_dir"] = args.out_dir
# Ensure out_dir exists & derive default file paths if missing
out_dir = cfg["training"].get("out_dir", "./checkpoints")
os.makedirs(out_dir, exist_ok=True)
# Derive defaults if not provided in config
cfg["training"].setdefault("best_model_path", os.path.join(out_dir, "best_srl_fr.ckpt"))
cfg["training"].setdefault("save_history_path", os.path.join(out_dir, "loss_history_fr.pkl"))
# Allow explicit overrides
if args.best_model_path:
cfg["training"]["best_model_path"] = args.best_model_path
if args.save_history_path:
cfg["training"]["save_history_path"] = args.save_history_path
return cfg
# ==============================================================
# 1. Training Loop
# ==============================================================
def train_one_epoch(
model,
dataloader,
optimizer,
device="cuda",
scheduler=None,
grad_accum_steps=1,
amp=True,
max_grad_norm=1.0,
):
model.train()
total_loss, n_steps = 0.0, 0
use_amp = amp and torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
optimizer.zero_grad(set_to_none=True)
for step, batch in enumerate(dataloader, 1):
batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.float16):
_, loss = model(**batch) # model must return (logits, loss)
total_loss += float(loss.detach().item())
n_steps += 1
loss = loss / grad_accum_steps
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
if step % grad_accum_steps == 0:
if use_amp:
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
if use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if scheduler is not None:
scheduler.step()
return total_loss / max(1, n_steps)
# ==============================================================
# 2. Evaluation Loop
# ==============================================================
@torch.no_grad()
def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
model.eval()
total_loss, n_batches = 0.0, 0
correct, total = 0, 0
for batch in dataloader:
gold = batch["labels"] # CPU
mask = (gold != -100) # valid word positions
batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
logits, loss = model(**batch)
total_loss += float(loss.item()); n_batches += 1
preds = logits.argmax(-1).cpu()
# micro-F1 == accuracy for single-label classification
correct += int((preds[mask] == gold[mask]).sum())
total += int(mask.sum())
micro_f1 = (correct / total) if total > 0 else 0.0
return total_loss / max(1, n_batches), micro_f1
# ==============================================================
# 3. Flexible Model Loader (English → French transfer)
# ==============================================================
def load_model(
bert_name: str,
label2id,
resume_path: str = None,
replace_encoder_with: str = None,
**kwargs
):
"""
Creates a PredicateAwareSRL model.
- If resume_path is given: loads SRL weights (English model)
- If replace_encoder_with is given: replaces only the BERT encoder
(e.g., replace 'bert-base-cased' with 'camembert-base')
"""
print(f"🧩 Loading model backbone: {bert_name}")
model = PredicateAwareSRL(
bert_name=bert_name,
num_labels=len(label2id),
use_indicator=kwargs.get("use_indicator", True),
use_distance=kwargs.get("use_distance", True),
indicator_dim=kwargs.get("indicator_dim", 10),
lstm_hidden=kwargs.get("lstm_hidden", 768),
mlp_hidden=kwargs.get("mlp_hidden", 300),
pos_dim=kwargs.get("pos_dim", 50),
max_distance=kwargs.get("max_distance", 128),
dropout=kwargs.get("dropout", 0.1),
)
if resume_path and os.path.exists(resume_path):
print(f"🔁 Loading SRL checkpoint from: {resume_path}")
state = torch.load(resume_path, map_location="cpu")
state_dict = state.get("model_state", state)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f" → missing: {len(missing)}, unexpected: {len(unexpected)}")
if replace_encoder_with:
print(f"🌍 Replacing encoder with: {replace_encoder_with}")
from transformers import AutoModel
model.bert = AutoModel.from_pretrained(replace_encoder_with)
return model
# ==============================================================
# 4. Main
# ==============================================================
if __name__ == "__main__":
# ------------------------------
# ⚙️ Configuration
# ------------------------------
cfg = load_cfg_from_jsonnet()
# read values from cfg as usual:
conll_train_path = cfg["data"]["conll_train"]
conll_valid_path = cfg["data"].get("conll_valid")
conll_test_path = cfg["data"].get("conll_test")
word_col_idx = cfg["data"]["word_col_idx"]
srl_first_col_idx= cfg["data"]["srl_first_col_idx"]
bert_name = cfg["model"]["bert_name"]
resume_from = cfg["model"].get("resume_from")
replace_encoder_with = cfg["model"].get("replace_encoder_with")
tok_name = (cfg["model"].get("tokenizer", {}) or {}).get("name", replace_encoder_with or bert_name)
out_dir = cfg["training"]["out_dir"]
num_epochs = cfg["training"]["num_epochs"]
batch_size = cfg["training"]["batch_size"]
lr = cfg["training"]["lr"]
weight_decay = cfg["training"]["weight_decay"]
grad_accum = cfg["training"]["grad_accum_steps"]
warmup_ratio = cfg["training"]["warmup_ratio"]
amp = cfg["training"]["amp"]
max_grad_norm = cfg["training"]["max_grad_norm"]
best_model_path = cfg["training"]["best_model_path"]
save_history_path = cfg["training"]["save_history_path"]
device = "cuda" if torch.cuda.is_available() else "cpu"
# ------------------------------
# 🧩 Tokenizer + data loading
# ------------------------------
tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name)
print(f"Using tokenizer: {replace_encoder_with or bert_name}")
# print(f"Loading multilingual CoNLL data: {conll_train_path}")
# train_bf_loader, dev_bf_loader, test_bf_loader, label2id, id2label = \
train_bf_loader, dev_bf_loader, label2id, id2label = \
data_processing_for_loader_conll(
train_conll=conll_train_path,
dev_conll=conll_valid_path,
# test_conll=conll_test_path,
tokenizer=tokenizer,
word_col_idx=word_col_idx,
srl_first_col_idx=srl_first_col_idx,
max_length=256,
)
# pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if pad_token_id is None:
# prefer reusing an existing special token
if getattr(tokenizer, "pad_token", None) is None:
if getattr(tokenizer, "eos_token", None) is not None:
tokenizer.pad_token = tokenizer.eos_token
elif getattr(tokenizer, "sep_token", None) is not None:
tokenizer.pad_token = tokenizer.sep_token
else:
# last resort: add a new PAD token (if you do this, resize embeddings after model init)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
pad_token_id = tokenizer.pad_token_id or 0 # ensure int
collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)
train_loader = DataLoader(train_bf_loader, batch_size=batch_size, shuffle=True, collate_fn=collate)
dev_loader = DataLoader(dev_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if dev_bf_loader else None
# test_loader = DataLoader(test_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if test_bf_loader else None
# ------------------------------
# 🧠 Model initialization
# ------------------------------
model = load_model(
bert_name=bert_name,
label2id=label2id,
resume_path=resume_from,
replace_encoder_with=replace_encoder_with,
use_indicator=True,
use_distance=True,
indicator_dim=10,
lstm_hidden=768,
mlp_hidden=300,
pos_dim=50,
max_distance=128,
dropout=0.1,
).to(device)
# ------------------------------
# 🔧 Optimizer + Scheduler
# ------------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
total_steps = len(train_loader) * num_epochs // max(1, grad_accum)
warmup_steps = int(warmup_ratio * total_steps)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
# ------------------------------
# 🏋️ Training Loop
# ------------------------------
history = {"epoch": [], "train_loss": [], "dev_loss": [], "dev_f1": []}
best_dev, best_path = -1.0, "best_srl_fr.ckpt"
for epoch in range(num_epochs):
tr_loss = train_one_epoch(
model, train_loader, optimizer, device=device,
scheduler=scheduler, grad_accum_steps=grad_accum,
amp=amp, max_grad_norm=max_grad_norm,
)
dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)
history["epoch"].append(epoch + 1)
history["train_loss"].append(tr_loss)
history["dev_loss"].append(dev_loss)
history["dev_f1"].append(dev_f1)
print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} dev_loss={dev_loss:.4f} dev_F1={dev_f1:.4f}")
if dev_f1 > best_dev:
best_dev = dev_f1
torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path)
print(f" ↳ new best dev; saved to {best_path}")
save_pkl(history, "loss_history_fr.pkl")