Spaces:
Sleeping
Sleeping
File size: 4,885 Bytes
668ae63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
Crystal-Embedder Microservice
Computes Tri-Fusion embeddings (Orb-v3 + l-MM + l-OFM) for crystal structures.
Deployed on Hugging Face Spaces (CPU-only, 16GB RAM).
API:
POST /embed
Request: {"cif": "<CIF content>"}
Response: {"vector": [...], "dims": 2738}
"""
import os
# CRITICAL: Set TensorFlow to use legacy Keras before any imports
# Required for MEGNet compatibility with TensorFlow 2.16+
os.environ["TF_USE_LEGACY_KERAS"] = "1"
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize
from pymatgen.core import Structure
app = FastAPI(
title="Crystal-Embedder",
description="Tri-Fusion physics embeddings for crystal structures",
version="1.0.0"
)
# Global model instances (loaded at startup)
orb_model = None
mm_model = None
ofm_model = None
models_loaded = False
class CifRequest(BaseModel):
cif: str
class EmbedResponse(BaseModel):
vector: list[float]
dims: int
class HealthResponse(BaseModel):
status: str
models_loaded: bool
vector_dims: int
def load_models():
"""Load all three featurizer models at startup."""
global orb_model, mm_model, ofm_model, models_loaded
print("Loading Physics Engines (Orb-v3 + l-MM + l-OFM)...")
print("This may take a few minutes on first load...")
try:
# Import MatterVial featurizers
from mattervial.featurizers import ORBFeaturizer, DescriptorMEGNetFeaturizer
# Load Orb-v3 (PyTorch, CPU)
print(" Loading Orb-v3 (1792-dim)...")
orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu")
# Load l-MM (TensorFlow/MEGNet)
print(" Loading l-MM (758-dim)...")
mm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-MM_v1')
# Load l-OFM (TensorFlow/MEGNet)
print(" Loading l-OFM (188-dim)...")
ofm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-OFM_v1')
models_loaded = True
print("All models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
raise
@app.on_event("startup")
async def startup_event():
"""Load models when the server starts."""
load_models()
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return HealthResponse(
status="healthy" if models_loaded else "loading",
models_loaded=models_loaded,
vector_dims=2738
)
@app.post("/embed", response_model=EmbedResponse)
async def embed_structure(req: CifRequest):
"""
Compute Tri-Fusion embedding for a CIF structure.
The embedding is a concatenation of three L2-normalized vectors:
- Orb-v3: 1792-dim (force field features)
- l-MM: 758-dim (electronic structure features)
- l-OFM: 188-dim (orbital field matrix features)
Total: 2738 dimensions
"""
if not models_loaded:
raise HTTPException(status_code=503, detail="Models still loading, please retry")
try:
# 1. Parse CIF to pymatgen Structure
struct = Structure.from_str(req.cif, fmt="cif")
s_series = pd.Series([struct])
# 2. Compute features (sequential on CPU)
print("Computing Orb-v3 features...")
vec_orb = orb_model.get_features(s_series).values[0] # ~10s
print("Computing l-MM features...")
vec_mm = mm_model.get_features(s_series).values[0] # ~2s
print("Computing l-OFM features...")
vec_ofm = ofm_model.get_features(s_series).values[0] # ~0.5s
# 3. Handle NaN values (replace with 0)
vec_orb = np.nan_to_num(vec_orb, nan=0.0)
vec_mm = np.nan_to_num(vec_mm, nan=0.0)
vec_ofm = np.nan_to_num(vec_ofm, nan=0.0)
# 4. L2 Normalize each vector (prevents magnitude dominance)
v1 = normalize([vec_orb])[0]
v2 = normalize([vec_mm])[0]
v3 = normalize([vec_ofm])[0]
# 5. Concatenate
final_vector = np.concatenate([v1, v2, v3])
print(f"Embedding complete: {len(final_vector)} dimensions")
return EmbedResponse(
vector=final_vector.tolist(),
dims=len(final_vector)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}")
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"service": "Crystal-Embedder",
"version": "1.0.0",
"description": "Tri-Fusion physics embeddings for crystal structures",
"endpoints": {
"/embed": "POST - Compute embedding from CIF",
"/health": "GET - Health check"
},
"vector_dimensions": {
"orb_v3": 1792,
"l_mm": 758,
"l_ofm": 188,
"total": 2738
}
}
|