manpreet88 commited on
Commit
bcbfd4b
·
1 Parent(s): ebd0f34

Update Transformer.py

Browse files
Files changed (1) hide show
  1. PolyFusion/Transformer.py +275 -366
PolyFusion/Transformer.py CHANGED
@@ -1,12 +1,16 @@
1
- # fingerprint_mlm_training.py
 
 
 
2
  import os
3
  import json
4
  import time
5
- import shutil
6
  import sys
7
  import csv
 
 
8
 
9
- # Increase max CSV field size limit (some fingerprint fields can be long)
10
  csv.field_size_limit(sys.maxsize)
11
 
12
  import torch
@@ -20,149 +24,112 @@ from torch.utils.data import Dataset, DataLoader
20
  from transformers import TrainingArguments, Trainer
21
  from transformers.trainer_callback import TrainerCallback
22
  from sklearn.metrics import accuracy_score, f1_score
23
- from typing import List
24
 
25
  # ---------------------------
26
  # Configuration / Constants
27
  # ---------------------------
28
- # MLM mask probability
29
  P_MASK = 0.15
30
 
31
- # Fingerprint specifics
32
- FINGERPRINT_KEY = "morgan_r3_bits" # inside the JSON stored under 'fingerprints' column
33
- FP_LENGTH = 2048 # expected fingerprint vector length (bits)
34
- # Vocabulary: {0, 1, MASK} where 0/1 are real bits and MASK token id = 2 used as masked input
35
  MASK_TOKEN_ID = 2
36
  VOCAB_SIZE = 3
37
 
38
- # Model / encoder hyperparams
39
  HIDDEN_DIM = 256
40
  TRANSFORMER_NUM_LAYERS = 4
41
  TRANSFORMER_NHEAD = 8
42
  TRANSFORMER_FF = 1024
43
  DROPOUT = 0.1
44
 
45
- # Training / data hyperparams
46
- TRAIN_BATCH_SIZE = 16 # number of molecules per batch
47
  EVAL_BATCH_SIZE = 8
48
  GRADIENT_ACCUMULATION_STEPS = 4
49
  NUM_EPOCHS = 25
50
  LEARNING_RATE = 1e-4
51
  WEIGHT_DECAY = 0.01
52
 
53
- # File locations (changed as requested)
54
- CSV_PATH = "Polymer_Foundational_Model/polymer_structures_unified_processed.csv"
55
- OUTPUT_DIR = "./fingerprint_mlm_output_5M"
56
- BEST_MODEL_DIR = os.path.join(OUTPUT_DIR, "best")
57
- os.makedirs(OUTPUT_DIR, exist_ok=True)
58
 
59
- # ---------------------------
60
- # 1. Load Data (chunked to avoid OOM) - read fingerprints column
61
- # ---------------------------
62
- TARGET_ROWS = 5000000
63
- CHUNKSIZE = 50000
64
-
65
- fp_lists: List[List[int]] = []
66
- rows_read = 0
67
-
68
- # Expect 'fingerprints' column value to be a JSON string we can json.loads()
69
- # that contains e.g. {"morgan_r3_bits": ["0","1","0",...]}
70
- for chunk in pd.read_csv(CSV_PATH, engine="python", chunksize=CHUNKSIZE):
71
- # some files might already have parsed JSON-like dicts; ensure we handle strings
72
- fps_chunk = chunk["fingerprints"]
73
- for fpval in fps_chunk:
74
- if pd.isna(fpval):
75
- # skip or use zeros
76
- fp_lists.append([0] * FP_LENGTH)
77
- continue
78
-
79
- # If it's already a dict-like object, use directly; else parse JSON string
80
- if isinstance(fpval, str):
81
- try:
82
- fp_json = json.loads(fpval)
83
- except Exception:
84
- # fallback: try to fix common quoting issues
 
 
85
  try:
86
- fp_json = json.loads(fpval.replace("'", '"'))
87
  except Exception:
88
- # as last fallback, treat the string as a comma separated "0,1,0,..."
89
- parts = [p.strip().strip('"').strip("'") for p in fpval.split(",")]
90
- bits = [1 if p in ("1", "True", "true") else 0 for p in parts[:FP_LENGTH]]
91
- if len(bits) < FP_LENGTH:
92
- bits += [0] * (FP_LENGTH - len(bits))
93
- fp_lists.append(bits)
94
- continue
95
- elif isinstance(fpval, dict):
96
- fp_json = fpval
97
- else:
98
- # Unknown type, zero pad
99
- fp_lists.append([0] * FP_LENGTH)
100
- continue
101
-
102
- # Extract the fingerprint bit list
103
- bits = fp_json.get(FINGERPRINT_KEY, None)
104
- if bits is None:
105
- # fallback if top-level is already list
106
- if isinstance(fp_json, list):
107
- bits = fp_json
108
  else:
109
- # default zero vector
110
- bits = [0] * FP_LENGTH
111
-
112
- # bits may be list of strings "0"/"1" or ints
113
- # normalize to ints and ensure length
114
- normalized = []
115
- for b in bits:
116
- if isinstance(b, str):
117
- b_clean = b.strip().strip('"').strip("'")
118
- normalized.append(1 if b_clean in ("1", "True", "true") else 0)
119
- elif isinstance(b, (int, np.integer)):
120
- normalized.append(1 if int(b) != 0 else 0)
121
- else:
122
- normalized.append(0)
123
- if len(normalized) >= FP_LENGTH:
124
- break
125
 
126
- if len(normalized) < FP_LENGTH:
127
- # pad with zeros
128
- normalized.extend([0] * (FP_LENGTH - len(normalized)))
 
 
 
129
 
130
- fp_lists.append(normalized[:FP_LENGTH])
 
 
 
 
 
 
 
 
 
 
131
 
132
- rows_read += len(chunk)
133
- if rows_read >= TARGET_ROWS:
134
- break
135
 
