File size: 11,716 Bytes
0c0db8b
 
 
b926eb7
0c0db8b
 
eb21fd5
0c0db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eb6561
0c0db8b
 
3eb6561
 
0c0db8b
3eb6561
0c0db8b
 
 
 
3eb6561
 
 
 
 
 
0c0db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b70354a
0c0db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a6fe36
d117d6b
0c0db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b38540
0c0db8b
d117d6b
 
 
f1c3b7e
 
d117d6b
 
72b55dd
f1c3b7e
d117d6b
 
 
 
 
0c0db8b
20a1375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c0db8b
 
d117d6b
 
f1c3b7e
0c0db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f20fc03
f23ec2d
0c0db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f20fc03
d19224b
0c0db8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import torch
import torch.nn as nn
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
# from sklearn.metrics import f1_score
from torch.utils.data import DataLoader

from SRL_preprocessing import data_processing_for_loader_conll, srl_collate
from model import PredicateAwareSRL
from utils import save_pkl
import re, pathlib, argparse, json, os, sys


try:
    import _jsonnet
except ImportError:
    _jsonnet = None

def load_cfg_from_jsonnet():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True, help="Path to .jsonnet config")
    parser.add_argument("--out_dir", default=None, help="Override training.out_dir")
    parser.add_argument("--best_model_path", default=None, help="Override best model save path")
    parser.add_argument("--save_history_path", default=None, help="Override history pickle path")
    args, unknown = parser.parse_known_args()

    if _jsonnet is None:
        raise RuntimeError("Please `pip install jsonnet` to use --config")

    cfg = json.loads(_jsonnet.evaluate_file(args.config))

    # Apply CLI overrides
    if args.out_dir:
        cfg.setdefault("training", {})["out_dir"] = args.out_dir

    # Ensure out_dir exists & derive default file paths if missing
    out_dir = cfg["training"].get("out_dir", "./checkpoints")
    os.makedirs(out_dir, exist_ok=True)

    # Derive defaults if not provided in config
    cfg["training"].setdefault("best_model_path", os.path.join(out_dir, "best_srl_fr.ckpt"))
    cfg["training"].setdefault("save_history_path", os.path.join(out_dir, "loss_history_fr.pkl"))

    # Allow explicit overrides
    if args.best_model_path:
        cfg["training"]["best_model_path"] = args.best_model_path
    if args.save_history_path:
        cfg["training"]["save_history_path"] = args.save_history_path

    return cfg

# ==============================================================
# 1. Training Loop
# ==============================================================
def train_one_epoch(
    model,
    dataloader,
    optimizer,
    device="cuda",
    scheduler=None,
    grad_accum_steps=1,
    amp=True,
    max_grad_norm=1.0,
):
    model.train()
    total_loss, n_steps = 0.0, 0

    use_amp = amp and torch.cuda.is_available()
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    optimizer.zero_grad(set_to_none=True)

    for step, batch in enumerate(dataloader, 1):
        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

        with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.float16):
            _, loss = model(**batch)  # model must return (logits, loss)

        total_loss += float(loss.detach().item())
        n_steps += 1
        loss = loss / grad_accum_steps

        if use_amp:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if step % grad_accum_steps == 0:
            if use_amp:
                scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            if use_amp:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()

            optimizer.zero_grad(set_to_none=True)

            if scheduler is not None:
                scheduler.step()

    return total_loss / max(1, n_steps)


# ==============================================================
# 2. Evaluation Loop
# ==============================================================
@torch.no_grad()
def eval_loss_and_token_f1(model, dataloader, id2label=None, device="cuda", average="micro"):
    model.eval()
    total_loss, n_batches = 0.0, 0
    correct, total = 0, 0

    for batch in dataloader:
        gold = batch["labels"]              # CPU
        mask = (gold != -100)               # valid word positions

        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
        logits, loss = model(**batch)
        total_loss += float(loss.item()); n_batches += 1

        preds = logits.argmax(-1).cpu()
        # micro-F1 == accuracy for single-label classification
        correct += int((preds[mask] == gold[mask]).sum())
        total   += int(mask.sum())

    micro_f1 = (correct / total) if total > 0 else 0.0
    return total_loss / max(1, n_batches), micro_f1



# ==============================================================
# 3. Flexible Model Loader (English → French transfer)
# ==============================================================
def load_model(
    bert_name: str,
    label2id,
    resume_path: str = None,
    replace_encoder_with: str = None,
    **kwargs
):
    """
    Creates a PredicateAwareSRL model.
    - If resume_path is given: loads SRL weights (English model)
    - If replace_encoder_with is given: replaces only the BERT encoder
      (e.g., replace 'bert-base-cased' with 'camembert-base')
    """
    print(f"🧩 Loading model backbone: {bert_name}")
    model = PredicateAwareSRL(
        bert_name=bert_name,
        num_labels=len(label2id),
        use_indicator=kwargs.get("use_indicator", True),
        use_distance=kwargs.get("use_distance", True),
        indicator_dim=kwargs.get("indicator_dim", 10),
        lstm_hidden=kwargs.get("lstm_hidden", 768),
        mlp_hidden=kwargs.get("mlp_hidden", 300),
        pos_dim=kwargs.get("pos_dim", 50),
        max_distance=kwargs.get("max_distance", 128),
        dropout=kwargs.get("dropout", 0.1),
    )

    if resume_path and os.path.exists(resume_path):
        print(f"🔁 Loading SRL checkpoint from: {resume_path}")
        state = torch.load(resume_path, map_location="cpu")
        state_dict = state.get("model_state", state)
        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        print(f"  → missing: {len(missing)}, unexpected: {len(unexpected)}")

    if replace_encoder_with:
        print(f"🌍 Replacing encoder with: {replace_encoder_with}")
        from transformers import AutoModel
        model.bert = AutoModel.from_pretrained(replace_encoder_with)

    return model


