import torch import torch.nn as nn from transformers import AutoTokenizer, get_linear_schedule_with_warmup # from sklearn.metrics import f1_score from torch.utils.data import DataLoader from SRL_preprocessing import data_processing_for_loader_conll, srl_collate from model import PredicateAwareSRL from utils import save_pkl import re, pathlib, argparse, json, os, sys try: import _jsonnet except ImportError: _jsonnet = None def load_cfg_from_jsonnet(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True, help="Path to .jsonnet config") parser.add_argument("--out_dir", default=None, help="Override training.out_dir") parser.add_argument("--best_model_path", default=None, help="Override best model save path") parser.add_argument("--save_history_path", default=None, help="Override history pickle path") args, unknown = parser.parse_known_args() if _jsonnet is None: raise RuntimeError("Please `pip install jsonnet` to use --config") cfg = json.loads(_jsonnet.evaluate_file(args.config)) # Apply CLI overrides if args.out_dir: cfg.setdefault("training", {})["out_dir"] = args.out_dir # Ensure out_dir exists & derive default file paths if missing out_dir = cfg["training"].get("out_dir", "./checkpoints") os.makedirs(out_dir, exist_ok=True) # Derive defaults if not provided in config cfg["training"].setdefault("best_model_path", os.path.join(out_dir, "best_srl_fr.ckpt")) cfg["training"].setdefault("save_history_path", os.path.join(out_dir, "loss_history_fr.pkl")) # Allow explicit overrides if args.best_model_path: cfg["training"]["best_model_path"] = args.best_model_path if args.save_history_path: cfg["training"]["save_history_path"] = args.save_history_path return cfg # ============================================================== # 1. Training Loop # ============================================================== def train_one_epoch( model, dataloader, optimizer, device="cuda", scheduler=None, grad_accum_steps=1, amp=True, max_grad_norm=1.0, ): model.train() total_loss, n_steps = 0.0, 0 use_amp = amp and torch.cuda.is_available() scaler = torch.cuda.amp.GradScaler(enabled=use_amp) optimizer.zero_grad(set_to_none=True) for step, batch in enumerate(dataloader, 1): batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.float16): _, loss = model(**batch) # model must return (logits, loss) total_loss += float(loss.detach().item()) n_steps += 1 loss = loss / grad_accum_steps if use_amp: scaler.scale(loss).backward() else: loss.backward() if step % grad_accum_steps == 0: if use_amp: scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad(set_to_none=True) if scheduler is not None: scheduler.step() return total_loss / max(1, n_steps) # ============================================================== # 2. Evaluation Loop # ============================================================== @torch.no_grad() def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"): model.eval() total_loss, n_batches = 0.0, 0 correct, total = 0, 0 for batch in dataloader: gold = batch["labels"] # CPU mask = (gold != -100) # valid word positions batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} logits, loss = model(**batch) total_loss += float(loss.item()); n_batches += 1 preds = logits.argmax(-1).cpu() # micro-F1 == accuracy for single-label classification correct += int((preds[mask] == gold[mask]).sum()) total += int(mask.sum()) micro_f1 = (correct / total) if total > 0 else 0.0 return total_loss / max(1, n_batches), micro_f1 # ============================================================== # 3. Flexible Model Loader (English → French transfer) # ============================================================== def load_model( bert_name: str, label2id, resume_path: str = None, replace_encoder_with: str = None, **kwargs ): """ Creates a PredicateAwareSRL model. - If resume_path is given: loads SRL weights (English model) - If replace_encoder_with is given: replaces only the BERT encoder (e.g., replace 'bert-base-cased' with 'camembert-base') """ print(f"🧩 Loading model backbone: {bert_name}") model = PredicateAwareSRL( bert_name=bert_name, num_labels=len(label2id), use_indicator=kwargs.get("use_indicator", True), use_distance=kwargs.get("use_distance", True), indicator_dim=kwargs.get("indicator_dim", 10), lstm_hidden=kwargs.get("lstm_hidden", 768), mlp_hidden=kwargs.get("mlp_hidden", 300), pos_dim=kwargs.get("pos_dim", 50), max_distance=kwargs.get("max_distance", 128), dropout=kwargs.get("dropout", 0.1), ) if resume_path and os.path.exists(resume_path): print(f"🔁 Loading SRL checkpoint from: {resume_path}") state = torch.load(resume_path, map_location="cpu") state_dict = state.get("model_state", state) missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f" → missing: {len(missing)}, unexpected: {len(unexpected)}") if replace_encoder_with: print(f"🌍 Replacing encoder with: {replace_encoder_with}") from transformers import AutoModel model.bert = AutoModel.from_pretrained(replace_encoder_with) return model # ============================================================== # 4. Main # ============================================================== if __name__ == "__main__": # ------------------------------ # ⚙️ Configuration # ------------------------------ cfg = load_cfg_from_jsonnet() # read values from cfg as usual: conll_train_path = cfg["data"]["conll_train"] conll_valid_path = cfg["data"].get("conll_valid") conll_test_path = cfg["data"].get("conll_test") word_col_idx = cfg["data"]["word_col_idx"] srl_first_col_idx= cfg["data"]["srl_first_col_idx"] bert_name = cfg["model"]["bert_name"] resume_from = cfg["model"].get("resume_from") replace_encoder_with = cfg["model"].get("replace_encoder_with") tok_name = (cfg["model"].get("tokenizer", {}) or {}).get("name", replace_encoder_with or bert_name) out_dir = cfg["training"]["out_dir"] num_epochs = cfg["training"]["num_epochs"] batch_size = cfg["training"]["batch_size"] lr = cfg["training"]["lr"] weight_decay = cfg["training"]["weight_decay"] grad_accum = cfg["training"]["grad_accum_steps"] warmup_ratio = cfg["training"]["warmup_ratio"] amp = cfg["training"]["amp"] max_grad_norm = cfg["training"]["max_grad_norm"] best_model_path = cfg["training"]["best_model_path"] save_history_path = cfg["training"]["save_history_path"] device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------------------ # 🧩 Tokenizer + data loading # ------------------------------ tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name) print(f"Using tokenizer: {replace_encoder_with or bert_name}") # print(f"Loading multilingual CoNLL data: {conll_train_path}") # train_bf_loader, dev_bf_loader, test_bf_loader, label2id, id2label = \ train_bf_loader, dev_bf_loader, label2id, id2label = \ data_processing_for_loader_conll( train_conll=conll_train_path, dev_conll=conll_valid_path, # test_conll=conll_test_path, tokenizer=tokenizer, word_col_idx=word_col_idx, srl_first_col_idx=srl_first_col_idx, max_length=256, ) # pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id pad_token_id = getattr(tokenizer, "pad_token_id", None) if pad_token_id is None: # prefer reusing an existing special token if getattr(tokenizer, "pad_token", None) is None: if getattr(tokenizer, "eos_token", None) is not None: tokenizer.pad_token = tokenizer.eos_token elif getattr(tokenizer, "sep_token", None) is not None: tokenizer.pad_token = tokenizer.sep_token else: # last resort: add a new PAD token (if you do this, resize embeddings after model init) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) pad_token_id = tokenizer.pad_token_id or 0 # ensure int collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100) train_loader = DataLoader(train_bf_loader, batch_size=batch_size, shuffle=True, collate_fn=collate) dev_loader = DataLoader(dev_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if dev_bf_loader else None # test_loader = DataLoader(test_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if test_bf_loader else None # ------------------------------ # 🧠 Model initialization # ------------------------------ model = load_model( bert_name=bert_name, label2id=label2id, resume_path=resume_from, replace_encoder_with=replace_encoder_with, use_indicator=True, use_distance=True, indicator_dim=10, lstm_hidden=768, mlp_hidden=300, pos_dim=50, max_distance=128, dropout=0.1, ).to(device) # ------------------------------ # 🔧 Optimizer + Scheduler # ------------------------------ optimizer = torch.optim.AdamW(model.parameters(), lr=lr) total_steps = len(train_loader) * num_epochs // max(1, grad_accum) warmup_steps = int(warmup_ratio * total_steps) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, ) # ------------------------------ # 🏋️ Training Loop # ------------------------------ history = {"epoch": [], "train_loss": [], "dev_loss": [], "dev_f1": []} best_dev, best_path = -1.0, "best_srl_fr.ckpt" for epoch in range(num_epochs): tr_loss = train_one_epoch( model, train_loader, optimizer, device=device, scheduler=scheduler, grad_accum_steps=grad_accum, amp=amp, max_grad_norm=max_grad_norm, ) dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device) history["epoch"].append(epoch + 1) history["train_loss"].append(tr_loss) history["dev_loss"].append(dev_loss) history["dev_f1"].append(dev_f1) print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} dev_loss={dev_loss:.4f} dev_F1={dev_f1:.4f}") if dev_f1 > best_dev: best_dev = dev_f1 torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path) print(f" ↳ new best dev; saved to {best_path}") save_pkl(history, "loss_history_fr.pkl")