136
- print(f"Loaded {len(fp_lists)} fingerprint vectors (using FP_LENGTH={FP_LENGTH}).")
137
 
138
- # ---------------------------
139
- # 2. Train/Val Split
140
- # ---------------------------
141
- train_idx, val_idx = train_test_split(list(range(len(fp_lists))), test_size=0.2, random_state=42)
142
- train_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in train_idx]
143
- val_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in val_idx]
144
 
145
- # ---------------------------
146
- # Compute class weights (for weighted CE to mitigate bit imbalance)
147
- # (we compute but will not apply them to match previous MLM-style loss behavior)
148
- # ---------------------------
149
- # We'll compute weights for classes {0,1} only (targets).
150
- counts = np.ones((2,), dtype=np.float64) # initialize with 1 to avoid zero division
151
- for fp in train_fps:
152
- vals = fp.cpu().numpy().astype(int)
153
- counts[0] += np.sum(vals == 0)
154
- counts[1] += np.sum(vals == 1)
155
-
156
- freq = counts / counts.sum()
157
- inv_freq = 1.0 / (freq + 1e-12)
158
- class_weights_arr = inv_freq / inv_freq.mean()
159
- class_weights = torch.tensor(class_weights_arr, dtype=torch.float) # shape [2]
160
- print("Class weights (for bit 0 and bit 1):", class_weights.numpy())
161
 
162
- # ---------------------------
163
- # 3. Dataset and Collator (fingerprint MLM)
164
- # ---------------------------
165
  class FingerprintDataset(Dataset):
 
166
  def __init__(self, fps: List[torch.Tensor]):
167
  self.fps = fps
168
 
@@ -170,42 +137,35 @@ class FingerprintDataset(Dataset):
170
  return len(self.fps)
171
 
172
  def __getitem__(self, idx):
173
- # Return the tensor directly (not wrapped in a dict). This avoids mismatches
174
- # when HF's Trainer / collators pass around items in different formats.
175
  return self.fps[idx]
176
 
 
177
  def collate_batch(batch):
178
  """
179
- Collate a batch of fingerprint tensors into:
180
- - z: [B, L] long, masked/corrupted input tokens (values 0,1, or MASK_TOKEN_ID)
181
- - labels_z: [B, L] long, with -100 for unselected positions and 0/1 for masked positions (targets)
182
- - attention_mask: [B, L] bool (all True here since fixed length)
183
-
184
- This collator is defensive: it accepts
185
- - list of torch.Tensors
186
- - list of dicts containing key 'fp'
187
- - HF-style list of dict-like items where a tensor-like value is present
188
  """
189
  B = len(batch)
190
  if B == 0:
191
- return {"z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
192
- "labels_z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
193
- "attention_mask": torch.zeros((0, FP_LENGTH), dtype=torch.bool)}
 
 
194
 
195
- # Normalize items -> list of tensors
196
  tensors = []
197
  for item in batch:
198
  if isinstance(item, torch.Tensor):
199
  tensors.append(item)
200
  elif isinstance(item, dict):
201
- # Prefer 'fp' if present
202
  if "fp" in item:
203
  val = item["fp"]
204
  if not isinstance(val, torch.Tensor):
205
  val = torch.tensor(val, dtype=torch.long)
206
  tensors.append(val)
207
  else:
208
- # Try to find any tensor-like value inside dict
209
  found = None
210
  for v in item.values():
211
  if isinstance(v, torch.Tensor):
@@ -215,193 +175,106 @@ def collate_batch(batch):
215
  found = torch.tensor(v, dtype=torch.long)
216
  break
217
  elif isinstance(v, list):
218
- # possible list of ints
219
  try:
220
  found = torch.tensor(v, dtype=torch.long)
221
  break
222
  except Exception:
223
  continue
224
  if found is None:
225
- raise KeyError("collate_batch: couldn't find 'fp' tensor in dataset item; item keys: {}".format(list(item.keys())))
226
  tensors.append(found)
227
  else:
228
- # fallback: try to convert numpy/sequence to tensor
229
- try:
230
- tensors.append(torch.tensor(item, dtype=torch.long))
231
- except Exception:
232
- raise TypeError(f"collate_batch: unsupported batch item type: {type(item)}")
233
-
234
- # Stack into [B, L]
235
- all_inputs = torch.stack(tensors, dim=0).long() # [B, L], long (0/1)
236
- device = all_inputs.device
237
 
238
- # Prepare masks and labels
239
- labels_z = torch.full_like(all_inputs, fill_value=-100, dtype=torch.long) # -100 ignored by CE
240
  z_masked = all_inputs.clone()
241
 
242
  for i in range(B):
243
- z = all_inputs[i] # [L]
244
  n_positions = z.size(0)
245
- # select positions to supervise (mask) with probability P_MASK
246
  is_selected = torch.rand(n_positions) < P_MASK
247
-
248
- # ensure not all selected
249
  if is_selected.all():
250
  is_selected[torch.randint(0, n_positions, (1,))] = False
251
 
252
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
253
  if sel_idx.numel() > 0:
254
- labels_z[i, sel_idx] = z[sel_idx] # store true bit labels
255
 
256
- # BERT-style corruption per selected position
257
  probs = torch.rand(sel_idx.size(0))
258
  mask_choice = probs < 0.8
259
  rand_choice = (probs >= 0.8) & (probs < 0.9)
260
- # keep_choice = probs >= 0.9
261
 
262
  if mask_choice.any():
263
- z_masked[i, sel_idx[mask_choice]] = MASK_TOKEN_ID # mask token id
264
-
265
  if rand_choice.any():
266
- # replace with random 0 or 1
267
  rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long)
268
  z_masked[i, sel_idx[rand_choice]] = rand_bits
269
 
270
- # keep_choice -> leave original bit
271
-
272
- attention_mask = torch.ones_like(all_inputs, dtype=torch.bool) # full attention (fixed length)
273
-
274
  return {"z": z_masked, "labels_z": labels_z, "attention_mask": attention_mask}
275
 
