Spaces:
Sleeping
Sleeping
| """ | |
| Crystal-Embedder-Orb Microservice | |
| Computes Orb-v3 force field embeddings (1792-dim) for crystal structures. | |
| Deployed on Hugging Face Spaces (CPU-only). | |
| API: | |
| POST /embed | |
| Request: {"cif": "<CIF content>"} | |
| Response: { | |
| "vectors": {"orb": [...]}, | |
| "dims": {"orb": 1792} | |
| } | |
| """ | |
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.preprocessing import normalize | |
| from pymatgen.core import Structure | |
| app = FastAPI( | |
| title="Crystal-Embedder-Orb", | |
| description="Orb-v3 force field embeddings for crystal structures", | |
| version="2.0.0", | |
| ) | |
| # Global model instance (loaded at startup) | |
| orb_model = None | |
| models_loaded = False | |
| ORB_DIM = 1792 | |
| class CifRequest(BaseModel): | |
| cif: str | |
| class EmbedResponse(BaseModel): | |
| vectors: dict[str, list[float]] | |
| dims: dict[str, int] | |
| class HealthResponse(BaseModel): | |
| status: str | |
| models_loaded: bool | |
| vector_dims: dict[str, int] | |
| def load_models(): | |
| """Load Orb-v3 featurizer at startup.""" | |
| global orb_model, models_loaded | |
| print("Loading Orb-v3 (1792-dim)...") | |
| try: | |
| from mattervial.featurizers import ORBFeaturizer | |
| orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu") | |
| models_loaded = True | |
| print("Orb-v3 loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading Orb-v3: {e}") | |
| raise | |
| async def startup_event(): | |
| """Load models when the server starts.""" | |
| load_models() | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return HealthResponse( | |
| status="healthy" if models_loaded else "loading", | |
| models_loaded=models_loaded, | |
| vector_dims={"orb": ORB_DIM}, | |
| ) | |
| async def embed_structure(req: CifRequest): | |
| """ | |
| Compute Orb-v3 embedding for a CIF structure. | |
| Returns an L2-normalized 1792-dim force field embedding as a named vector. | |
| """ | |
| if not models_loaded: | |
| raise HTTPException(status_code=503, detail="Model still loading, please retry") | |
| try: | |
| struct = Structure.from_str(req.cif, fmt="cif") | |
| s_series = pd.Series([struct]) | |
| print("Computing Orb-v3 features...") | |
| vec_orb = orb_model.get_features(s_series).values[0] | |
| vec_orb = np.nan_to_num(vec_orb, nan=0.0) | |
| vec_orb = normalize([vec_orb])[0] | |
| print(f"Embedding complete: orb={len(vec_orb)}") | |
| return EmbedResponse( | |
| vectors={"orb": vec_orb.tolist()}, | |
| dims={"orb": ORB_DIM}, | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}") | |
| async def root(): | |
| """Root endpoint with API info.""" | |
| return { | |
| "service": "Crystal-Embedder-Orb", | |
| "version": "2.0.0", | |
| "description": "Orb-v3 force field embeddings for crystal structures", | |
| "endpoints": { | |
| "/embed": "POST - Compute Orb-v3 embedding from CIF", | |
| "/health": "GET - Health check", | |
| }, | |
| "vector_dimensions": {"orb": ORB_DIM}, | |
| } | |