File size: 4,885 Bytes
668ae63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Crystal-Embedder Microservice

Computes Tri-Fusion embeddings (Orb-v3 + l-MM + l-OFM) for crystal structures.
Deployed on Hugging Face Spaces (CPU-only, 16GB RAM).

API:
    POST /embed
    Request: {"cif": "<CIF content>"}
    Response: {"vector": [...], "dims": 2738}
"""

import os

# CRITICAL: Set TensorFlow to use legacy Keras before any imports
# Required for MEGNet compatibility with TensorFlow 2.16+
os.environ["TF_USE_LEGACY_KERAS"] = "1"

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import pandas as pd
from sklearn.preprocessing import normalize
from pymatgen.core import Structure

app = FastAPI(
    title="Crystal-Embedder",
    description="Tri-Fusion physics embeddings for crystal structures",
    version="1.0.0"
)

# Global model instances (loaded at startup)
orb_model = None
mm_model = None
ofm_model = None
models_loaded = False


class CifRequest(BaseModel):
    cif: str


class EmbedResponse(BaseModel):
    vector: list[float]
    dims: int


class HealthResponse(BaseModel):
    status: str
    models_loaded: bool
    vector_dims: int


def load_models():
    """Load all three featurizer models at startup."""
    global orb_model, mm_model, ofm_model, models_loaded

    print("Loading Physics Engines (Orb-v3 + l-MM + l-OFM)...")
    print("This may take a few minutes on first load...")

    try:
        # Import MatterVial featurizers
        from mattervial.featurizers import ORBFeaturizer, DescriptorMEGNetFeaturizer

        # Load Orb-v3 (PyTorch, CPU)
        print("  Loading Orb-v3 (1792-dim)...")
        orb_model = ORBFeaturizer(model_name="ORB_v3", device="cpu")

        # Load l-MM (TensorFlow/MEGNet)
        print("  Loading l-MM (758-dim)...")
        mm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-MM_v1')

        # Load l-OFM (TensorFlow/MEGNet)
        print("  Loading l-OFM (188-dim)...")
        ofm_model = DescriptorMEGNetFeaturizer(base_descriptor='l-OFM_v1')

        models_loaded = True
        print("All models loaded successfully!")

    except Exception as e:
        print(f"Error loading models: {e}")
        raise


@app.on_event("startup")
async def startup_event():
    """Load models when the server starts."""
    load_models()


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint."""
    return HealthResponse(
        status="healthy" if models_loaded else "loading",
        models_loaded=models_loaded,
        vector_dims=2738
    )


@app.post("/embed", response_model=EmbedResponse)
async def embed_structure(req: CifRequest):
    """
    Compute Tri-Fusion embedding for a CIF structure.

    The embedding is a concatenation of three L2-normalized vectors:
    - Orb-v3: 1792-dim (force field features)
    - l-MM: 758-dim (electronic structure features)
    - l-OFM: 188-dim (orbital field matrix features)

    Total: 2738 dimensions
    """
    if not models_loaded:
        raise HTTPException(status_code=503, detail="Models still loading, please retry")

    try:
        # 1. Parse CIF to pymatgen Structure
        struct = Structure.from_str(req.cif, fmt="cif")
        s_series = pd.Series([struct])

        # 2. Compute features (sequential on CPU)
        print("Computing Orb-v3 features...")
        vec_orb = orb_model.get_features(s_series).values[0]  # ~10s

        print("Computing l-MM features...")
        vec_mm = mm_model.get_features(s_series).values[0]    # ~2s

        print("Computing l-OFM features...")
        vec_ofm = ofm_model.get_features(s_series).values[0]  # ~0.5s

        # 3. Handle NaN values (replace with 0)
        vec_orb = np.nan_to_num(vec_orb, nan=0.0)
        vec_mm = np.nan_to_num(vec_mm, nan=0.0)
        vec_ofm = np.nan_to_num(vec_ofm, nan=0.0)

        # 4. L2 Normalize each vector (prevents magnitude dominance)
        v1 = normalize([vec_orb])[0]
        v2 = normalize([vec_mm])[0]
        v3 = normalize([vec_ofm])[0]

        # 5. Concatenate
        final_vector = np.concatenate([v1, v2, v3])

        print(f"Embedding complete: {len(final_vector)} dimensions")

        return EmbedResponse(
            vector=final_vector.tolist(),
            dims=len(final_vector)
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Embedding failed: {str(e)}")


@app.get("/")
async def root():
    """Root endpoint with API info."""
    return {
        "service": "Crystal-Embedder",
        "version": "1.0.0",
        "description": "Tri-Fusion physics embeddings for crystal structures",
        "endpoints": {
            "/embed": "POST - Compute embedding from CIF",
            "/health": "GET - Health check"
        },
        "vector_dimensions": {
            "orb_v3": 1792,
            "l_mm": 758,
            "l_ofm": 188,
            "total": 2738
        }
    }