Spaces:
Running
Running
manpreet88
commited on
Commit
·
5da1f0d
1
Parent(s):
35a954c
Update CL.py
Browse files- PolyFusion/CL.py +3 -3
PolyFusion/CL.py
CHANGED
|
@@ -335,7 +335,7 @@ def prepare_or_load_data_streaming(
|
|
| 335 |
s = row.get("psmiles", "")
|
| 336 |
psmiles_raw = "" if s is None else str(s)
|
| 337 |
|
| 338 |
-
# Require at least 2 modalities to keep sample
|
| 339 |
modalities_present = sum(
|
| 340 |
[1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
|
| 341 |
)
|
|
@@ -1328,7 +1328,7 @@ class VerboseTrainingCallback(TrainerCallback):
|
|
| 1328 |
class CLTrainer(Trainer):
|
| 1329 |
"""
|
| 1330 |
Custom Trainer:
|
| 1331 |
-
- evaluate(): merges HF eval with contrastive evaluator
|
| 1332 |
- _save(): saves a state_dict under pytorch_model.bin
|
| 1333 |
- _load_best_model(): loads best pytorch_model.bin
|
| 1334 |
"""
|
|
@@ -1544,7 +1544,7 @@ def main():
|
|
| 1544 |
|
| 1545 |
tokenizer_local = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
|
| 1546 |
|
| 1547 |
-
global train_loader, val_loader, multimodal_model, device, tokenizer
|
| 1548 |
tokenizer = tokenizer_local
|
| 1549 |
device = device_local
|
| 1550 |
|
|
|
|
| 335 |
s = row.get("psmiles", "")
|
| 336 |
psmiles_raw = "" if s is None else str(s)
|
| 337 |
|
| 338 |
+
# Require at least 2 modalities to keep sample
|
| 339 |
modalities_present = sum(
|
| 340 |
[1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
|
| 341 |
)
|
|
|
|
| 1328 |
class CLTrainer(Trainer):
|
| 1329 |
"""
|
| 1330 |
Custom Trainer:
|
| 1331 |
+
- evaluate(): merges HF eval with contrastive evaluator
|
| 1332 |
- _save(): saves a state_dict under pytorch_model.bin
|
| 1333 |
- _load_best_model(): loads best pytorch_model.bin
|
| 1334 |
"""
|
|
|
|
| 1544 |
|
| 1545 |
tokenizer_local = build_psmiles_tokenizer(spm_path=SPM_MODEL, max_len=PSMILES_MAX_LEN)
|
| 1546 |
|
| 1547 |
+
global train_loader, val_loader, multimodal_model, device, tokenizer
|
| 1548 |
tokenizer = tokenizer_local
|
| 1549 |
device = device_local
|
| 1550 |
|