Spaces:
Running
Running
Update PolyFusion/CL.py
Browse files- PolyFusion/CL.py +9 -13
PolyFusion/CL.py
CHANGED
|
@@ -335,7 +335,7 @@ def prepare_or_load_data_streaming(
|
|
| 335 |
s = row.get("psmiles", "")
|
| 336 |
psmiles_raw = "" if s is None else str(s)
|
| 337 |
|
| 338 |
-
# Require at least 2 modalities
|
| 339 |
modalities_present = sum(
|
| 340 |
[1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
|
| 341 |
)
|
|
@@ -352,7 +352,7 @@ def prepare_or_load_data_streaming(
|
|
| 352 |
torch.save(sample, sample_path)
|
| 353 |
except Exception as save_e:
|
| 354 |
print("Warning: failed to torch.save sample:", save_e)
|
| 355 |
-
# fallback JSON for debugging
|
| 356 |
try:
|
| 357 |
with open(sample_path + ".json", "w") as fjson:
|
| 358 |
json.dump(sample, fjson)
|
|
@@ -396,7 +396,7 @@ class LazyMultimodalDataset(Dataset):
|
|
| 396 |
def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 397 |
sample_path = self.files[idx]
|
| 398 |
|
| 399 |
-
# prefer torch.load if .pt, else try json
|
| 400 |
if sample_path.endswith(".pt"):
|
| 401 |
sample = torch.load(sample_path, map_location="cpu")
|
| 402 |
else:
|
|
@@ -512,7 +512,6 @@ def multimodal_collate(batch_list: List[Dict[str, Dict[str, torch.Tensor]]]) ->
|
|
| 512 |
ei_offset = g["edge_index"] + node_offset
|
| 513 |
all_edge_index.append(ei_offset)
|
| 514 |
|
| 515 |
-
# REUSED helper from GINE.py
|
| 516 |
ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
|
| 517 |
all_edge_attr.append(ea)
|
| 518 |
|
|
@@ -685,7 +684,7 @@ class MultimodalContrastiveModel(nn.Module):
|
|
| 685 |
|
| 686 |
def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
|
| 687 |
"""
|
| 688 |
-
Compute total loss = InfoNCE + REC_LOSS_WEIGHT * reconstruction_loss
|
| 689 |
"""
|
| 690 |
device = next(self.parameters()).device
|
| 691 |
embs = self.encode(batch_mods)
|
|
@@ -949,7 +948,6 @@ def mask_batch_for_modality(batch: dict, modality: str, tokenizer, p_mask: float
|
|
| 949 |
def mm_batch_to_model_input(masked_batch: dict) -> dict:
|
| 950 |
"""
|
| 951 |
Normalize the masked batch dict into the exact structure expected by MultimodalContrastiveModel.
|
| 952 |
-
(Kept identical semantics.)
|
| 953 |
"""
|
| 954 |
mm = {}
|
| 955 |
if "gine" in masked_batch:
|
|
@@ -1027,7 +1025,7 @@ def evaluate_multimodal(model: MultimodalContrastiveModel, val_loader: DataLoade
|
|
| 1027 |
acc = (preds == labels).float().mean().item()
|
| 1028 |
acc_sum += acc * B
|
| 1029 |
|
| 1030 |
-
# Weighted F1 over instance IDs
|
| 1031 |
try:
|
| 1032 |
labels_np = labels.cpu().numpy()
|
| 1033 |
preds_np = preds.cpu().numpy()
|
|
@@ -1158,8 +1156,6 @@ class ContrastiveDataCollator:
|
|
| 1158 |
class VerboseTrainingCallback(TrainerCallback):
|
| 1159 |
"""
|
| 1160 |
Console-first training callback with early stopping on eval_loss.
|
| 1161 |
-
|
| 1162 |
-
Behavior is kept consistent with your original callback; changes are comment/structure only.
|
| 1163 |
"""
|
| 1164 |
|
| 1165 |
def __init__(self, patience: int = 10):
|
|
@@ -1578,7 +1574,7 @@ def main():
|
|
| 1578 |
trainer.get_train_dataloader = lambda dataset=None: train_loader
|
| 1579 |
trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader
|
| 1580 |
|
| 1581 |
-
# Optimizer
|
| 1582 |
_optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
|
| 1583 |
|
| 1584 |
total_params = sum(p.numel() for p in multimodal_model.parameters())
|
|
@@ -1597,12 +1593,12 @@ def main():
|
|
| 1597 |
except Exception:
|
| 1598 |
pass
|
| 1599 |
|
| 1600 |
-
# ----
|
| 1601 |
training_start_time = time.time()
|
| 1602 |
trainer.train()
|
| 1603 |
training_end_time = time.time()
|
| 1604 |
|
| 1605 |
-
# ----
|
| 1606 |
best_dir = os.path.join(OUTPUT_DIR, "best")
|
| 1607 |
os.makedirs(best_dir, exist_ok=True)
|
| 1608 |
|
|
@@ -1616,7 +1612,7 @@ def main():
|
|
| 1616 |
except Exception as e:
|
| 1617 |
print("Warning: failed to load/save best model from Trainer:", e)
|
| 1618 |
|
| 1619 |
-
# ----
|
| 1620 |
final_metrics = {}
|
| 1621 |
try:
|
| 1622 |
if trainer.state.best_model_checkpoint:
|
|
|
|
| 335 |
s = row.get("psmiles", "")
|
| 336 |
psmiles_raw = "" if s is None else str(s)
|
| 337 |
|
| 338 |
+
# Require at least 2 modalities
|
| 339 |
modalities_present = sum(
|
| 340 |
[1 if x is not None else 0 for x in [gine_sample, schnet_sample, fp_sample, psmiles_raw]]
|
| 341 |
)
|
|
|
|
| 352 |
torch.save(sample, sample_path)
|
| 353 |
except Exception as save_e:
|
| 354 |
print("Warning: failed to torch.save sample:", save_e)
|
| 355 |
+
# fallback JSON for debugging
|
| 356 |
try:
|
| 357 |
with open(sample_path + ".json", "w") as fjson:
|
| 358 |
json.dump(sample, fjson)
|
|
|
|
| 396 |
def __getitem__(self, idx: int) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 397 |
sample_path = self.files[idx]
|
| 398 |
|
| 399 |
+
# prefer torch.load if .pt, else try json
|
| 400 |
if sample_path.endswith(".pt"):
|
| 401 |
sample = torch.load(sample_path, map_location="cpu")
|
| 402 |
else:
|
|
|
|
| 512 |
ei_offset = g["edge_index"] + node_offset
|
| 513 |
all_edge_index.append(ei_offset)
|
| 514 |
|
|
|
|
| 515 |
ea = match_edge_attr_to_index(g["edge_index"], g["edge_attr"], target_dim=3)
|
| 516 |
all_edge_attr.append(ea)
|
| 517 |
|
|
|
|
| 684 |
|
| 685 |
def forward(self, batch_mods: Dict[str, torch.Tensor], mask_target: str):
|
| 686 |
"""
|
| 687 |
+
Compute total loss = InfoNCE + REC_LOSS_WEIGHT * reconstruction_loss
|
| 688 |
"""
|
| 689 |
device = next(self.parameters()).device
|
| 690 |
embs = self.encode(batch_mods)
|
|
|
|
| 948 |
def mm_batch_to_model_input(masked_batch: dict) -> dict:
|
| 949 |
"""
|
| 950 |
Normalize the masked batch dict into the exact structure expected by MultimodalContrastiveModel.
|
|
|
|
| 951 |
"""
|
| 952 |
mm = {}
|
| 953 |
if "gine" in masked_batch:
|
|
|
|
| 1025 |
acc = (preds == labels).float().mean().item()
|
| 1026 |
acc_sum += acc * B
|
| 1027 |
|
| 1028 |
+
# Weighted F1 over instance IDs
|
| 1029 |
try:
|
| 1030 |
labels_np = labels.cpu().numpy()
|
| 1031 |
preds_np = preds.cpu().numpy()
|
|
|
|
| 1156 |
class VerboseTrainingCallback(TrainerCallback):
|
| 1157 |
"""
|
| 1158 |
Console-first training callback with early stopping on eval_loss.
|
|
|
|
|
|
|
| 1159 |
"""
|
| 1160 |
|
| 1161 |
def __init__(self, patience: int = 10):
|
|
|
|
| 1574 |
trainer.get_train_dataloader = lambda dataset=None: train_loader
|
| 1575 |
trainer.get_eval_dataloader = lambda eval_dataset=None: val_loader
|
| 1576 |
|
| 1577 |
+
# Optimizer
|
| 1578 |
_optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=training_args.learning_rate, weight_decay=training_args.weight_decay)
|
| 1579 |
|
| 1580 |
total_params = sum(p.numel() for p in multimodal_model.parameters())
|
|
|
|
| 1593 |
except Exception:
|
| 1594 |
pass
|
| 1595 |
|
| 1596 |
+
# ---- Train ----
|
| 1597 |
training_start_time = time.time()
|
| 1598 |
trainer.train()
|
| 1599 |
training_end_time = time.time()
|
| 1600 |
|
| 1601 |
+
# ---- Save best ----
|
| 1602 |
best_dir = os.path.join(OUTPUT_DIR, "best")
|
| 1603 |
os.makedirs(best_dir, exist_ok=True)
|
| 1604 |
|
|
|
|
| 1612 |
except Exception as e:
|
| 1613 |
print("Warning: failed to load/save best model from Trainer:", e)
|
| 1614 |
|
| 1615 |
+
# ---- Final Evaluation ----
|
| 1616 |
final_metrics = {}
|
| 1617 |
try:
|
| 1618 |
if trainer.state.best_model_checkpoint:
|