Spaces:
Sleeping
Sleeping
File size: 2,658 Bytes
5e82fdf e934d38 5e82fdf e48262c 5e82fdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from typing import Dict
from src.tests.config import API_TITLE, API_VERSION, API_DESCRIPTION
from src.tests.inference import predictor
app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION)
templates = Jinja2Templates(directory=".")
class PredictionRequest(BaseModel):
features: Dict[str, float]
class PredictionResponse(BaseModel):
prediction: str
features_used: int
class ProbabilityResponse(BaseModel):
probabilities: Dict[str, float]
features_used: int
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
feature_names = predictor.get_feature_names()
# Load top 10 features from features.json
predictor.load_model() # Ensure features are loaded
top_10_features = predictor.features['top_10_features']
return templates.TemplateResponse(
"index.html",
{"request": request, "features": {"all_features": feature_names, "top_10_features": top_10_features}}
)
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": predictor._model_loaded,
"model_ready": predictor.model is not None
}
@app.get("/features")
async def get_features():
return {"features": predictor.get_feature_names()}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
# Validate missing features
expected = set(predictor.get_feature_names())
incoming = set(request.features.keys())
missing = expected - incoming
if missing:
raise HTTPException(400, f"Missing features: {missing}")
prediction = predictor.predict(request.features)
return PredictionResponse(
prediction=prediction,
features_used=len(request.features)
)
@app.post("/predict_proba", response_model=ProbabilityResponse)
async def predict_proba(request: PredictionRequest):
# Validate missing features
expected = set(predictor.get_feature_names())
incoming = set(request.features.keys())
missing = expected - incoming
if missing:
raise HTTPException(400, f"Missing features: {missing}")
probabilities = predictor.predict_proba(request.features)
return ProbabilityResponse(
probabilities=probabilities,
features_used=len(request.features)
)
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)
|