manpreet88 commited on
Commit
5da1f0d
·
1 Parent(s): 35a954c

Update CL.py

Browse files
Files changed (1) hide show
  1. PolyFusion/CL.py +3 -3
PolyFusion/CL.py CHANGED
@@ -335,7 +335,7 @@ def prepare_or_load_data_streaming(
335
  s = row.get("psmiles", "")
336
  psmiles_raw = "" if s is None else str(s)
337
 
338
- # Require at least 2 modalities to keep sample (same logic as your original)
339
  modalities_present = sum(
340
  [1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
341
  )
@@ -1328,7 +1328,7 @@ class VerboseTrainingCallback(TrainerCallback):
1328
  class CLTrainer(Trainer):
1329
  """
1330
  Custom Trainer:
1331
- - evaluate(): merges HF eval with contrastive evaluator (same behavior)
1332
  - _save(): saves a state_dict under pytorch_model.bin
1333
  - _load_best_model(): loads best pytorch_model.bin
1334
  """
@@ -1544,7 +1544,7 @@ def main():
1544
 
1545
  tokenizer_local = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
1546
 
1547
- global train_loader, val_loader, multimodal_model, device, tokenizer # kept for callback references (same behavior)
1548
  tokenizer = tokenizer_local
1549
  device = device_local
1550
 
 
335
  s = row.get("psmiles", "")
336
  psmiles_raw = "" if s is None else str(s)
337
 
338
+ # Require at least 2 modalities to keep sample
339
  modalities_present = sum(
340
  [1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
341
  )
 
1328
  class CLTrainer(Trainer):
1329
  """
1330
  Custom Trainer:
1331
+ - evaluate(): merges HF eval with contrastive evaluator
1332
  - _save(): saves a state_dict under pytorch_model.bin
1333
  - _load_best_model(): loads best pytorch_model.bin
1334
  """
 
1544
 
1545
  tokenizer_local = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
1546
 
1547
+ global train_loader, val_loader, multimodal_model, device, tokenizer
1548
  tokenizer = tokenizer_local
1549
  device = device_local
1550