""" Crystal-Embedder Microservice Computes Tri-Fusion embeddings (Orb-v3 + l-MM + l-OFM) for crystal structures. Deployed on Hugging Face Spaces (CPU-only, 16GB RAM). API: POST /embed Request: {"cif": ""} Response: {"vector": [...], "dims": 2738} """ import os # CRITICAL: Set TensorFlow to use legacy Keras before any imports # Required for MEGNet compatibility with TensorFlow 2.16+ os.environ["TF_USE_LEGACY_KERAS"] = "1" from fastapi import FastAPI, HTTPException from pydantic import BaseModel import numpy as np import pandas as pd from sklearn.preprocessing import normalize from pymatgen.core import Structure app = FastAPI( title="Crystal-Embedder", description="Tri-Fusion physics embeddings for crystal structures", version="1.0.0" ) # Global model instances (loaded at startup) orb_model = None mm_model = None ofm_model = None models_loaded = False class CifRequest(BaseModel): cif: str class EmbedResponse(BaseModel): vector: list[float] dims: int class HealthResponse(BaseModel): status: str models_loaded: bool vector_dims: int def load_models(): """Load all three featurizer models at startup.""" global orb_model, mm_model, ofm_model, models_loaded print("Loading Physics Engines (Orb-v3 + l-MM + l-OFM)...") print("This may take a few minutes on first load...") try: # Import MatterVial featurizers from mattervial.featurizers import ORBFeaturizer, DescriptorMEGNetFeaturizer # Load Orb-v3 (PyTorch, CPU) print(" Loading Orb-v3 (1792-dim)...") orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu") # Load l-MM (TensorFlow/MEGNet) print(" Loading l-MM (758-dim)...") mm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-MM_v1') # Load l-OFM (TensorFlow/MEGNet) print(" Loading l-OFM (188-dim)...") ofm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-OFM_v1') models_loaded = True print("All models loaded successfully!") except Exception as e: print(f"Error loading models: {e}") raise @app.on_event("startup") async def startup_event(): """Load models when the server starts.""" load_models() @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint.""" return HealthResponse( status="healthy" if models_loaded else "loading", models_loaded=models_loaded, vector_dims=2738 ) @app.post("/embed", response_model=EmbedResponse) async def embed_structure(req: CifRequest): """ Compute Tri-Fusion embedding for a CIF structure. The embedding is a concatenation of three L2-normalized vectors: - Orb-v3: 1792-dim (force field features) - l-MM: 758-dim (electronic structure features) - l-OFM: 188-dim (orbital field matrix features) Total: 2738 dimensions """ if not models_loaded: raise HTTPException(status_code=503, detail="Models still loading, please retry") try: # 1. Parse CIF to pymatgen Structure struct = Structure.from_str(req.cif, fmt="cif") s_series = pd.Series([struct]) # 2. Compute features (sequential on CPU) print("Computing Orb-v3 features...") vec_orb = orb_model.get_features(s_series).values[0] # ~10s print("Computing l-MM features...") vec_mm = mm_model.get_features(s_series).values[0] # ~2s print("Computing l-OFM features...") vec_ofm = ofm_model.get_features(s_series).values[0] # ~0.5s # 3. Handle NaN values (replace with 0) vec_orb = np.nan_to_num(vec_orb, nan=0.0) vec_mm = np.nan_to_num(vec_mm, nan=0.0) vec_ofm = np.nan_to_num(vec_ofm, nan=0.0) # 4. L2 Normalize each vector (prevents magnitude dominance) v1 = normalize([vec_orb])[0] v2 = normalize([vec_mm])[0] v3 = normalize([vec_ofm])[0] # 5. Concatenate final_vector = np.concatenate([v1, v2, v3]) print(f"Embedding complete: {len(final_vector)} dimensions") return EmbedResponse( vector=final_vector.tolist(), dims=len(final_vector) ) except Exception as e: raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}") @app.get("/") async def root(): """Root endpoint with API info.""" return { "service": "Crystal-Embedder", "version": "1.0.0", "description": "Tri-Fusion physics embeddings for crystal structures", "endpoints": { "/embed": "POST - Compute embedding from CIF", "/health": "GET - Health check" }, "vector_dimensions": { "orb_v3": 1792, "l_mm": 758, "l_ofm": 188, "total": 2738 } }