Spaces:
Sleeping
Sleeping
| """ | |
| Crystal-Embedder Microservice | |
| Computes Tri-Fusion embeddings (Orb-v3 + l-MM + l-OFM) for crystal structures. | |
| Deployed on Hugging Face Spaces (CPU-only, 16GB RAM). | |
| API: | |
| POST /embed | |
| Request: {"cif": "<CIF content>"} | |
| Response: {"vector": [...], "dims": 2738} | |
| """ | |
| import os | |
| # CRITICAL: Set TensorFlow to use legacy Keras before any imports | |
| # Required for MEGNet compatibility with TensorFlow 2.16+ | |
| os.environ["TF_USE_LEGACY_KERAS"] = "1" | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.preprocessing import normalize | |
| from pymatgen.core import Structure | |
| app = FastAPI( | |
| title="Crystal-Embedder", | |
| description="Tri-Fusion physics embeddings for crystal structures", | |
| version="1.0.0" | |
| ) | |
| # Global model instances (loaded at startup) | |
| orb_model = None | |
| mm_model = None | |
| ofm_model = None | |
| models_loaded = False | |
| class CifRequest(BaseModel): | |
| cif: str | |
| class EmbedResponse(BaseModel): | |
| vector: list[float] | |
| dims: int | |
| class HealthResponse(BaseModel): | |
| status: str | |
| models_loaded: bool | |
| vector_dims: int | |
| def load_models(): | |
| """Load all three featurizer models at startup.""" | |
| global orb_model, mm_model, ofm_model, models_loaded | |
| print("Loading Physics Engines (Orb-v3 + l-MM + l-OFM)...") | |
| print("This may take a few minutes on first load...") | |
| try: | |
| # Import MatterVial featurizers | |
| from mattervial.featurizers import ORBFeaturizer, DescriptorMEGNetFeaturizer | |
| # Load Orb-v3 (PyTorch, CPU) | |
| print(" Loading Orb-v3 (1792-dim)...") | |
| orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu") | |
| # Load l-MM (TensorFlow/MEGNet) | |
| print(" Loading l-MM (758-dim)...") | |
| mm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-MM_v1') | |
| # Load l-OFM (TensorFlow/MEGNet) | |
| print(" Loading l-OFM (188-dim)...") | |
| ofm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-OFM_v1') | |
| models_loaded = True | |
| print("All models loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| raise | |
| async def startup_event(): | |
| """Load models when the server starts.""" | |
| load_models() | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return HealthResponse( | |
| status="healthy" if models_loaded else "loading", | |
| models_loaded=models_loaded, | |
| vector_dims=2738 | |
| ) | |
| async def embed_structure(req: CifRequest): | |
| """ | |
| Compute Tri-Fusion embedding for a CIF structure. | |
| The embedding is a concatenation of three L2-normalized vectors: | |
| - Orb-v3: 1792-dim (force field features) | |
| - l-MM: 758-dim (electronic structure features) | |
| - l-OFM: 188-dim (orbital field matrix features) | |
| Total: 2738 dimensions | |
| """ | |
| if not models_loaded: | |
| raise HTTPException(status_code=503, detail="Models still loading, please retry") | |
| try: | |
| # 1. Parse CIF to pymatgen Structure | |
| struct = Structure.from_str(req.cif, fmt="cif") | |
| s_series = pd.Series([struct]) | |
| # 2. Compute features (sequential on CPU) | |
| print("Computing Orb-v3 features...") | |
| vec_orb = orb_model.get_features(s_series).values[0] # ~10s | |
| print("Computing l-MM features...") | |
| vec_mm = mm_model.get_features(s_series).values[0] # ~2s | |
| print("Computing l-OFM features...") | |
| vec_ofm = ofm_model.get_features(s_series).values[0] # ~0.5s | |
| # 3. Handle NaN values (replace with 0) | |
| vec_orb = np.nan_to_num(vec_orb, nan=0.0) | |
| vec_mm = np.nan_to_num(vec_mm, nan=0.0) | |
| vec_ofm = np.nan_to_num(vec_ofm, nan=0.0) | |
| # 4. L2 Normalize each vector (prevents magnitude dominance) | |
| v1 = normalize([vec_orb])[0] | |
| v2 = normalize([vec_mm])[0] | |
| v3 = normalize([vec_ofm])[0] | |
| # 5. Concatenate | |
| final_vector = np.concatenate([v1, v2, v3]) | |
| print(f"Embedding complete: {len(final_vector)} dimensions") | |
| return EmbedResponse( | |
| vector=final_vector.tolist(), | |
| dims=len(final_vector) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}") | |
| async def root(): | |
| """Root endpoint with API info.""" | |
| return { | |
| "service": "Crystal-Embedder", | |
| "version": "1.0.0", | |
| "description": "Tri-Fusion physics embeddings for crystal structures", | |
| "endpoints": { | |
| "/embed": "POST - Compute embedding from CIF", | |
| "/health": "GET - Health check" | |
| }, | |
| "vector_dimensions": { | |
| "orb_v3": 1792, | |
| "l_mm": 758, | |
| "l_ofm": 188, | |
| "total": 2738 | |
| } | |
| } | |