File size: 3,723 Bytes
e620469
 
 
 
 
 
 
 
 
 
 
 
8b5c1f3
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import shutil
from pathlib import Path
from typing import Callable, Dict, Optional

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
from pydantic import BaseModel, Field

THRESHOLD = 0.33
predict_all_fn: Optional[Callable[..., Dict[str, object]]] = None
PROJECT_DIR = Path(__file__).resolve().parent
HF_REPO_ID = "asif45/LLM-PROP"
HF_TOKEN = os.getenv("HF_TOKEN")
CHECKPOINT_FILES = {
    "checkpoints/samples/classification/best_checkpoint_for_is_gap_direct.pt": "checkpoints/samples/classification/best_checkpoint_for_is_gap_direct.pt",
    "checkpoints/samples/regression/best_checkpoint_for_band_gap.pt": "checkpoints/samples/regression/best_checkpoint_for_band_gap.pt",
    "checkpoints/samples/regression/best_checkpoint_for_energy_per_atom.pt": "checkpoints/samples/regression/best_checkpoint_for_energy_per_atom.pt",
    "checkpoints/samples/regression/best_checkpoint_for_e_above_hull.pt": "checkpoints/samples/regression/best_checkpoint_for_e_above_hull.pt",
    "checkpoints/samples/regression/best_checkpoint_for_fepa.pt": "checkpoints/samples/regression/best_checkpoint_for_fepa.pt",
    "checkpoints/samples/regression/best_checkpoint_for_volume.pt": "checkpoints/samples/regression/best_checkpoint_for_volume.pt",
}


class PredictRequest(BaseModel):
    text: str = Field(..., description="Crystal description text")


class PredictResponse(BaseModel):
    is_gap_direct: str
    energy_per_atom: float
    formation_energy_per_atom: float
    band_gap: float
    e_above_hull: float
    volume: float


app = FastAPI(title="Crystal Property Predictor API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:8080",
        "http://127.0.0.1:8080",
        "http://localhost:5173",
        "http://127.0.0.1:5173",
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def ensure_checkpoint_files() -> None:
    for local_relative_path, repo_file_path in CHECKPOINT_FILES.items():
        local_path = PROJECT_DIR / local_relative_path
        if local_path.exists():
            continue

        local_path.parent.mkdir(parents=True, exist_ok=True)
        downloaded_path = hf_hub_download(
            repo_id=HF_REPO_ID,
            filename=repo_file_path,
            repo_type="model",
            token=HF_TOKEN,
        )
        shutil.copy2(downloaded_path, local_path)


@app.on_event("startup")
def load_model_once() -> None:
    # Download missing checkpoints first, then import the predictor so it loads the local files once.
    ensure_checkpoint_files()

    global predict_all_fn
    from predict_all import predict_all

    predict_all_fn = predict_all


@app.get("/health")
def health() -> Dict[str, object]:
    return {"status": "ok", "model_loaded": predict_all_fn is not None}


@app.post("/predict", response_model=PredictResponse)
def predict(payload: PredictRequest) -> PredictResponse:
    if predict_all_fn is None:
        raise HTTPException(status_code=503, detail="Model is not loaded yet")

    text = payload.text.strip()
    if not text:
        raise HTTPException(status_code=400, detail="Text input cannot be empty")

    predictions = predict_all_fn(text, threshold=THRESHOLD)
    filtered_predictions = {
        "is_gap_direct": predictions["is_gap_direct"],
        "energy_per_atom": predictions["energy_per_atom"],
        "formation_energy_per_atom": predictions["formation_energy_per_atom"],
        "band_gap": predictions["band_gap"],
        "e_above_hull": predictions["e_above_hull"],
        "volume": predictions["volume"],
    }
    return PredictResponse(**filtered_predictions)