manpreet88 commited on
Commit
58616ba
·
1 Parent(s): 35a7589

Create Transformer.py

Browse files
Files changed (1) hide show
  1. PolyFusion/Transformer.py +603 -0
PolyFusion/Transformer.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fingerprint_mlm_training.py
2
+ import os
3
+ import json
4
+ import time
5
+ import shutil
6
+ import sys
7
+ import csv
8
+
9
+ # Increase max CSV field size limit (some fingerprint fields can be long)
10
+ csv.field_size_limit(sys.maxsize)
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ import pandas as pd
17
+ from sklearn.model_selection import train_test_split
18
+ from torch.utils.data import Dataset, DataLoader
19
+
20
+ from transformers import TrainingArguments, Trainer
21
+ from transformers.trainer_callback import TrainerCallback
22
+ from sklearn.metrics import accuracy_score, f1_score
23
+ from typing import List
24
+
25
+ # ---------------------------
26
+ # Configuration / Constants
27
+ # ---------------------------
28
+ # MLM mask probability
29
+ P_MASK = 0.15
30
+
31
+ # Fingerprint specifics
32
+ FINGERPRINT_KEY = "morgan_r3_bits" # inside the JSON stored under 'fingerprints' column
33
+ FP_LENGTH = 2048 # expected fingerprint vector length (bits)
34
+ # Vocabulary: {0, 1, MASK} where 0/1 are real bits and MASK token id = 2 used as masked input
35
+ MASK_TOKEN_ID = 2
36
+ VOCAB_SIZE = 3
37
+
38
+ # Model / encoder hyperparams
39
+ HIDDEN_DIM = 256
40
+ TRANSFORMER_NUM_LAYERS = 4
41
+ TRANSFORMER_NHEAD = 8
42
+ TRANSFORMER_FF = 1024
43
+ DROPOUT = 0.1
44
+
45
+ # Training / data hyperparams
46
+ TRAIN_BATCH_SIZE = 16 # number of molecules per batch
47
+ EVAL_BATCH_SIZE = 8
48
+ GRADIENT_ACCUMULATION_STEPS = 4
49
+ NUM_EPOCHS = 25
50
+ LEARNING_RATE = 1e-4
51
+ WEIGHT_DECAY = 0.01
52
+
53
+ # File locations (changed as requested)
54
+ CSV_PATH = "Polymer_Foundational_Model/polymer_structures_unified_processed.csv"
55
+ OUTPUT_DIR = "./fingerprint_mlm_output_5M"
56
+ BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "best")
57
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
58
+
59
+ # ---------------------------
60
+ # 1. Load Data (chunked to avoid OOM) - read fingerprints column
61
+ # ---------------------------
62
+ TARGET_ROWS = 5000000
63
+ CHUNKSIZE = 50000
64
+
65
+ fp_lists: List[List[int]] = []
66
+ rows_read = 0
67
+
68
+ # Expect 'fingerprints' column value to be a JSON string we can json.loads()
69
+ # that contains e.g. {"morgan_r3_bits": ["0","1","0",...]}
70
+ for chunk in pd.read_csv(CSV_PATH, engine="python", chunksize=CHUNKSIZE):
71
+ # some files might already have parsed JSON-like dicts; ensure we handle strings
72
+ fps_chunk = chunk["fingerprints"]
73
+ for fpval in fps_chunk:
74
+ if pd.isna(fpval):
75
+ # skip or use zeros
76
+ fp_lists.append([0] * FP_LENGTH)
77
+ continue
78
+
79
+ # If it's already a dict-like object, use directly; else parse JSON string
80
+ if isinstance(fpval, str):
81
+ try:
82
+ fp_json = json.loads(fpval)
83
+ except Exception:
84
+ # fallback: try to fix common quoting issues
85
+ try:
86
+ fp_json = json.loads(fpval.replace("'", '"'))
87
+ except Exception:
88
+ # as last fallback, treat the string as a comma separated "0,1,0,..."
89
+ parts = [p.strip().strip('"').strip("'") for p in fpval.split(",")]
90
+ bits = [1 if p in ("1", "True", "true") else 0 for p in parts[:FP_LENGTH]]
91
+ if len(bits) < FP_LENGTH:
92
+ bits += [0] * (FP_LENGTH - len(bits))
93
+ fp_lists.append(bits)
94
+ continue
95
+ elif isinstance(fpval, dict):
96
+ fp_json = fpval
97
+ else:
98
+ # Unknown type, zero pad
99
+ fp_lists.append([0] * FP_LENGTH)
100
+ continue
101
+
102
+ # Extract the fingerprint bit list
103
+ bits = fp_json.get(FINGERPRINT_KEY, None)
104
+ if bits is None:
105
+ # fallback if top-level is already list
106
+ if isinstance(fp_json, list):
107
+ bits = fp_json
108
+ else:
109
+ # default zero vector
110
+ bits = [0] * FP_LENGTH
111
+
112
+ # bits may be list of strings "0"/"1" or ints
113
+ # normalize to ints and ensure length
114
+ normalized = []
115
+ for b in bits:
116
+ if isinstance(b, str):
117
+ b_clean = b.strip().strip('"').strip("'")
118
+ normalized.append(1 if b_clean in ("1", "True", "true") else 0)
119
+ elif isinstance(b, (int, np.integer)):
120
+ normalized.append(1 if int(b) != 0 else 0)
121
+ else:
122
+ normalized.append(0)
123
+ if len(normalized) >= FP_LENGTH:
124
+ break
125
+
126
+ if len(normalized) < FP_LENGTH:
127
+ # pad with zeros
128
+ normalized.extend([0] * (FP_LENGTH - len(normalized)))
129
+
130
+ fp_lists.append(normalized[:FP_LENGTH])
131
+
132
+ rows_read += len(chunk)
133
+ if rows_read >= TARGET_ROWS:
134
+ break
135
+
136
+ print(f"Loaded {len(fp_lists)} fingerprint vectors (using FP_LENGTH={FP_LENGTH}).")
137
+
138
+ # ---------------------------
139
+ # 2. Train/Val Split
140
+ # ---------------------------
141
+ train_idx, val_idx = train_test_split(list(range(len(fp_lists))), test_size=0.2, random_state=42)
142
+ train_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in train_idx]
143
+ val_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in val_idx]
144
+
145
+ # ---------------------------
146
+ # Compute class weights (for weighted CE to mitigate bit imbalance)
147
+ # (we compute but will not apply them to match previous MLM-style loss behavior)
148
+ # ---------------------------
149
+ # We'll compute weights for classes {0,1} only (targets).
150
+ counts = np.ones((2,), dtype=np.float64) # initialize with 1 to avoid zero division
151
+ for fp in train_fps:
152
+ vals = fp.cpu().numpy().astype(int)
153
+ counts[0] += np.sum(vals == 0)
154
+ counts[1] += np.sum(vals == 1)
155
+
156
+ freq = counts / counts.sum()
157
+ inv_freq = 1.0 / (freq + 1e-12)
158
+ class_weights_arr = inv_freq / inv_freq.mean()
159
+ class_weights = torch.tensor(class_weights_arr, dtype=torch.float) # shape [2]
160
+ print("Class weights (for bit 0 and bit 1):", class_weights.numpy())
161
+
162
+ # ---------------------------
163
+ # 3. Dataset and Collator (fingerprint MLM)
164
+ # ---------------------------
165
+ class FingerprintDataset(Dataset):
166
+ def __init__(self, fps: List[torch.Tensor]):
167
+ self.fps = fps
168
+
169
+ def __len__(self):
170
+ return len(self.fps)
171
+
172
+ def __getitem__(self, idx):
173
+ # Return the tensor directly (not wrapped in a dict). This avoids mismatches
174
+ # when HF's Trainer / collators pass around items in different formats.
175
+ return self.fps[idx]
176
+
177
+ def collate_batch(batch):
178
+ """
179
+ Collate a batch of fingerprint tensors into:
180
+ - z: [B, L] long, masked/corrupted input tokens (values 0,1, or MASK_TOKEN_ID)
181
+ - labels_z: [B, L] long, with -100 for unselected positions and 0/1 for masked positions (targets)
182
+ - attention_mask: [B, L] bool (all True here since fixed length)
183
+
184
+ This collator is defensive: it accepts
185
+ - list of torch.Tensors
186
+ - list of dicts containing key 'fp'
187
+ - HF-style list of dict-like items where a tensor-like value is present
188
+ """
189
+ B = len(batch)
190
+ if B == 0:
191
+ return {"z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
192
+ "labels_z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
193
+ "attention_mask": torch.zeros((0, FP_LENGTH), dtype=torch.bool)}
194
+
195
+ # Normalize items -> list of tensors
196
+ tensors = []
197
+ for item in batch:
198
+ if isinstance(item, torch.Tensor):
199
+ tensors.append(item)
200
+ elif isinstance(item, dict):
201
+ # Prefer 'fp' if present
202
+ if "fp" in item:
203
+ val = item["fp"]
204
+ if not isinstance(val, torch.Tensor):
205
+ val = torch.tensor(val, dtype=torch.long)
206
+ tensors.append(val)
207
+ else:
208
+ # Try to find any tensor-like value inside dict
209
+ found = None
210
+ for v in item.values():
211
+ if isinstance(v, torch.Tensor):
212
+ found = v
213
+ break
214
+ elif isinstance(v, np.ndarray):
215
+ found = torch.tensor(v, dtype=torch.long)
216
+ break
217
+ elif isinstance(v, list):
218
+ # possible list of ints
219
+ try:
220
+ found = torch.tensor(v, dtype=torch.long)
221
+ break
222
+ except Exception:
223
+ continue
224
+ if found is None:
225
+ raise KeyError("collate_batch: couldn't find 'fp' tensor in dataset item; item keys: {}".format(list(item.keys())))
226
+ tensors.append(found)
227
+ else:
228
+ # fallback: try to convert numpy/sequence to tensor
229
+ try:
230
+ tensors.append(torch.tensor(item, dtype=torch.long))
231
+ except Exception:
232
+ raise TypeError(f"collate_batch: unsupported batch item type: {type(item)}")
233
+
234
+ # Stack into [B, L]
235
+ all_inputs = torch.stack(tensors, dim=0).long() # [B, L], long (0/1)
236
+ device = all_inputs.device
237
+
238
+ # Prepare masks and labels
239
+ labels_z = torch.full_like(all_inputs, fill_value=-100, dtype=torch.long) # -100 ignored by CE
240
+ z_masked = all_inputs.clone()
241
+
242
+ for i in range(B):
243
+ z = all_inputs[i] # [L]
244
+ n_positions = z.size(0)
245
+ # select positions to supervise (mask) with probability P_MASK
246
+ is_selected = torch.rand(n_positions) < P_MASK
247
+
248
+ # ensure not all selected
249
+ if is_selected.all():
250
+ is_selected[torch.randint(0, n_positions, (1,))] = False
251
+
252
+ sel_idx = torch.nonzero(is_selected).squeeze(-1)
253
+ if sel_idx.numel() > 0:
254
+ labels_z[i, sel_idx] = z[sel_idx] # store true bit labels
255
+
256
+ # BERT-style corruption per selected position
257
+ probs = torch.rand(sel_idx.size(0))
258
+ mask_choice = probs < 0.8
259
+ rand_choice = (probs >= 0.8) & (probs < 0.9)
260
+ # keep_choice = probs >= 0.9
261
+
262
+ if mask_choice.any():
263
+ z_masked[i, sel_idx[mask_choice]] = MASK_TOKEN_ID # mask token id
264
+
265
+ if rand_choice.any():
266
+ # replace with random 0 or 1
267
+ rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long)
268
+ z_masked[i, sel_idx[rand_choice]] = rand_bits
269
+
270
+ # keep_choice -> leave original bit
271
+
272
+ attention_mask = torch.ones_like(all_inputs, dtype=torch.bool) # full attention (fixed length)
273
+
274
+ return {"z": z_masked, "labels_z": labels_z, "attention_mask": attention_mask}
275
+
276
+ train_dataset = FingerprintDataset(train_fps)
277
+ val_dataset = FingerprintDataset(val_fps)
278
+ train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_batch, drop_last=False)
279
+ val_loader = DataLoader(val_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=False, collate_fn=collate_batch, drop_last=False)
280
+
281
+ # ---------------------------
282
+ # 4. Model Definition (Fingerprint Encoder + MLM head)
283
+ # ---------------------------
284
+
285
+ class FingerprintEncoder(nn.Module):
286
+ """
287
+ Simple encoder for fingerprint token sequences:
288
+ - token embedding (vocab size VOCAB_SIZE)
289
+ - positional embedding
290
+ - Transformer encoder stack
291
+ - returns per-position embeddings [B, L, hidden_dim]
292
+ """
293
+ def __init__(self, vocab_size=VOCAB_SIZE, hidden_dim=HIDDEN_DIM, seq_len=FP_LENGTH,
294
+ num_layers=TRANSFORMER_NUM_LAYERS, nhead=TRANSFORMER_NHEAD, dim_feedforward=TRANSFORMER_FF,
295
+ dropout=DROPOUT):
296
+ super().__init__()
297
+ self.token_emb = nn.Embedding(vocab_size, hidden_dim)
298
+ self.pos_emb = nn.Embedding(seq_len, hidden_dim)
299
+ encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
300
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
301
+ self.hidden_dim = hidden_dim
302
+ self.seq_len = seq_len
303
+
304
+ def forward(self, input_ids, attention_mask=None):
305
+ """
306
+ input_ids: [B, L] long (values 0,1, or MASK_TOKEN_ID)
307
+ attention_mask: [B, L] bool (True for valid positions)
308
+ returns: embeddings [B, L, hidden_dim]
309
+ """
310
+ B, L = input_ids.shape
311
+ x = self.token_emb(input_ids) # [B, L, hidden]
312
+ # positional indices 0..L-1 broadcast to batch
313
+ pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
314
+ x = x + self.pos_emb(pos_ids)
315
+ # transformer expects batch_first=True (we set that)
316
+ if attention_mask is not None:
317
+ # transformer encoder in PyTorch doesn't use attention_mask in same way as HF; provide key_padding_mask
318
+ key_padding_mask = ~attention_mask # True where to mask
319
+ else:
320
+ key_padding_mask = None
321
+
322
+ out = self.transformer(x, src_key_padding_mask=key_padding_mask)
323
+ return out # [B, L, hidden_dim]
324
+
325
+
326
+ class MaskedFingerprintModel(nn.Module):
327
+ """
328
+ Encoder + MLM head for fingerprint masked language modeling.
329
+ MLM head predicts over VOCAB_SIZE (0,1,MASK) like a token classification over the small vocab.
330
+ Loss is standard CrossEntropyLoss (ignore_index=-100) computed only on masked positions,
331
+ matching the "MLM with CrossEntropy" behavior used in the DebertaV2ForMaskedLM setup.
332
+ """
333
+ def __init__(self, hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE):
334
+ super().__init__()
335
+ self.encoder = FingerprintEncoder(vocab_size=vocab_size, hidden_dim=hidden_dim)
336
+ # MLM head: predict logits over the small token vocabulary {0,1,MASK}
337
+ self.mlm_head = nn.Linear(hidden_dim, vocab_size)
338
+
339
+ def forward(self, z, attention_mask=None, labels_z=None):
340
+ """
341
+ z: [B, L] long inputs (0/1/MASK_TOKEN_ID)
342
+ labels_z: [B, L] long with -100 for unselected positions, else 0/1 targets
343
+ Returns:
344
+ - if labels_z provided -> loss (scalar tensor)
345
+ - else -> logits [B, L, VOCAB_SIZE]
346
+ """
347
+ embeddings = self.encoder(z, attention_mask=attention_mask) # [B, L, hidden]
348
+ logits = self.mlm_head(embeddings) # [B, L, VOCAB_SIZE]
349
+
350
+ if labels_z is not None:
351
+ mask = labels_z != -100 # [B, L]
352
+ if mask.sum() == 0:
353
+ # return zero loss tensor on same device
354
+ return torch.tensor(0.0, device=z.device)
355
+
356
+ logits_masked = logits[mask] # [M, VOCAB_SIZE]
357
+ labels_masked = labels_z[mask] # [M] values in {0,1}
358
+
359
+ # standard cross-entropy over the vocabulary (no class weighting, matching previous Deberta MLM behavior)
360
+ # labels_masked must be long
361
+ labels_masked = labels_masked.long()
362
+ loss_z = F.cross_entropy(logits_masked, labels_masked)
363
+
364
+ return loss_z
365
+
366
+ # inference -> return logits
367
+ return logits
368
+
369
+ # instantiate model using MLM-style head and standard cross-entropy loss (no learned weighting/class-weights)
370
+ model = MaskedFingerprintModel(hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE)
371
+
372
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
373
+ model.to(device)
374
+
375
+ # ---------------------------
376
+ # 5. Training Setup (Hugging Face Trainer)
377
+ # ---------------------------
378
+ training_args = TrainingArguments(
379
+ output_dir=OUTPUT_DIR,
380
+ overwrite_output_dir=True,
381
+ num_train_epochs=NUM_EPOCHS,
382
+ per_device_train_batch_size=TRAIN_BATCH_SIZE,
383
+ per_device_eval_batch_size=EVAL_BATCH_SIZE,
384
+ eval_accumulation_steps=1000, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
385
+ eval_strategy="epoch",
386
+ logging_steps=500,
387
+ learning_rate=LEARNING_RATE,
388
+ weight_decay=WEIGHT_DECAY,
389
+ fp16=torch.cuda.is_available(),
390
+ save_strategy="no", # callback will save best model
391
+ disable_tqdm=False,
392
+ logging_first_step=True,
393
+ report_to=[],
394
+ # NOTE: set to 0 to avoid DataLoader worker pickling/collate issues in some environments.
395
+ dataloader_num_workers=0,
396
+ )
397
+
398
+ class ValLossCallback(TrainerCallback):
399
+ def __init__(self, trainer_ref=None):
400
+ self.best_val_loss = float("inf")
401
+ self.epochs_no_improve = 0
402
+ self.patience = 10
403
+ self.best_epoch = None
404
+ self.trainer_ref = trainer_ref
405
+
406
+ def on_epoch_end(self, args, state, control, **kwargs):
407
+ epoch_num = int(state.epoch)
408
+ train_loss = next((x["loss"] for x in reversed(state.log_history) if "loss" in x), None)
409
+ print(f"\n=== Epoch {epoch_num}/{args.num_train_epochs} ===")
410
+ if train_loss is not None:
411
+ print(f"Train Loss: {train_loss:.4f}")
412
+
413
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
414
+ epoch_num = int(state.epoch) + 1
415
+
416
+ if self.trainer_ref is None:
417
+ print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
418
+ return
419
+
420
+ metric_val_loss = None
421
+ if metrics is not None:
422
+ metric_val_loss = metrics.get("eval_loss")
423
+
424
+ model_eval = self.trainer_ref.model
425
+ model_eval.eval()
426
+
427
+ device_local = next(model_eval.parameters()).device if any(p.numel() > 0 for p in model_eval.parameters()) else torch.device("cpu")
428
+
429
+ preds_bits = []
430
+ true_bits = []
431
+ total_loss = 0.0
432
+ n_batches = 0
433
+
434
+ logits_masked_list = []
435
+ labels_masked_list = []
436
+
437
+ with torch.no_grad():
438
+ for batch in val_loader:
439
+ z = batch["z"].to(device_local) # [B, L]
440
+ labels_z = batch["labels_z"].to(device_local)
441
+ attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device_local)
442
+
443
+ # compute loss if possible (model returns scalar loss when labels_z provided)
444
+ try:
445
+ loss = model_eval(z, attention_mask=attention_mask, labels_z=labels_z)
446
+ except Exception as e:
447
+ loss = None
448
+
449
+ if isinstance(loss, torch.Tensor):
450
+ total_loss += loss.item()
451
+ n_batches += 1
452
+
453
+ logits = model_eval(z, attention_mask=attention_mask) # [B, L, VOCAB_SIZE]
454
+
455
+ mask = labels_z != -100
456
+ if mask.sum().item() == 0:
457
+ continue
458
+
459
+ logits_masked_list.append(logits[mask])
460
+ labels_masked_list.append(labels_z[mask])
461
+
462
+ pred_bits = torch.argmax(logits[mask], dim=-1)
463
+ true_b = labels_z[mask]
464
+
465
+ preds_bits.extend(pred_bits.cpu().tolist())
466
+ true_bits.extend(true_b.cpu().tolist())
467
+
468
+ avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan"))
469
+
470
+ accuracy = accuracy_score(true_bits, preds_bits) if len(true_bits) > 0 else 0.0
471
+ f1 = f1_score(true_bits, preds_bits, average="weighted") if len(true_bits) > 0 else 0.0
472
+
473
+ # perplexity from masked-token cross-entropy (computed over masked positions only)
474
+ if len(logits_masked_list) > 0:
475
+ all_logits_masked = torch.cat(logits_masked_list, dim=0)
476
+ all_labels_masked = torch.cat(labels_masked_list, dim=0)
477
+ # match previous MLM: standard cross-entropy over the vocabulary
478
+ loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked.long())
479
+ try:
480
+ perplexity = float(torch.exp(loss_z_all).cpu().item())
481
+ except Exception:
482
+ perplexity = float(np.exp(float(loss_z_all.cpu().item())))
483
+ else:
484
+ perplexity = float("nan")
485
+
486
+ print(f"\n--- Evaluation after Epoch {epoch_num} ---")
487
+ print(f"Validation Loss: {avg_val_loss:.4f}")
488
+ print(f"Validation Accuracy: {accuracy:.4f}")
489
+ print(f"Validation F1 (weighted): {f1:.4f}")
490
+ print(f"Validation Perplexity (classification head): {perplexity:.4f}")
491
+
492
+ # Check for improvement
493
+ if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
494
+ self.best_val_loss = avg_val_loss
495
+ self.best_epoch = int(state.epoch)
496
+ self.epochs_no_improve = 0
497
+ os.makedirs(BEST_MODEL_DIR, exist_ok=True)
498
+ try:
499
+ torch.save(self.trainer_ref.model.state_dict(), os.path.join(BEST_MODEL_DIR, "pytorch_model.bin"))
500
+ print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(BEST_MODEL_DIR, 'pytorch_model.bin')}")
501
+ except Exception as e:
502
+ print(f"Failed to save best model at epoch {epoch_num}: {e}")
503
+ else:
504
+ self.epochs_no_improve += 1
505
+
506
+ if self.epochs_no_improve >= self.patience:
507
+ print(f"Early stopping after {self.patience} epochs with no improvement.")
508
+ control.should_training_stop = True
509
+
510
+
511
+ # Create callback and Trainer
512
+ callback = ValLossCallback()
513
+ trainer = Trainer(
514
+ model=model,
515
+ args=training_args,
516
+ train_dataset=train_dataset,
517
+ eval_dataset=val_dataset,
518
+ data_collator=collate_batch,
519
+ callbacks=[callback]
520
+ )
521
+ callback.trainer_ref = trainer
522
+
523
+ # ---------------------------
524
+ # 6. Run training
525
+ # ---------------------------
526
+ start_time = time.time()
527
+ trainer.train()
528
+ total_time = time.time() - start_time
529
+
530
+ # ---------------------------
531
+ # 7. Final Evaluation (evaluate best saved model on validation set)
532
+ # ---------------------------
533
+
534
+ best_model_path = os.path.join(BEST_MODEL_DIR, "pytorch_model.bin")
535
+ if os.path.exists(best_model_path):
536
+ try:
537
+ model.load_state_dict(torch.load(best_model_path, map_location=device))
538
+ print(f"\nLoaded best model from {best_model_path}")
539
+ except Exception as e:
540
+ print(f"\nFailed to load best model from {best_model_path}: {e}")
541
+
542
+ model.eval()
543
+ preds_bits_all = []
544
+ true_bits_all = []
545
+ logits_masked_final = []
546
+ labels_masked_final = []
547
+
548
+ with torch.no_grad():
549
+ for batch in val_loader:
550
+ z = batch["z"].to(device)
551
+ labels_z = batch["labels_z"].to(device)
552
+ attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device)
553
+
554
+ logits = model(z, attention_mask=attention_mask) # [B, L, VOCAB_SIZE]
555
+
556
+ mask = labels_z != -100
557
+ if mask.sum().item() == 0:
558
+ continue
559
+
560
+ logits_masked_final.append(logits[mask])
561
+ labels_masked_final.append(labels_z[mask])
562
+
563
+ pred_bits = torch.argmax(logits[mask], dim=-1)
564
+ true_b = labels_z[mask]
565
+
566
+ preds_bits_all.extend(pred_bits.cpu().tolist())
567
+ true_bits_all.extend(true_b.cpu().tolist())
568
+
569
+ accuracy = accuracy_score(true_bits_all, preds_bits_all) if len(true_bits_all) > 0 else 0.0
570
+ f1 = f1_score(true_bits_all, preds_bits_all, average="weighted") if len(true_bits_all) > 0 else 0.0
571
+
572
+ if len(logits_masked_final) > 0:
573
+ all_logits_masked_final = torch.cat(logits_masked_final, dim=0)
574
+ all_labels_masked_final = torch.cat(labels_masked_final, dim=0)
575
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final.long())
576
+ try:
577
+ perplexity_final = float(torch.exp(loss_z_final).cpu().item())
578
+ except Exception:
579
+ perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
580
+ else:
581
+ perplexity_final = float("nan")
582
+
583
+ best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
584
+ best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
585
+
586
+ print(f"\n=== Final Results (evaluated on best saved model) ===")
587
+ print(f"Total Training Time (s): {total_time:.2f}")
588
+ if best_epoch_num is not None:
589
+ print(f"Best Epoch (1-based): {best_epoch_num}")
590
+ else:
591
+ print("Best Epoch: (none saved)")
592
+
593
+ print(f"Best Validation Loss: {best_val_loss:.4f}")
594
+ print(f"Validation Accuracy: {accuracy:.4f}")
595
+ print(f"Validation F1 (weighted): {f1:.4f}")
596
+ print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
597
+
598
+ total_params = sum(p.numel() for p in model.parameters())
599
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
600
+ non_trainable_params = total_params - trainable_params
601
+ print(f"Total Parameters: {total_params}")
602
+ print(f"Trainable Parameters: {trainable_params}")
603
+ print(f"Non-trainable Parameters: {non_trainable_params}")