# ==============================================================
# 4. Main
# ==============================================================
if __name__ == "__main__":
    # ------------------------------
    # ⚙️ Configuration
    # ------------------------------
    cfg = load_cfg_from_jsonnet()

    # read values from cfg as usual:
    conll_train_path = cfg["data"]["conll_train"]
    conll_valid_path   = cfg["data"].get("conll_valid")
    conll_test_path   = cfg["data"].get("conll_test")  
    word_col_idx     = cfg["data"]["word_col_idx"]
    srl_first_col_idx= cfg["data"]["srl_first_col_idx"]

    bert_name            = cfg["model"]["bert_name"]
    resume_from          = cfg["model"].get("resume_from")
    replace_encoder_with = cfg["model"].get("replace_encoder_with")
    tok_name = (cfg["model"].get("tokenizer", {}) or {}).get("name", replace_encoder_with or bert_name)

    out_dir         = cfg["training"]["out_dir"]
    num_epochs      = cfg["training"]["num_epochs"]
    batch_size      = cfg["training"]["batch_size"]
    lr              = cfg["training"]["lr"]
    weight_decay    = cfg["training"]["weight_decay"]
    grad_accum      = cfg["training"]["grad_accum_steps"]
    warmup_ratio    = cfg["training"]["warmup_ratio"]
    amp             = cfg["training"]["amp"]
    max_grad_norm   = cfg["training"]["max_grad_norm"]

    best_model_path    = cfg["training"]["best_model_path"]
    save_history_path  = cfg["training"]["save_history_path"]
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ------------------------------
    # 🧩 Tokenizer + data loading
    # ------------------------------
    tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name)
    print(f"Using tokenizer: {replace_encoder_with or bert_name}")

    # print(f"Loading multilingual CoNLL data: {conll_train_path}")


    # train_bf_loader, dev_bf_loader, test_bf_loader, label2id, id2label = \
    train_bf_loader, dev_bf_loader, label2id, id2label = \
        data_processing_for_loader_conll(
            train_conll=conll_train_path,
            dev_conll=conll_valid_path,
            # test_conll=conll_test_path,
            tokenizer=tokenizer,
            word_col_idx=word_col_idx,
            srl_first_col_idx=srl_first_col_idx,
            max_length=256,
        )

    # pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    pad_token_id = getattr(tokenizer, "pad_token_id", None)

    if pad_token_id is None:
        # prefer reusing an existing special token
        if getattr(tokenizer, "pad_token", None) is None:
            if getattr(tokenizer, "eos_token", None) is not None:
                tokenizer.pad_token = tokenizer.eos_token
            elif getattr(tokenizer, "sep_token", None) is not None:
                tokenizer.pad_token = tokenizer.sep_token
            else:
                # last resort: add a new PAD token (if you do this, resize embeddings after model init)
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        pad_token_id = tokenizer.pad_token_id or 0  # ensure int

    collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)

    train_loader = DataLoader(train_bf_loader, batch_size=batch_size, shuffle=True,  collate_fn=collate)
    dev_loader   = DataLoader(dev_bf_loader,   batch_size=batch_size, shuffle=False, collate_fn=collate) if dev_bf_loader else None
    # test_loader  = DataLoader(test_bf_loader,  batch_size=batch_size, shuffle=False, collate_fn=collate) if test_bf_loader else None

    # ------------------------------
    # 🧠 Model initialization
    # ------------------------------
    model = load_model(
        bert_name=bert_name,
        label2id=label2id,
        resume_path=resume_from,
        replace_encoder_with=replace_encoder_with,
        use_indicator=True,
        use_distance=True,
        indicator_dim=10,
        lstm_hidden=768,
        mlp_hidden=300,
        pos_dim=50,
        max_distance=128,
        dropout=0.1,
    ).to(device)

    # ------------------------------
    # 🔧 Optimizer + Scheduler
    # ------------------------------
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * num_epochs // max(1, grad_accum)
    warmup_steps = int(warmup_ratio * total_steps)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    # ------------------------------
    # 🏋️ Training Loop
    # ------------------------------
    history = {"epoch": [], "train_loss": [], "dev_loss": [], "dev_f1": []}
    best_dev, best_path = -1.0, "best_srl_fr.ckpt"

    for epoch in range(num_epochs):
        tr_loss = train_one_epoch(
            model, train_loader, optimizer, device=device,
            scheduler=scheduler, grad_accum_steps=grad_accum,
            amp=amp, max_grad_norm=max_grad_norm,
        )
        dev_loss, dev_f1 = eval_loss_and_token_f1(model, dev_loader, id2label, device=device)

        history["epoch"].append(epoch + 1)
        history["train_loss"].append(tr_loss)
        history["dev_loss"].append(dev_loss)
        history["dev_f1"].append(dev_f1)

        print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f}  dev_loss={dev_loss:.4f}  dev_F1={dev_f1:.4f}")

        if dev_f1 > best_dev:
            best_dev = dev_f1
            torch.save({"model_state": model.state_dict(), "label2id": label2id}, best_path)
            print(f"  ↳ new best dev; saved to {best_path}")

    save_pkl(history, "loss_history_fr.pkl")