Spaces:
Sleeping
Sleeping
File size: 3,723 Bytes
e620469 8b5c1f3 e620469 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import os
import shutil
from pathlib import Path
from typing import Callable, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download
from pydantic import BaseModel, Field
THRESHOLD = 0.33
predict_all_fn: Optional[Callable[..., Dict[str, object]]] = None
PROJECT_DIR = Path(__file__).resolve().parent
HF_REPO_ID = "asif45/LLM-PROP"
HF_TOKEN = os.getenv("HF_TOKEN")
CHECKPOINT_FILES = {
"checkpoints/samples/classification/best_checkpoint_for_is_gap_direct.pt": "checkpoints/samples/classification/best_checkpoint_for_is_gap_direct.pt",
"checkpoints/samples/regression/best_checkpoint_for_band_gap.pt": "checkpoints/samples/regression/best_checkpoint_for_band_gap.pt",
"checkpoints/samples/regression/best_checkpoint_for_energy_per_atom.pt": "checkpoints/samples/regression/best_checkpoint_for_energy_per_atom.pt",
"checkpoints/samples/regression/best_checkpoint_for_e_above_hull.pt": "checkpoints/samples/regression/best_checkpoint_for_e_above_hull.pt",
"checkpoints/samples/regression/best_checkpoint_for_fepa.pt": "checkpoints/samples/regression/best_checkpoint_for_fepa.pt",
"checkpoints/samples/regression/best_checkpoint_for_volume.pt": "checkpoints/samples/regression/best_checkpoint_for_volume.pt",
}
class PredictRequest(BaseModel):
text: str = Field(..., description="Crystal description text")
class PredictResponse(BaseModel):
is_gap_direct: str
energy_per_atom: float
formation_energy_per_atom: float
band_gap: float
e_above_hull: float
volume: float
app = FastAPI(title="Crystal Property Predictor API")
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:8080",
"http://127.0.0.1:8080",
"http://localhost:5173",
"http://127.0.0.1:5173",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def ensure_checkpoint_files() -> None:
for local_relative_path, repo_file_path in CHECKPOINT_FILES.items():
local_path = PROJECT_DIR / local_relative_path
if local_path.exists():
continue
local_path.parent.mkdir(parents=True, exist_ok=True)
downloaded_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=repo_file_path,
repo_type="model",
token=HF_TOKEN,
)
shutil.copy2(downloaded_path, local_path)
@app.on_event("startup")
def load_model_once() -> None:
# Download missing checkpoints first, then import the predictor so it loads the local files once.
ensure_checkpoint_files()
global predict_all_fn
from predict_all import predict_all
predict_all_fn = predict_all
@app.get("/health")
def health() -> Dict[str, object]:
return {"status": "ok", "model_loaded": predict_all_fn is not None}
@app.post("/predict", response_model=PredictResponse)
def predict(payload: PredictRequest) -> PredictResponse:
if predict_all_fn is None:
raise HTTPException(status_code=503, detail="Model is not loaded yet")
text = payload.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Text input cannot be empty")
predictions = predict_all_fn(text, threshold=THRESHOLD)
filtered_predictions = {
"is_gap_direct": predictions["is_gap_direct"],
"energy_per_atom": predictions["energy_per_atom"],
"formation_energy_per_atom": predictions["formation_energy_per_atom"],
"band_gap": predictions["band_gap"],
"e_above_hull": predictions["e_above_hull"],
"volume": predictions["volume"],
}
return PredictResponse(**filtered_predictions)
|