276
- train_dataset = FingerprintDataset(train_fps)
277
- val_dataset = FingerprintDataset(val_fps)
278
- train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_batch, drop_last=False)
279
- val_loader = DataLoader(val_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=False, collate_fn=collate_batch, drop_last=False)
280
-
281
- # ---------------------------
282
- # 4. Model Definition (Fingerprint Encoder + MLM head)
283
- # ---------------------------
284
 
285
  class FingerprintEncoder(nn.Module):
286
- """
287
- Simple encoder for fingerprint token sequences:
288
- - token embedding (vocab size VOCAB_SIZE)
289
- - positional embedding
290
- - Transformer encoder stack
291
- - returns per-position embeddings [B, L, hidden_dim]
292
- """
293
  def __init__(self, vocab_size=VOCAB_SIZE, hidden_dim=HIDDEN_DIM, seq_len=FP_LENGTH,
294
  num_layers=TRANSFORMER_NUM_LAYERS, nhead=TRANSFORMER_NHEAD, dim_feedforward=TRANSFORMER_FF,
295
  dropout=DROPOUT):
296
  super().__init__()
297
  self.token_emb = nn.Embedding(vocab_size, hidden_dim)
298
  self.pos_emb = nn.Embedding(seq_len, hidden_dim)
299
- encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
 
 
 
 
 
 
300
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
301
- self.hidden_dim = hidden_dim
302
- self.seq_len = seq_len
303
 
304
  def forward(self, input_ids, attention_mask=None):
305
- """
306
- input_ids: [B, L] long (values 0,1, or MASK_TOKEN_ID)
307
- attention_mask: [B, L] bool (True for valid positions)
308
- returns: embeddings [B, L, hidden_dim]
309
- """
310
  B, L = input_ids.shape
311
- x = self.token_emb(input_ids) # [B, L, hidden]
312
- # positional indices 0..L-1 broadcast to batch
313
  pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
314
  x = x + self.pos_emb(pos_ids)
315
- # transformer expects batch_first=True (we set that)
316
- if attention_mask is not None:
317
- # transformer encoder in PyTorch doesn't use attention_mask in same way as HF; provide key_padding_mask
318
- key_padding_mask = ~attention_mask # True where to mask
319
- else:
320
- key_padding_mask = None
321
 
322
- out = self.transformer(x, src_key_padding_mask=key_padding_mask)
323
- return out # [B, L, hidden_dim]
324
 
325
 
326
  class MaskedFingerprintModel(nn.Module):
327
- """
328
- Encoder + MLM head for fingerprint masked language modeling.
329
- MLM head predicts over VOCAB_SIZE (0,1,MASK) like a token classification over the small vocab.
330
- Loss is standard CrossEntropyLoss (ignore_index=-100) computed only on masked positions,
331
- matching the "MLM with CrossEntropy" behavior used in the DebertaV2ForMaskedLM setup.
332
- """
333
  def __init__(self, hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE):
334
  super().__init__()
335
  self.encoder = FingerprintEncoder(vocab_size=vocab_size, hidden_dim=hidden_dim)
336
- # MLM head: predict logits over the small token vocabulary {0,1,MASK}
337
  self.mlm_head = nn.Linear(hidden_dim, vocab_size)
338
 
339
  def forward(self, z, attention_mask=None, labels_z=None):
340
- """
341
- z: [B, L] long inputs (0/1/MASK_TOKEN_ID)
342
- labels_z: [B, L] long with -100 for unselected positions, else 0/1 targets
343
- Returns:
344
- - if labels_z provided -> loss (scalar tensor)
345
- - else -> logits [B, L, VOCAB_SIZE]
346
- """
347
- embeddings = self.encoder(z, attention_mask=attention_mask) # [B, L, hidden]
348
- logits = self.mlm_head(embeddings) # [B, L, VOCAB_SIZE]
349
 
350
  if labels_z is not None:
351
- mask = labels_z != -100 # [B, L]
352
  if mask.sum() == 0:
353
- # return zero loss tensor on same device
354
  return torch.tensor(0.0, device=z.device)
355
 
356
- logits_masked = logits[mask] # [M, VOCAB_SIZE]
357
- labels_masked = labels_z[mask] # [M] values in {0,1}
358
-
359
- # standard cross-entropy over the vocabulary (no class weighting, matching previous Deberta MLM behavior)
360
- # labels_masked must be long
361
- labels_masked = labels_masked.long()
362
- loss_z = F.cross_entropy(logits_masked, labels_masked)
363
-
364
- return loss_z
365
 
366
- # inference -> return logits
367
  return logits
368
 
369
- # instantiate model using MLM-style head and standard cross-entropy loss (no learned weighting/class-weights)
370
- model = MaskedFingerprintModel(hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE)
371
-
372
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
373
- model.to(device)
374
-
375
- # ---------------------------
376
- # 5. Training Setup (Hugging Face Trainer)
377
- # ---------------------------
378
- training_args = TrainingArguments(
379
- output_dir=OUTPUT_DIR,
380
- overwrite_output_dir=True,
381
- num_train_epochs=NUM_EPOCHS,
382
- per_device_train_batch_size=TRAIN_BATCH_SIZE,
383
- per_device_eval_batch_size=EVAL_BATCH_SIZE,
384
- eval_accumulation_steps=1000, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
385
- eval_strategy="epoch",
386
- logging_steps=500,
387
- learning_rate=LEARNING_RATE,
388
- weight_decay=WEIGHT_DECAY,
389
- fp16=torch.cuda.is_available(),
390
- save_strategy="no", # callback will save best model
391
- disable_tqdm=False,
392
- logging_first_step=True,
393
- report_to=[],
394
- # NOTE: set to 0 to avoid DataLoader worker pickling/collate issues in some environments.
395
- dataloader_num_workers=0,
396
- )
397
 
398
  class ValLossCallback(TrainerCallback):
399
- def __init__(self, trainer_ref=None):
 
