| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, get_linear_schedule_with_warmup |
| |
| from torch.utils.data import DataLoader |
|
|
| from SRL_preprocessing import data_processing_for_loader_conll, srl_collate |
| from model import PredicateAwareSRL |
| from utils import save_pkl |
| import re, pathlib, argparse, json, os, sys |
|
|
|
|
| try: |
| import _jsonnet |
| except ImportError: |
| _jsonnet = None |
|
|
| def load_cfg_from_jsonnet(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", required=True, help="Path to .jsonnet config") |
| parser.add_argument("--out_dir", default=None, help="Override training.out_dir") |
| parser.add_argument("--best_model_path", default=None, help="Override best model save path") |
| parser.add_argument("--save_history_path", default=None, help="Override history pickle path") |
| args, unknown = parser.parse_known_args() |
|
|
| if _jsonnet is None: |
| raise RuntimeError("Please `pip install jsonnet` to use --config") |
|
|
| cfg = json.loads(_jsonnet.evaluate_file(args.config)) |
|
|
| |
| if args.out_dir: |
| cfg.setdefault("training", {})["out_dir"] = args.out_dir |
|
|
| |
| out_dir = cfg["training"].get("out_dir", "./checkpoints") |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| |
| cfg["training"].setdefault("best_model_path", os.path.join(out_dir, "best_srl_fr.ckpt")) |
| cfg["training"].setdefault("save_history_path", os.path.join(out_dir, "loss_history_fr.pkl")) |
|
|
| |
| if args.best_model_path: |
| cfg["training"]["best_model_path"] = args.best_model_path |
| if args.save_history_path: |
| cfg["training"]["save_history_path"] = args.save_history_path |
|
|
| return cfg |
|
|
| |
| |
| |
| def train_one_epoch( |
| model, |
| dataloader, |
| optimizer, |
| device="cuda", |
| scheduler=None, |
| grad_accum_steps=1, |
| amp=True, |
| max_grad_norm=1.0, |
| ): |
| model.train() |
| total_loss, n_steps = 0.0, 0 |
|
|
| use_amp = amp and torch.cuda.is_available() |
| scaler = torch.cuda.amp.GradScaler(enabled=use_amp) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| for step, batch in enumerate(dataloader, 1): |
| batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} |
|
|
| with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.float16): |
| _, loss = model(**batch) |
|
|
| total_loss += float(loss.detach().item()) |
| n_steps += 1 |
| loss = loss / grad_accum_steps |
|
|
| if use_amp: |
| scaler.scale(loss).backward() |
| else: |
| loss.backward() |
|
|
| if step % grad_accum_steps == 0: |
| if use_amp: |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
|
|
| if use_amp: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| if scheduler is not None: |
| scheduler.step() |
|
|
| return total_loss / max(1, n_steps) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"): |
| model.eval() |
| total_loss, n_batches = 0.0, 0 |
| correct, total = 0, 0 |
|
|
| for batch in dataloader: |
| gold = batch["labels"] |
| mask = (gold != -100) |
|
|
| batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} |
| logits, loss = model(**batch) |
| total_loss += float(loss.item()); n_batches += 1 |
|
|
| preds = logits.argmax(-1).cpu() |
| |
| correct += int((preds[mask] == gold[mask]).sum()) |
| total += int(mask.sum()) |
|
|
| micro_f1 = (correct / total) if total > 0 else 0.0 |
| return total_loss / max(1, n_batches), micro_f1 |
|
|
|
|
|
|
| |
| |
| |
| def load_model( |
| bert_name: str, |
| label2id, |
| resume_path: str = None, |
| replace_encoder_with: str = None, |
| **kwargs |
| ): |
| """ |
| Creates a PredicateAwareSRL model. |
| - If resume_path is given: loads SRL weights (English model) |
| - If replace_encoder_with is given: replaces only the BERT encoder |
| (e.g., replace 'bert-base-cased' with 'camembert-base') |
| """ |
| print(f"🧩 Loading model backbone: {bert_name}") |
| model = PredicateAwareSRL( |
| bert_name=bert_name, |
| num_labels=len(label2id), |
| use_indicator=kwargs.get("use_indicator", True), |
| use_distance=kwargs.get("use_distance", True), |
| indicator_dim=kwargs.get("indicator_dim", 10), |
| lstm_hidden=kwargs.get("lstm_hidden", 768), |
| mlp_hidden=kwargs.get("mlp_hidden", 300), |
| pos_dim=kwargs.get("pos_dim", 50), |
| max_distance=kwargs.get("max_distance", 128), |
| dropout=kwargs.get("dropout", 0.1), |
| ) |
|
|
| if resume_path and os.path.exists(resume_path): |
| print(f"🔁 Loading SRL checkpoint from: {resume_path}") |
| state = torch.load(resume_path, map_location="cpu") |
| state_dict = state.get("model_state", state) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| print(f" → missing: {len(missing)}, unexpected: {len(unexpected)}") |
|
|
| if replace_encoder_with: |
| print(f"🌍 Replacing encoder with: {replace_encoder_with}") |
| from transformers import AutoModel |
| model.bert = AutoModel.from_pretrained(replace_encoder_with) |
|
|
| return model |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| |
| |
| |
| cfg = load_cfg_from_jsonnet() |
|
|
| |
| conll_train_path = cfg["data"]["conll_train"] |
| conll_valid_path = cfg["data"].get("conll_valid") |
| conll_test_path = cfg["data"].get("conll_test") |
| word_col_idx = cfg["data"]["word_col_idx"] |
| srl_first_col_idx= cfg["data"]["srl_first_col_idx"] |
|
|
| bert_name = cfg["model"]["bert_name"] |
| resume_from = cfg["model"].get("resume_from") |
| replace_encoder_with = cfg["model"].get("replace_encoder_with") |
| tok_name = (cfg["model"].get("tokenizer", {}) or {}).get("name", replace_encoder_with or bert_name) |
|
|
| out_dir = cfg["training"]["out_dir"] |
| num_epochs = cfg["training"]["num_epochs"] |
| batch_size = cfg["training"]["batch_size"] |
| lr = cfg["training"]["lr"] |
| weight_decay = cfg["training"]["weight_decay"] |
| grad_accum = cfg["training"]["grad_accum_steps"] |
| warmup_ratio = cfg["training"]["warmup_ratio"] |
| amp = cfg["training"]["amp"] |
| max_grad_norm = cfg["training"]["max_grad_norm"] |
|
|
| best_model_path = cfg["training"]["best_model_path"] |
| save_history_path = cfg["training"]["save_history_path"] |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name) |
| print(f"Using tokenizer: {replace_encoder_with or bert_name}") |
|
|
| |
|
|
|
|
| |
| train_bf_loader, dev_bf_loader, label2id, id2label = \ |
| data_processing_for_loader_conll( |
| train_conll=conll_train_path, |
| dev_conll=conll_valid_path, |
| |
| tokenizer=tokenizer, |
| word_col_idx=word_col_idx, |
| srl_first_col_idx=srl_first_col_idx, |
| max_length=256, |
| ) |
|
|
| |
|
|
| pad_token_id = getattr(tokenizer, "pad_token_id", None) |
|
|
| if pad_token_id is None: |
| |
| if getattr(tokenizer, "pad_token", None) is None: |
| if getattr(tokenizer, "eos_token", None) is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
| elif getattr(tokenizer, "sep_token", None) is not None: |
| tokenizer.pad_token = tokenizer.sep_token |
| else: |
| |
| tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
| pad_token_id = tokenizer.pad_token_id or 0 |
|
|
| collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100) |
|
|
| train_loader = DataLoader(train_bf_loader, batch_size=batch_size, shuffle=True, collate_fn=collate) |
| dev_loader = DataLoader(dev_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if dev_bf_loader else None |
| |
|
|
| |
| |
| |
| model = load_model( |
| bert_name=bert_name, |
| label2id=label2id, |
| resume_path=resume_from, |
| replace_encoder_with=replace_encoder_with, |
| use_indicator=True, |
| use_distance=True, |
| indicator_dim=10, |
| lstm_hidden=768, |
| mlp_hidden=300, |
| pos_dim=50, |
| max_distance=128, |
| dropout=0.1, |
| ).to(device) |
|
|
| |
| |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
| total_steps = len(train_loader) * num_epochs // max(1, grad_accum) |
| warmup_steps = int(warmup_ratio * total_steps) |
|
|
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=warmup_steps, |
| num_training_steps=total_steps, |
| ) |
|
|
| |
| |
| |
| history = {"epoch": [], "train_loss": [], "dev_loss": [], "dev_f1": []} |
| best_dev, best_path = -1.0, "best_srl_fr.ckpt" |
|
|
| for epoch in range(num_epochs): |
| tr_loss = train_one_epoch( |
| model, train_loader, optimizer, device=device, |
| scheduler=scheduler, grad_accum_steps=grad_accum, |
| amp=amp, max_grad_norm=max_grad_norm, |
| ) |
| dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device) |
|
|
| history["epoch"].append(epoch + 1) |
| history["train_loss"].append(tr_loss) |
| history["dev_loss"].append(dev_loss) |
| history["dev_f1"].append(dev_f1) |
|
|
| print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} dev_loss={dev_loss:.4f} dev_F1={dev_f1:.4f}") |
|
|
| if dev_f1 > best_dev: |
| best_dev = dev_f1 |
| torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path) |
| print(f" ↳ new best dev; saved to {best_path}") |
|
|
| save_pkl(history, "loss_history_fr.pkl") |
|
|