Hafnium49's picture
New: ORB-v3 named vector endpoint
55ff391 verified
"""
Crystal-Embedder-Orb Microservice
Computes Orb-v3 force field embeddings (1792-dim) for crystal structures.
Deployed on Hugging Face Spaces (CPU-only).
API:
POST /embed
Request: {"cif": "<CIF content>"}
Response: {
"vectors": {"orb": [...]},
"dims": {"orb": 1792}
}
"""
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize
from pymatgen.core import Structure
app = FastAPI(
title="Crystal-Embedder-Orb",
description="Orb-v3 force field embeddings for crystal structures",
version="2.0.0",
)
# Global model instance (loaded at startup)
orb_model = None
models_loaded = False
ORB_DIM = 1792
class CifRequest(BaseModel):
cif: str
class EmbedResponse(BaseModel):
vectors: dict[str, list[float]]
dims: dict[str, int]
class HealthResponse(BaseModel):
status: str
models_loaded: bool
vector_dims: dict[str, int]
def load_models():
"""Load Orb-v3 featurizer at startup."""
global orb_model, models_loaded
print("Loading Orb-v3 (1792-dim)...")
try:
from mattervial.featurizers import ORBFeaturizer
orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu")
models_loaded = True
print("Orb-v3 loaded successfully!")
except Exception as e:
print(f"Error loading Orb-v3: {e}")
raise
@app.on_event("startup")
async def startup_event():
"""Load models when the server starts."""
load_models()
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return HealthResponse(
status="healthy" if models_loaded else "loading",
models_loaded=models_loaded,
vector_dims={"orb": ORB_DIM},
)
@app.post("/embed", response_model=EmbedResponse)
async def embed_structure(req: CifRequest):
"""
Compute Orb-v3 embedding for a CIF structure.
Returns an L2-normalized 1792-dim force field embedding as a named vector.
"""
if not models_loaded:
raise HTTPException(status_code=503, detail="Model still loading, please retry")
try:
struct = Structure.from_str(req.cif, fmt="cif")
s_series = pd.Series([struct])
print("Computing Orb-v3 features...")
vec_orb = orb_model.get_features(s_series).values[0]
vec_orb = np.nan_to_num(vec_orb, nan=0.0)
vec_orb = normalize([vec_orb])[0]
print(f"Embedding complete: orb={len(vec_orb)}")
return EmbedResponse(
vectors={"orb": vec_orb.tolist()},
dims={"orb": ORB_DIM},
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}")
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"service": "Crystal-Embedder-Orb",
"version": "2.0.0",
"description": "Orb-v3 force field embeddings for crystal structures",
"endpoints": {
"/embed": "POST - Compute Orb-v3 embedding from CIF",
"/health": "GET - Health check",
},
"vector_dimensions": {"orb": ORB_DIM},
}