400
  self.best_val_loss = float("inf")
401
  self.epochs_no_improve = 0
402
- self.patience = 10
403
  self.best_epoch = None
404
  self.trainer_ref = trainer_ref
 
 
405
 
406
  def on_epoch_end(self, args, state, control, **kwargs):
407
  epoch_num = int(state.epoch)
@@ -412,45 +285,36 @@ class ValLossCallback(TrainerCallback):
412
 
413
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
414
  epoch_num = int(state.epoch) + 1
415
-
416
  if self.trainer_ref is None:
417
  print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
418
  return
419
 
420
- metric_val_loss = None
421
- if metrics is not None:
422
- metric_val_loss = metrics.get("eval_loss")
423
 
424
  model_eval = self.trainer_ref.model
425
  model_eval.eval()
 
426
 
427
- device_local = next(model_eval.parameters()).device if any(p.numel() > 0 for p in model_eval.parameters()) else torch.device("cpu")
428
-
429
- preds_bits = []
430
- true_bits = []
431
- total_loss = 0.0
432
- n_batches = 0
433
-
434
- logits_masked_list = []
435
- labels_masked_list = []
436
 
437
  with torch.no_grad():
438
- for batch in val_loader:
439
- z = batch["z"].to(device_local) # [B, L]
440
  labels_z = batch["labels_z"].to(device_local)
441
  attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device_local)
442
 
443
- # compute loss if possible (model returns scalar loss when labels_z provided)
444
  try:
445
  loss = model_eval(z, attention_mask=attention_mask, labels_z=labels_z)
446
- except Exception as e:
447
  loss = None
448
 
449
  if isinstance(loss, torch.Tensor):
450
  total_loss += loss.item()
451
  n_batches += 1
452
 
453
- logits = model_eval(z, attention_mask=attention_mask) # [B, L, VOCAB_SIZE]
454
 
455
  mask = labels_z != -100
456
  if mask.sum().item() == 0:
@@ -466,15 +330,12 @@ class ValLossCallback(TrainerCallback):
466
  true_bits.extend(true_b.cpu().tolist())
467
 
468
  avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan"))
469
-
470
  accuracy = accuracy_score(true_bits, preds_bits) if len(true_bits) > 0 else 0.0
471
  f1 = f1_score(true_bits, preds_bits, average="weighted") if len(true_bits) > 0 else 0.0
472
 
473
- # perplexity from masked-token cross-entropy (computed over masked positions only)
474
  if len(logits_masked_list) > 0:
475
  all_logits_masked = torch.cat(logits_masked_list, dim=0)
476
  all_labels_masked = torch.cat(labels_masked_list, dim=0)
477
- # match previous MLM: standard cross-entropy over the vocabulary
478
  loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked.long())
479
  try:
480
  perplexity = float(torch.exp(loss_z_all).cpu().item())
@@ -489,15 +350,14 @@ class ValLossCallback(TrainerCallback):
489
  print(f"Validation F1 (weighted): {f1:.4f}")
490
  print(f"Validation Perplexity (classification head): {perplexity:.4f}")
491
 
492
- # Check for improvement
493
  if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
494
  self.best_val_loss = avg_val_loss
495
  self.best_epoch = int(state.epoch)
496
  self.epochs_no_improve = 0
497
- os.makedirs(BEST_MODEL_DIR, exist_ok=True)
498
  try:
499
- torch.save(self.trainer_ref.model.state_dict(), os.path.join(BEST_MODEL_DIR, "pytorch_model.bin"))
500
- print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(BEST_MODEL_DIR, 'pytorch_model.bin')}")
501
  except Exception as e:
502
  print(f"Failed to save best model at epoch {epoch_num}: {e}")
503
  else:
@@ -508,96 +368,145 @@ class ValLossCallback(TrainerCallback):
508
  control.should_training_stop = True
509
 
510
 
