kaurm43 commited on
Commit
983d53f
·
verified ·
1 Parent(s): b5ed4b6

Update PolyFusion/CL.py

Browse files
Files changed (1) hide show
  1. PolyFusion/CL.py +9 -13
PolyFusion/CL.py CHANGED
@@ -335,7 +335,7 @@ def prepare_or_load_data_streaming(
335
  s = row.get("psmiles", "")
336
  psmiles_raw = "" if s is None else str(s)
337
 
338
- # Require at least 2 modalities to keep sample
339
  modalities_present = sum(
340
  [1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
341
  )
@@ -352,7 +352,7 @@ def prepare_or_load_data_streaming(
352
  torch.save(sample, sample_path)
353
  except Exception as save_e:
354
  print("Warning: failed to torch.save sample:", save_e)
355
- # fallback JSON for debugging (kept from your original)
356
  try:
357
  with open(sample_path + ".json", "w") as fjson:
358
  json.dump(sample, fjson)
@@ -396,7 +396,7 @@ class LazyMultimodalDataset(Dataset):
396
  def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]:
397
  sample_path = self.files[idx]
398
 
399
- # prefer torch.load if .pt, else try json (kept behavior)
400
  if sample_path.endswith(".pt"):
401
  sample = torch.load(sample_path, map_location="cpu")
402
  else:
@@ -512,7 +512,6 @@ def multimodal_collate(batch_list: List[Dict[str, Dict[str, torch.Tensor]]]) ->
512
  ei_offset = g["edge_index"] + node_offset
513
  all_edge_index.append(ei_offset)
514
 
515
- # REUSED helper from GINE.py
516
  ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
517
  all_edge_attr.append(ea)
518
 
@@ -685,7 +684,7 @@ class MultimodalContrastiveModel(nn.Module):
685
 
686
  def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
687
  """
688
- Compute total loss = InfoNCE + REC_LOSS_WEIGHT * reconstruction_loss (if any labels exist).
689
  """
690
  device = next(self.parameters()).device
691
  embs = self.encode(batch_mods)
@@ -949,7 +948,6 @@ def mask_batch_for_modality(batch: dict, modality: str, tokenizer, p_mask: float
949
  def mm_batch_to_model_input(masked_batch: dict) -> dict:
950
  """
951
  Normalize the masked batch dict into the exact structure expected by MultimodalContrastiveModel.
952
- (Kept identical semantics.)
953
  """
954
  mm = {}
955
  if "gine" in masked_batch:
@@ -1027,7 +1025,7 @@ def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader: DataLoade
1027
  acc = (preds == labels).float().mean().item()
1028
  acc_sum += acc * B
1029
 
1030
- # Weighted F1 over instance IDs (kept as in your prior logic)
1031
  try:
1032
  labels_np = labels.cpu().numpy()
1033
  preds_np = preds.cpu().numpy()
@@ -1158,8 +1156,6 @@ class ContrastiveDataCollator:
1158
  class VerboseTrainingCallback(TrainerCallback):
1159
  """
1160
  Console-first training callback with early stopping on eval_loss.
1161
-
1162
- Behavior is kept consistent with your original callback; changes are comment/structure only.
1163
  """
1164
 
1165
  def __init__(self, patience: int = 10):
@@ -1578,7 +1574,7 @@ def main():
1578
  trainer.get_train_dataloader = lambda dataset=None: train_loader
1579
  trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader
1580
 
1581
- # Optimizer (kept as in original script)
1582
  _optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
1583
 
1584
  total_params = sum(p.numel() for p in multimodal_model.parameters())
@@ -1597,12 +1593,12 @@ def main():
1597
  except Exception:
1598
  pass
1599
 
1600
- # ---- train ----
1601
  training_start_time = time.time()
1602
  trainer.train()
1603
  training_end_time = time.time()
1604
 
1605
- # ---- save best ----
1606
  best_dir = os.path.join(OUTPUT_DIR, "best")
1607
  os.makedirs(best_dir, exist_ok=True)
1608
 
@@ -1616,7 +1612,7 @@ def main():
1616
  except Exception as e:
1617
  print("Warning: failed to load/save best model from Trainer:", e)
1618
 
1619
- # ---- final evaluation ----
1620
  final_metrics = {}
1621
  try:
1622
  if trainer.state.best_model_checkpoint:
 
335
  s = row.get("psmiles", "")
336
  psmiles_raw = "" if s is None else str(s)
337
 
338
+ # Require at least 2 modalities
339
  modalities_present = sum(
340
  [1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
341
  )
 
352
  torch.save(sample, sample_path)
353
  except Exception as save_e:
354
  print("Warning: failed to torch.save sample:", save_e)
355
+ # fallback JSON for debugging
356
  try:
357
  with open(sample_path + ".json", "w") as fjson:
358
  json.dump(sample, fjson)
 
396
  def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]:
397
  sample_path = self.files[idx]
398
 
399
+ # prefer torch.load if .pt, else try json
400
  if sample_path.endswith(".pt"):
401
  sample = torch.load(sample_path, map_location="cpu")
402
  else:
 
512
  ei_offset = g["edge_index"] + node_offset
513
  all_edge_index.append(ei_offset)
514
 
 
515
  ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
516
  all_edge_attr.append(ea)
517
 
 
684
 
685
  def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
686
  """
687
+ Compute total loss = InfoNCE + REC_LOSS_WEIGHT * reconstruction_loss
688
  """
689
  device = next(self.parameters()).device
690
  embs = self.encode(batch_mods)
 
948
  def mm_batch_to_model_input(masked_batch: dict) -> dict:
949
  """
950
  Normalize the masked batch dict into the exact structure expected by MultimodalContrastiveModel.
 
951
  """
952
  mm = {}
953
  if "gine" in masked_batch:
 
1025
  acc = (preds == labels).float().mean().item()
1026
  acc_sum += acc * B
1027
 
1028
+ # Weighted F1 over instance IDs
1029
  try:
1030
  labels_np = labels.cpu().numpy()
1031
  preds_np = preds.cpu().numpy()
 
1156
  class VerboseTrainingCallback(TrainerCallback):
1157
  """
1158
  Console-first training callback with early stopping on eval_loss.
 
 
1159
  """
1160
 
1161
  def __init__(self, patience: int = 10):
 
1574
  trainer.get_train_dataloader = lambda dataset=None: train_loader
1575
  trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader
1576
 
1577
+ # Optimizer
1578
  _optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
1579
 
1580
  total_params = sum(p.numel() for p in multimodal_model.parameters())
 
1593
  except Exception:
1594
  pass
1595
 
1596
+ # ---- Train ----
1597
  training_start_time = time.time()
1598
  trainer.train()
1599
  training_end_time = time.time()
1600
 
1601
+ # ---- Save best ----
1602
  best_dir = os.path.join(OUTPUT_DIR, "best")
1603
  os.makedirs(best_dir, exist_ok=True)
1604
 
 
1612
  except Exception as e:
1613
  print("Warning: failed to load/save best model from Trainer:", e)
1614
 
1615
+ # ---- Final Evaluation ----
1616
  final_metrics = {}
1617
  try:
1618
  if trainer.state.best_model_checkpoint: