crystal-embedder / main.py
Hafnium13's picture
Initial deployment: Tri-Fusion Crystal Embedder
668ae63
"""
Crystal-Embedder Microservice
Computes Tri-Fusion embeddings (Orb-v3 + l-MM + l-OFM) for crystal structures.
Deployed on Hugging Face Spaces (CPU-only, 16GB RAM).
API:
POST /embed
Request: {"cif": "<CIF content>"}
Response: {"vector": [...], "dims": 2738}
"""
import os
# CRITICAL: Set TensorFlow to use legacy Keras before any imports
# Required for MEGNet compatibility with TensorFlow 2.16+
os.environ["TF_USE_LEGACY_KERAS"] = "1"
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize
from pymatgen.core import Structure
app = FastAPI(
title="Crystal-Embedder",
description="Tri-Fusion physics embeddings for crystal structures",
version="1.0.0"
)
# Global model instances (loaded at startup)
orb_model = None
mm_model = None
ofm_model = None
models_loaded = False
class CifRequest(BaseModel):
cif: str
class EmbedResponse(BaseModel):
vector: list[float]
dims: int
class HealthResponse(BaseModel):
status: str
models_loaded: bool
vector_dims: int
def load_models():
"""Load all three featurizer models at startup."""
global orb_model, mm_model, ofm_model, models_loaded
print("Loading Physics Engines (Orb-v3 + l-MM + l-OFM)...")
print("This may take a few minutes on first load...")
try:
# Import MatterVial featurizers
from mattervial.featurizers import ORBFeaturizer, DescriptorMEGNetFeaturizer
# Load Orb-v3 (PyTorch, CPU)
print(" Loading Orb-v3 (1792-dim)...")
orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu")
# Load l-MM (TensorFlow/MEGNet)
print(" Loading l-MM (758-dim)...")
mm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-MM_v1')
# Load l-OFM (TensorFlow/MEGNet)
print(" Loading l-OFM (188-dim)...")
ofm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-OFM_v1')
models_loaded = True
print("All models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
raise
@app.on_event("startup")
async def startup_event():
"""Load models when the server starts."""
load_models()
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return HealthResponse(
status="healthy" if models_loaded else "loading",
models_loaded=models_loaded,
vector_dims=2738
)
@app.post("/embed", response_model=EmbedResponse)
async def embed_structure(req: CifRequest):
"""
Compute Tri-Fusion embedding for a CIF structure.
The embedding is a concatenation of three L2-normalized vectors:
- Orb-v3: 1792-dim (force field features)
- l-MM: 758-dim (electronic structure features)
- l-OFM: 188-dim (orbital field matrix features)
Total: 2738 dimensions
"""
if not models_loaded:
raise HTTPException(status_code=503, detail="Models still loading, please retry")
try:
# 1. Parse CIF to pymatgen Structure
struct = Structure.from_str(req.cif, fmt="cif")
s_series = pd.Series([struct])
# 2. Compute features (sequential on CPU)
print("Computing Orb-v3 features...")
vec_orb = orb_model.get_features(s_series).values[0] # ~10s
print("Computing l-MM features...")
vec_mm = mm_model.get_features(s_series).values[0] # ~2s
print("Computing l-OFM features...")
vec_ofm = ofm_model.get_features(s_series).values[0] # ~0.5s
# 3. Handle NaN values (replace with 0)
vec_orb = np.nan_to_num(vec_orb, nan=0.0)
vec_mm = np.nan_to_num(vec_mm, nan=0.0)
vec_ofm = np.nan_to_num(vec_ofm, nan=0.0)
# 4. L2 Normalize each vector (prevents magnitude dominance)
v1 = normalize([vec_orb])[0]
v2 = normalize([vec_mm])[0]
v3 = normalize([vec_ofm])[0]
# 5. Concatenate
final_vector = np.concatenate([v1, v2, v3])
print(f"Embedding complete: {len(final_vector)} dimensions")
return EmbedResponse(
vector=final_vector.tolist(),
dims=len(final_vector)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}")
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"service": "Crystal-Embedder",
"version": "1.0.0",
"description": "Tri-Fusion physics embeddings for crystal structures",
"endpoints": {
"/embed": "POST - Compute embedding from CIF",
"/health": "GET - Health check"
},
"vector_dimensions": {
"orb_v3": 1792,
"l_mm": 758,
"l_ofm": 188,
"total": 2738
}
}