""" Crystal-Embedder-Orb Microservice Computes Orb-v3 force field embeddings (1792-dim) for crystal structures. Deployed on Hugging Face Spaces (CPU-only). API: POST /embed Request: {"cif": ""} Response: { "vectors": {"orb": [...]}, "dims": {"orb": 1792} } """ import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel import numpy as np import pandas as pd from sklearn.preprocessing import normalize from pymatgen.core import Structure app = FastAPI( title="Crystal-Embedder-Orb", description="Orb-v3 force field embeddings for crystal structures", version="2.0.0", ) # Global model instance (loaded at startup) orb_model = None models_loaded = False ORB_DIM = 1792 class CifRequest(BaseModel): cif: str class EmbedResponse(BaseModel): vectors: dict[str, list[float]] dims: dict[str, int] class HealthResponse(BaseModel): status: str models_loaded: bool vector_dims: dict[str, int] def load_models(): """Load Orb-v3 featurizer at startup.""" global orb_model, models_loaded print("Loading Orb-v3 (1792-dim)...") try: from mattervial.featurizers import ORBFeaturizer orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu") models_loaded = True print("Orb-v3 loaded successfully!") except Exception as e: print(f"Error loading Orb-v3: {e}") raise @app.on_event("startup") async def startup_event(): """Load models when the server starts.""" load_models() @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint.""" return HealthResponse( status="healthy" if models_loaded else "loading", models_loaded=models_loaded, vector_dims={"orb": ORB_DIM}, ) @app.post("/embed", response_model=EmbedResponse) async def embed_structure(req: CifRequest): """ Compute Orb-v3 embedding for a CIF structure. Returns an L2-normalized 1792-dim force field embedding as a named vector. """ if not models_loaded: raise HTTPException(status_code=503, detail="Model still loading, please retry") try: struct = Structure.from_str(req.cif, fmt="cif") s_series = pd.Series([struct]) print("Computing Orb-v3 features...") vec_orb = orb_model.get_features(s_series).values[0] vec_orb = np.nan_to_num(vec_orb, nan=0.0) vec_orb = normalize([vec_orb])[0] print(f"Embedding complete: orb={len(vec_orb)}") return EmbedResponse( vectors={"orb": vec_orb.tolist()}, dims={"orb": ORB_DIM}, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}") @app.get("/") async def root(): """Root endpoint with API info.""" return { "service": "Crystal-Embedder-Orb", "version": "2.0.0", "description": "Orb-v3 force field embeddings for crystal structures", "endpoints": { "/embed": "POST - Compute Orb-v3 embedding from CIF", "/health": "GET - Health check", }, "vector_dimensions": {"orb": ORB_DIM}, }