MolCraftDiffusion-demo / verify_dependencies.py
iflp1908sl's picture
Initial commit with robust CPU inference
970ac6b
import sys
import os
import torch
import pickle
# Point to our LOCAL copy of MolecularDiffusion
# This is crucial: we are testing if the copied code works
SPACE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, SPACE_DIR)
print(f"Testing environment in: {SPACE_DIR}")
try:
# 1. Test imports from local copy
import MolecularDiffusion
print(f"Imported MolecularDiffusion from: {MolecularDiffusion.__file__}")
from MolecularDiffusion.core import Engine
from MolecularDiffusion.utils import seed_everything
# 2. Test Model Loading
MODEL_PATH = os.path.join(SPACE_DIR, "model", "edm_chem.pkl")
STAT_PATH = os.path.join(SPACE_DIR, "model", "edm_stat.pkl")
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Missing model file: {MODEL_PATH}")
print(f"Loading checkpoint: {MODEL_PATH}")
# Force CPU to adhere to space constraints
device = torch.device("cpu")
engine = Engine(None, None, None, None, None)
# This call triggers the pickle load which requires all class defs to be available
engine = engine.load_from_checkpoint(MODEL_PATH, interference_mode=True)
task = engine.model
# Load stats
if os.path.exists(STAT_PATH):
with open(STAT_PATH, "rb") as f:
stats = pickle.load(f)
task.node_dist_model = stats["node"]
# manually set vocab
task.atom_vocab = ["H","B","C","N","O","F","Al","Si","P","S","Cl","As","Se","Br","I","Hg","Bi"]
# Move everything to CPU at the end
task.to(device)
task.eval()
# Monkeypatch device property to force CPU
print("DEBUG: Monkeypatching task.device to return CPU")
# If it's a property on the class, we need to patch the class or the instance's class
# If it's just an attribute, we set it.
try:
task.device = device
except AttributeError:
# It's a property, patch the class
type(task).device = property(lambda self: torch.device("cpu"))
print(f"DEBUG: task.device prop AFTER PATCH: {task.device}")
# Iterate over all modules to clear internal caches and fix metadata
print("DEBUG: Cleaning up module caches and device attributes...")
for name, module in task.named_modules():
# 1. Clear _edges_dict cache in EGNN_dynamics
if hasattr(module, "_edges_dict"):
print(f"DEBUG: Clearing _edges_dict for {name}")
module._edges_dict = {}
# 2. Fix 'device' attribute if it exists and is not a property
if hasattr(module, "device"):
# Check if it is a property or simple attribute
# If it's a property, it might crash if we try to set it, or just work if setter exists
# We can check class dict
is_property = isinstance(getattr(type(module), "device", None), property)
if not is_property:
try:
print(f"DEBUG: Setting device attribute for {name} to {device}")
module.device = device
except Exception as e:
print(f"DEBUG: Failed to set device for {name}: {e}")
print("SUCCESS: Model loaded from checkpoint!")
# 3. Test Inference
print("Running sample inference (1 molecule, 5 steps)...")
nodesxsample = torch.tensor([5]).long().to(device)
# T override
old_T = task.model.T
task.model.T = 5
one_hot, charges, x, node_mask = task.sample(
nodesxsample=nodesxsample,
mode="ddpm",
n_frames=0,
fix_noise=False
)
task.model.T = old_T
print(f"Generated molecule with {int(node_mask.sum())} atoms.")
print(f"Positions shape: {x.shape}")
print("SUCCESS: Inference pipeline is functional!")
except Exception as e:
print(f"\nCRITICAL FAILURE: {e}")
import traceback
traceback.print_exc()
sys.exit(1)