511
- # Create callback and Trainer
512
- callback = ValLossCallback()
513
- trainer = Trainer(
514
- model=model,
515
- args=training_args,
516
- train_dataset=train_dataset,
517
- eval_dataset=val_dataset,
518
- data_collator=collate_batch,
519
- callbacks=[callback]
520
- )
521
- callback.trainer_ref = trainer
522
-
523
- # ---------------------------
524
- # 6. Run training
525
- # ---------------------------
526
- start_time = time.time()
527
- trainer.train()
528
- total_time = time.time() - start_time
529
-
530
- # ---------------------------
531
- # 7. Final Evaluation (evaluate best saved model on validation set)
532
- # ---------------------------
533
-
534
- best_model_path = os.path.join(BEST_MODEL_DIR, "pytorch_model.bin")
535
- if os.path.exists(best_model_path):
536
- try:
537
- model.load_state_dict(torch.load(best_model_path, map_location=device))
538
- print(f"\nLoaded best model from {best_model_path}")
539
- except Exception as e:
540
- print(f"\nFailed to load best model from {best_model_path}: {e}")
541
-
542
- model.eval()
543
- preds_bits_all = []
544
- true_bits_all = []
545
- logits_masked_final = []
546
- labels_masked_final = []
547
-
548
- with torch.no_grad():
549
- for batch in val_loader:
550
- z = batch["z"].to(device)
551
- labels_z = batch["labels_z"].to(device)
552
- attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device)
553
-
554
- logits = model(z, attention_mask=attention_mask) # [B, L, VOCAB_SIZE]
555
-
556
- mask = labels_z != -100
557
- if mask.sum().item() == 0:
558
- continue
559
-
560
- logits_masked_final.append(logits[mask])
561
- labels_masked_final.append(labels_z[mask])
562
-
563
- pred_bits = torch.argmax(logits[mask], dim=-1)
564
- true_b = labels_z[mask]
565
-
566
- preds_bits_all.extend(pred_bits.cpu().tolist())
567
- true_bits_all.extend(true_b.cpu().tolist())
568
-
569
- accuracy = accuracy_score(true_bits_all, preds_bits_all) if len(true_bits_all) > 0 else 0.0
570
- f1 = f1_score(true_bits_all, preds_bits_all, average="weighted") if len(true_bits_all) > 0 else 0.0
571
-
572
- if len(logits_masked_final) > 0:
573
- all_logits_masked_final = torch.cat(logits_masked_final, dim=0)
574
- all_labels_masked_final = torch.cat(labels_masked_final, dim=0)
575
- loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final.long())
576
- try:
577
- perplexity_final = float(torch.exp(loss_z_final).cpu().item())
578
- except Exception:
579
- perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
580
- else:
581
- perplexity_final = float("nan")
582
-
583
- best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
584
- best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
585
-
586
- print(f"\n=== Final Results (evaluated on best saved model) ===")
587
- print(f"Total Training Time (s): {total_time:.2f}")
588
- if best_epoch_num is not None:
589
- print(f"Best Epoch (1-based): {best_epoch_num}")
590
- else:
591
- print("Best Epoch: (none saved)")
592
-
593
- print(f"Best Validation Loss: {best_val_loss:.4f}")
594
- print(f"Validation Accuracy: {accuracy:.4f}")
595
- print(f"Validation F1 (weighted): {f1:.4f}")
596
- print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
597
-
598
- total_params = sum(p.numel() for p in model.parameters())
599
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
600
- non_trainable_params = total_params - trainable_params
601
- print(f"Total Parameters: {total_params}")
602
- print(f"Trainable Parameters: {trainable_params}")
603
- print(f"Non-trainable Parameters: {non_trainable_params}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fingerprint masked language modeling (MLM) using a Transformer encoder.
3
+ """
4
+
5
  import os
6
  import json
7
  import time
 
8
  import sys
9
  import csv
10
+ import argparse
11
+ from typing import List
12
 
13
+ # Increase max CSV field size limit (fingerprints can be long)
14
  csv.field_size_limit(sys.maxsize)
15
 
16
  import torch
 
24
  from transformers import TrainingArguments, Trainer
25
  from transformers.trainer_callback import TrainerCallback
26
  from sklearn.metrics import accuracy_score, f1_score
 
27
 
28
  # ---------------------------
29
  # Configuration / Constants
30
  # ---------------------------
 
31
  P_MASK = 0.15
32
 
33
+ FINGERPRINT_KEY = "morgan_r3_bits"
34
+ FP_LENGTH = 2048
35
+
 
36
  MASK_TOKEN_ID = 2
37
  VOCAB_SIZE = 3
38
 
 
39
  HIDDEN_DIM = 256
40
  TRANSFORMER_NUM_LAYERS = 4
41
  TRANSFORMER_NHEAD = 8
42
  TRANSFORMER_FF = 1024
43
  DROPOUT = 0.1
44
 
45
+ TRAIN_BATCH_SIZE = 16
 
46
  EVAL_BATCH_SIZE = 8
47
  GRADIENT_ACCUMULATION_STEPS = 4
48
  NUM_EPOCHS = 25
49
  LEARNING_RATE = 1e-4
50
  WEIGHT_DECAY = 0.01
51
 
 
 
 
 
 
52
 
53
+ def parse_args() -> argparse.Namespace:
54
+ parser = argparse.ArgumentParser(description="Fingerprint MLM pretraining (Transformer).")
55
+ parser.add_argument(
56
+ "--csv_path",
57
+ type=str,
58
+ default="/path/to/polymer_structures_unified_processed.csv",
59
+ help="Processed CSV containing a JSON 'fingerprints' column.",
60
+ )
61
+ parser.add_argument("--target_rows", type=int, default=5_000_000, help="Max rows to parse.")
62
+ parser.add_argument("--chunksize", type=int, default=50_000, help="CSV chunksize.")
63
+ parser.add_argument("--output_dir", type=str, default="/path/to/fingerprint_mlm_output_5M", help="Training output directory.")
64
+ parser.add_argument("--num_workers", type=int, default=0, help="PyTorch DataLoader num workers (kept default 0).")
65
+ return parser.parse_args()
66
+
67
+
68
+ def load_fingerprints(csv_path: str, target_rows: int, chunksize: int) -> List[List[int]]:
69
+ """Stream CSV and parse fingerprint bits into fixed-length vectors of ints."""
70
+ fp_lists: List[List[int]] = []
71
+ rows_read = 0
72
+
73
+ for chunk in pd.read_csv(csv_path, engine="python", chunksize=chunksize):
74
+ fps_chunk = chunk["fingerprints"]
75
+ for fpval in fps_chunk:
76
+ if pd.isna(fpval):
77
+ fp_lists.append([0] * FP_LENGTH)
78
+ continue
79
+
80
+ if isinstance(fpval, str):
81
  try:
82
+ fp_json = json.loads(fpval)
83
  except Exception:
84
+ try:
85
+ fp_json = json.loads(fpval.replace("'", '"'))
86
+ except Exception:
87
+ parts = [p.strip().strip('"').strip("'") for p in fpval.split(",")]
88
+ bits = [1 if p in ("1", "True", "true") else 0 for p in parts[:FP_LENGTH]]
89
+ if len(bits) < FP_LENGTH:
90
+ bits += [0] * (FP_LENGTH - len(bits))
91
+ fp_lists.append(bits)
92
+ continue
93
+ elif isinstance(fpval, dict):
94
+ fp_json = fpval
 
 
 
 
 
 
 
 
 
95
  else:
96
+ fp_lists.append([0] * FP_LENGTH)
97
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ bits = fp_json.get(FINGERPRINT_KEY, None)
100
+ if bits is None:
101
+ if isinstance(fp_json, list):
102
+ bits = fp_json
103
+ else:
104
+ bits = [0] * FP_LENGTH
105
 
106
+ normalized = []
107
+ for b in bits:
108
+ if isinstance(b, str):
109
+ b_clean = b.strip().strip('"').strip("'")
110
+ normalized.append(1 if b_clean in ("1", "True", "true") else 0)
111
+ elif isinstance(b, (int, np.integer)):
112
+ normalized.append(1 if int(b) != 0 else 0)
113
+ else:
114
+ normalized.append(0)
115
+ if len(normalized) >= FP_LENGTH:
116
+ break
117
 
118
+ if len(normalized) < FP_LENGTH:
119
+ normalized.extend([0] * (FP_LENGTH - len(normalized)))
 
120
 
121
+ fp_lists.append(normalized[:FP_LENGTH])
122
 
123
+ rows_read += len(chunk)
124
+ if rows_read >= target_rows:
125
+ break
126
+
127
+ print(f"Loaded {len(fp_lists)} fingerprint vectors (using FP_LENGTH={FP_LENGTH}).")
128
+ return fp_lists
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
 
 
 
131
  class FingerprintDataset(Dataset):
132
+ """Dataset of fixed-length fingerprint bit vectors (stored as torch.long tensors)."""
133
  def __init__(self, fps: List[torch.Tensor]):
134
  self.fps = fps
135
 
 
137
  return len(self.fps)
138
 
139
  def __getitem__(self, idx):
 
 
140
  return self.fps[idx]
141
 
142
+
143
  def collate_batch(batch):
144
  """
145
+ MLM-style collation:
146
+ - Select positions with P_MASK
147
+ - Labels are true bits only on selected positions, else -100
148
+ - Inputs are corrupted with 80/10/10 mask/random/keep policy
 
 
 
 
 
149
  """
150
  B = len(batch)
151
  if B == 0:
152
+ return {
153
+ "z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
154
+ "labels_z": torch.zeros((0, FP_LENGTH), dtype=torch.long),
155
+ "attention_mask": torch.zeros((0, FP_LENGTH), dtype=torch.bool),
156
+ }
157
 
 
158
  tensors = []
159
  for item in batch:
160
  if isinstance(item, torch.Tensor):
161
  tensors.append(item)
162
  elif isinstance(item, dict):
 
163
  if "fp" in item:
164
  val = item["fp"]
165
  if not isinstance(val, torch.Tensor):
166
  val = torch.tensor(val, dtype=torch.long)
167
  tensors.append(val)
168
  else:
 
169
  found = None
170
  for v in item.values():
171
  if isinstance(v, torch.Tensor):
 
175
  found = torch.tensor(v, dtype=torch.long)
176
  break
177
  elif isinstance(v, list):
 
178
  try:
179
  found = torch.tensor(v, dtype=torch.long)
180
  break
181
  except Exception:
182
  continue
183
  if found is None:
184
+ raise KeyError(f"collate_batch: couldn't find tensor-like fp in item keys: {list(item.keys())}")
185
  tensors.append(found)
186
  else:
187
+ tensors.append(torch.tensor(item, dtype=torch.long))
 
 
 
 
 
 
 
 
188
 
189
+ all_inputs = torch.stack(tensors, dim=0).long()
190
+ labels_z = torch.full_like(all_inputs, fill_value=-100, dtype=torch.long)
191
  z_masked = all_inputs.clone()
192
 
193
  for i in range(B):
194
+ z = all_inputs[i]
195
  n_positions = z.size(0)
 
196
  is_selected = torch.rand(n_positions) < P_MASK
 
 
197
  if is_selected.all():
198
  is_selected[torch.randint(0, n_positions, (1,))] = False
199
 
200
  sel_idx = torch.nonzero(is_selected).squeeze(-1)
201
  if sel_idx.numel() > 0:
202
+ labels_z[i, sel_idx] = z[sel_idx]
203
 
 
204
  probs = torch.rand(sel_idx.size(0))
205
  mask_choice = probs < 0.8
206
  rand_choice = (probs >= 0.8) & (probs < 0.9)
 
207
 
208
  if mask_choice.any():
209
+ z_masked[i, sel_idx[mask_choice]] = MASK_TOKEN_ID
 
210
  if rand_choice.any():
 
211
  rand_bits = torch.randint(0, 2, (rand_choice.sum().item(),), dtype=torch.long)
212
  z_masked[i, sel_idx[rand_choice]] = rand_bits
213
 
214
+ attention_mask = torch.ones_like(all_inputs, dtype=torch.bool)
 
 
 
215
  return {"z": z_masked, "labels_z": labels_z, "attention_mask": attention_mask}
216
 
 
 
 
 
 
 
 
 
217
 
218
  class FingerprintEncoder(nn.Module):
219
+ """Transformer encoder over a length-FP_LENGTH token sequence with small vocab {0,1,MASK}."""
 
 
 
 
 
 
220
  def __init__(self, vocab_size=VOCAB_SIZE, hidden_dim=HIDDEN_DIM, seq_len=FP_LENGTH,
221
  num_layers=TRANSFORMER_NUM_LAYERS, nhead=TRANSFORMER_NHEAD, dim_feedforward=TRANSFORMER_FF,
222
  dropout=DROPOUT):
223
  super().__init__()
224
  self.token_emb = nn.Embedding(vocab_size, hidden_dim)
225
  self.pos_emb = nn.Embedding(seq_len, hidden_dim)
226
+ encoder_layer = nn.TransformerEncoderLayer(
227
+ d_model=hidden_dim,
228
+ nhead=nhead,
229
+ dim_feedforward=dim_feedforward,
230
+ dropout=dropout,
231
+ batch_first=True,
232
+ )
233
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
 
234
 
235
  def forward(self, input_ids, attention_mask=None):
 
 
 
 
 
236
  B, L = input_ids.shape
237
+ x = self.token_emb(input_ids)
 
238
  pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
239
  x = x + self.pos_emb(pos_ids)
 
 
 
 
 
 
240
 
241
+ key_padding_mask = (~attention_mask) if attention_mask is not None else None
242
+ return self.transformer(x, src_key_padding_mask=key_padding_mask)
243
 
244
 
245
  class MaskedFingerprintModel(nn.Module):
246
+ """Encoder + token classification head; returns scalar loss when labels_z provided."""
 
 
 
 
 
247
  def __init__(self, hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE):
248
  super().__init__()
249
  self.encoder = FingerprintEncoder(vocab_size=vocab_size, hidden_dim=hidden_dim)
 
250
  self.mlm_head = nn.Linear(hidden_dim, vocab_size)
251
 
252
  def forward(self, z, attention_mask=None, labels_z=None):
253
+ embeddings = self.encoder(z, attention_mask=attention_mask)
254
+ logits = self.mlm_head(embeddings)
 
 
 
 
 
 
 
255
 
256
  if labels_z is not None:
257
+ mask = labels_z != -100
258
  if mask.sum() == 0:
 
259
  return torch.tensor(0.0, device=z.device)
260
 
261
+ logits_masked = logits[mask]
262
+ labels_masked = labels_z[mask].long()
263
+ return F.cross_entropy(logits_masked, labels_masked)
 
 
 
 
 
 
264
 
 
265
  return logits
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  class ValLossCallback(TrainerCallback):
269
+ """Tracks best eval loss, prints metrics, saves best model, early-stops."""
270
+ def __init__(self, best_model_dir: str, val_loader: DataLoader, patience: int = 10, trainer_ref=None):
271
  self.best_val_loss = float("inf")
272
  self.epochs_no_improve = 0
273
+ self.patience = patience
274
  self.best_epoch = None
275
  self.trainer_ref = trainer_ref
276
+ self.best_model_dir = best_model_dir
277
+ self.val_loader = val_loader
278
 
279
  def on_epoch_end(self, args, state, control, **kwargs):
280
  epoch_num = int(state.epoch)
 
285
 
286
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
287
  epoch_num = int(state.epoch) + 1
 
288
  if self.trainer_ref is None:
289
  print(f"[Eval] Epoch {epoch_num} - metrics (trainer_ref missing): {metrics}")
290
  return
291
 
292
+ metric_val_loss = metrics.get("eval_loss") if metrics is not None else None
 
 
293
 
294
  model_eval = self.trainer_ref.model
295
  model_eval.eval()
296
+ device_local = next(model_eval.parameters()).device
297
 
298
+ preds_bits, true_bits = [], []
299
+ total_loss, n_batches = 0.0, 0
300
+ logits_masked_list, labels_masked_list = [], []
 
 
 
 
 
 
301
 
302
  with torch.no_grad():
303
+ for batch in self.val_loader:
304
+ z = batch["z"].to(device_local)
305
  labels_z = batch["labels_z"].to(device_local)
306
  attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device_local)
307
 
 
308
  try:
309
  loss = model_eval(z, attention_mask=attention_mask, labels_z=labels_z)
310
+ except Exception:
311
  loss = None
312
 
313
  if isinstance(loss, torch.Tensor):
314
  total_loss += loss.item()
315
  n_batches += 1
316
 
317
+ logits = model_eval(z, attention_mask=attention_mask)
318
 
319
  mask = labels_z != -100
320
  if mask.sum().item() == 0:
 
330
  true_bits.extend(true_b.cpu().tolist())
331
 
332
  avg_val_loss = metric_val_loss if metric_val_loss is not None else ((total_loss / n_batches) if n_batches > 0 else float("nan"))
 
333
  accuracy = accuracy_score(true_bits, preds_bits) if len(true_bits) > 0 else 0.0
334
  f1 = f1_score(true_bits, preds_bits, average="weighted") if len(true_bits) > 0 else 0.0
335
 
 
336
  if len(logits_masked_list) > 0:
337
  all_logits_masked = torch.cat(logits_masked_list, dim=0)
338
  all_labels_masked = torch.cat(labels_masked_list, dim=0)
 
339
  loss_z_all = F.cross_entropy(all_logits_masked, all_labels_masked.long())
340
  try:
341
  perplexity = float(torch.exp(loss_z_all).cpu().item())
 
350
  print(f"Validation F1 (weighted): {f1:.4f}")
351
  print(f"Validation Perplexity (classification head): {perplexity:.4f}")
352
 
 
353
  if avg_val_loss is not None and not (isinstance(avg_val_loss, float) and np.isnan(avg_val_loss)) and avg_val_loss < self.best_val_loss - 1e-6:
354
  self.best_val_loss = avg_val_loss
355
  self.best_epoch = int(state.epoch)
356
  self.epochs_no_improve = 0
357
+ os.makedirs(self.best_model_dir, exist_ok=True)
358
  try:
359
+ torch.save(self.trainer_ref.model.state_dict(), os.path.join(self.best_model_dir, "pytorch_model.bin"))
360
+ print(f"Saved new best model (epoch {epoch_num}) to {os.path.join(self.best_model_dir, 'pytorch_model.bin')}")
361
  except Exception as e:
362
  print(f"Failed to save best model at epoch {epoch_num}: {e}")
363
  else:
 
368
  control.should_training_stop = True
369
 
370
 
371
+ def train_and_eval(args: argparse.Namespace) -> None:
372
+ output_dir = args.output_dir
373
+ best_model_dir = os.path.join(output_dir, "best")
374
+ os.makedirs(output_dir, exist_ok=True)
375
+
376
+ fp_lists = load_fingerprints(args.csv_path, args.target_rows, args.chunksize)
377
+
378
+ train_idx, val_idx = train_test_split(list(range(len(fp_lists))), test_size=0.2, random_state=42)
379
+ train_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in train_idx]
380
+ val_fps = [torch.tensor(fp_lists[i], dtype=torch.long) for i in val_idx]
381
+
382
+ # Compute class weights
383
+ counts = np.ones((2,), dtype=np.float64)
384
+ for fp in train_fps:
385
+ vals = fp.cpu().numpy().astype(int)
386
+ counts[0] += np.sum(vals == 0)
387
+ counts[1] += np.sum(vals == 1)
388
+ freq = counts / counts.sum()
389
+ inv_freq = 1.0 / (freq + 1e-12)
390
+ class_weights_arr = inv_freq / inv_freq.mean()
391
+ class_weights = torch.tensor(class_weights_arr, dtype=torch.float)
392
+ print("Class weights (for bit 0 and bit 1):", class_weights.numpy())
393
+
394
+ train_dataset = FingerprintDataset(train_fps)
395
+ val_dataset = FingerprintDataset(val_fps)
396
+
397
+ train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_batch, drop_last=False, num_workers=args.num_workers)
398
+ val_loader = DataLoader(val_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=False, collate_fn=collate_batch, drop_last=False, num_workers=args.num_workers)
399
+
400
+ model = MaskedFingerprintModel(hidden_dim=HIDDEN_DIM, vocab_size=VOCAB_SIZE)
401
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
402
+ model.to(device)
403
+
404
+ training_args = TrainingArguments(
405
+ output_dir=output_dir,
406
+ overwrite_output_dir=True,
407
+ num_train_epochs=NUM_EPOCHS,
408
+ per_device_train_batch_size=TRAIN_BATCH_SIZE,
409
+ per_device_eval_batch_size=EVAL_BATCH_SIZE,
410
+ eval_accumulation_steps=1000,
411
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
412
+ eval_strategy="epoch",
413
+ logging_steps=500,
414
+ learning_rate=LEARNING_RATE,
415
+ weight_decay=WEIGHT_DECAY,
416
+ fp16=torch.cuda.is_available(),
417
+ save_strategy="no",
418
+ disable_tqdm=False,
419
+ logging_first_step=True,
420
+ report_to=[],
421
+ dataloader_num_workers=args.num_workers,
422
+ )
423
+
424
+ callback = ValLossCallback(best_model_dir=best_model_dir, val_loader=val_loader, patience=10)
425
+ trainer = Trainer(
426
+ model=model,
427
+ args=training_args,
428
+ train_dataset=train_dataset,
429
+ eval_dataset=val_dataset,
430
+ data_collator=collate_batch,
431
+ callbacks=[callback],
432
+ )
433
+ callback.trainer_ref = trainer
434
+
435
+ start_time = time.time()
436
+ trainer.train()
437
+ total_time = time.time() - start_time
438
+
439
+ best_model_path = os.path.join(best_model_dir, "pytorch_model.bin")
440
+ if os.path.exists(best_model_path):
441
+ try:
442
+ model.load_state_dict(torch.load(best_model_path, map_location=device))
443
+ print(f"\nLoaded best model from {best_model_path}")
444
+ except Exception as e:
445
+ print(f"\nFailed to load best model from {best_model_path}: {e}")
446
+
447
+ # Final evaluation
448
+ model.eval()
449
+ preds_bits_all, true_bits_all = [], []
450
+ logits_masked_final, labels_masked_final = [], []
451
+
452
+ with torch.no_grad():
453
+ for batch in val_loader:
454
+ z = batch["z"].to(device)
455
+ labels_z = batch["labels_z"].to(device)
456
+ attention_mask = batch.get("attention_mask", torch.ones_like(z, dtype=torch.bool)).to(device)
457
+
458
+ logits = model(z, attention_mask=attention_mask)
459
+
460
+ mask = labels_z != -100
461
+ if mask.sum().item() == 0:
462
+ continue
463
+
464
+ logits_masked_final.append(logits[mask])
465
+ labels_masked_final.append(labels_z[mask])
466
+
467
+ pred_bits = torch.argmax(logits[mask], dim=-1)
468
+ true_b = labels_z[mask]
469
+
470
+ preds_bits_all.extend(pred_bits.cpu().tolist())
471
+ true_bits_all.extend(true_b.cpu().tolist())
472
+
473
+ accuracy = accuracy_score(true_bits_all, preds_bits_all) if len(true_bits_all) > 0 else 0.0
474
+ f1 = f1_score(true_bits_all, preds_bits_all, average="weighted") if len(true_bits_all) > 0 else 0.0
475
+
476
+ if len(logits_masked_final) > 0:
477
+ all_logits_masked_final = torch.cat(logits_masked_final, dim=0)
478
+ all_labels_masked_final = torch.cat(labels_masked_final, dim=0)
479
+ loss_z_final = F.cross_entropy(all_logits_masked_final, all_labels_masked_final.long())
480
+ try:
481
+ perplexity_final = float(torch.exp(loss_z_final).cpu().item())
482
+ except Exception:
483
+ perplexity_final = float(np.exp(float(loss_z_final.cpu().item())))
484
+ else:
485
+ perplexity_final = float("nan")
486
+
487
+ best_val_loss = callback.best_val_loss if hasattr(callback, "best_val_loss") else float("nan")
488
+ best_epoch_num = (int(callback.best_epoch) + 1) if callback.best_epoch is not None else None
489
+
490
+ print(f"\n=== Final Results (evaluated on best saved model) ===")
491
+ print(f"Total Training Time (s): {total_time:.2f}")
492
+ print(f"Best Epoch (1-based): {best_epoch_num}" if best_epoch_num is not None else "Best Epoch: (none saved)")
493
+ print(f"Best Validation Loss: {best_val_loss:.4f}")
494
+ print(f"Validation Accuracy: {accuracy:.4f}")
495
+ print(f"Validation F1 (weighted): {f1:.4f}")
496
+ print(f"Validation Perplexity (classification head): {perplexity_final:.4f}")
497
+
498
+ total_params = sum(p.numel() for p in model.parameters())
499
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
500
+ non_trainable_params = total_params - trainable_params
501
+ print(f"Total Parameters: {total_params}")
502
+ print(f"Trainable Parameters: {trainable_params}")
503
+ print(f"Non-trainable Parameters: {non_trainable_params}")
504
+
505
+
506
+ def main():
507
+ args = parse_args()
508
+ train_and_eval(args)
509
+
510
+
511
+ if __name__ == "__main__":
512
+ main()