Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import torch | |
| import pickle | |
| # Point to our LOCAL copy of MolecularDiffusion | |
| # This is crucial: we are testing if the copied code works | |
| SPACE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.insert(0, SPACE_DIR) | |
| print(f"Testing environment in: {SPACE_DIR}") | |
| try: | |
| # 1. Test imports from local copy | |
| import MolecularDiffusion | |
| print(f"Imported MolecularDiffusion from: {MolecularDiffusion.__file__}") | |
| from MolecularDiffusion.core import Engine | |
| from MolecularDiffusion.utils import seed_everything | |
| # 2. Test Model Loading | |
| MODEL_PATH = os.path.join(SPACE_DIR, "model", "edm_chem.pkl") | |
| STAT_PATH = os.path.join(SPACE_DIR, "model", "edm_stat.pkl") | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Missing model file: {MODEL_PATH}") | |
| print(f"Loading checkpoint: {MODEL_PATH}") | |
| # Force CPU to adhere to space constraints | |
| device = torch.device("cpu") | |
| engine = Engine(None, None, None, None, None) | |
| # This call triggers the pickle load which requires all class defs to be available | |
| engine = engine.load_from_checkpoint(MODEL_PATH, interference_mode=True) | |
| task = engine.model | |
| # Load stats | |
| if os.path.exists(STAT_PATH): | |
| with open(STAT_PATH, "rb") as f: | |
| stats = pickle.load(f) | |
| task.node_dist_model = stats["node"] | |
| # manually set vocab | |
| task.atom_vocab = ["H","B","C","N","O","F","Al","Si","P","S","Cl","As","Se","Br","I","Hg","Bi"] | |
| # Move everything to CPU at the end | |
| task.to(device) | |
| task.eval() | |
| # Monkeypatch device property to force CPU | |
| print("DEBUG: Monkeypatching task.device to return CPU") | |
| # If it's a property on the class, we need to patch the class or the instance's class | |
| # If it's just an attribute, we set it. | |
| try: | |
| task.device = device | |
| except AttributeError: | |
| # It's a property, patch the class | |
| type(task).device = property(lambda self: torch.device("cpu")) | |
| print(f"DEBUG: task.device prop AFTER PATCH: {task.device}") | |
| # Iterate over all modules to clear internal caches and fix metadata | |
| print("DEBUG: Cleaning up module caches and device attributes...") | |
| for name, module in task.named_modules(): | |
| # 1. Clear _edges_dict cache in EGNN_dynamics | |
| if hasattr(module, "_edges_dict"): | |
| print(f"DEBUG: Clearing _edges_dict for {name}") | |
| module._edges_dict = {} | |
| # 2. Fix 'device' attribute if it exists and is not a property | |
| if hasattr(module, "device"): | |
| # Check if it is a property or simple attribute | |
| # If it's a property, it might crash if we try to set it, or just work if setter exists | |
| # We can check class dict | |
| is_property = isinstance(getattr(type(module), "device", None), property) | |
| if not is_property: | |
| try: | |
| print(f"DEBUG: Setting device attribute for {name} to {device}") | |
| module.device = device | |
| except Exception as e: | |
| print(f"DEBUG: Failed to set device for {name}: {e}") | |
| print("SUCCESS: Model loaded from checkpoint!") | |
| # 3. Test Inference | |
| print("Running sample inference (1 molecule, 5 steps)...") | |
| nodesxsample = torch.tensor([5]).long().to(device) | |
| # T override | |
| old_T = task.model.T | |
| task.model.T = 5 | |
| one_hot, charges, x, node_mask = task.sample( | |
| nodesxsample=nodesxsample, | |
| mode="ddpm", | |
| n_frames=0, | |
| fix_noise=False | |
| ) | |
| task.model.T = old_T | |
| print(f"Generated molecule with {int(node_mask.sum())} atoms.") | |
| print(f"Positions shape: {x.shape}") | |
| print("SUCCESS: Inference pipeline is functional!") | |
| except Exception as e: | |
| print(f"\nCRITICAL FAILURE: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |