Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from pydantic import BaseModel | |
| from typing import Dict | |
| from src.tests.config import API_TITLE, API_VERSION, API_DESCRIPTION | |
| from src.tests.inference import predictor | |
| app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION) | |
| templates = Jinja2Templates(directory=".") | |
| class PredictionRequest(BaseModel): | |
| features: Dict[str, float] | |
| class PredictionResponse(BaseModel): | |
| prediction: str | |
| features_used: int | |
| class ProbabilityResponse(BaseModel): | |
| probabilities: Dict[str, float] | |
| features_used: int | |
| async def home(request: Request): | |
| feature_names = predictor.get_feature_names() | |
| # Load top 10 features from features.json | |
| predictor.load_model() # Ensure features are loaded | |
| top_10_features = predictor.features['top_10_features'] | |
| return templates.TemplateResponse( | |
| "index.html", | |
| {"request": request, "features": {"all_features": feature_names, "top_10_features": top_10_features}} | |
| ) | |
| async def health(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": predictor._model_loaded, | |
| "model_ready": predictor.model is not None | |
| } | |
| async def get_features(): | |
| return {"features": predictor.get_feature_names()} | |
| async def predict(request: PredictionRequest): | |
| # Validate missing features | |
| expected = set(predictor.get_feature_names()) | |
| incoming = set(request.features.keys()) | |
| missing = expected - incoming | |
| if missing: | |
| raise HTTPException(400, f"Missing features: {missing}") | |
| prediction = predictor.predict(request.features) | |
| return PredictionResponse( | |
| prediction=prediction, | |
| features_used=len(request.features) | |
| ) | |
| async def predict_proba(request: PredictionRequest): | |
| # Validate missing features | |
| expected = set(predictor.get_feature_names()) | |
| incoming = set(request.features.keys()) | |
| missing = expected - incoming | |
| if missing: | |
| raise HTTPException(400, f"Missing features: {missing}") | |
| probabilities = predictor.predict_proba(request.features) | |
| return ProbabilityResponse( | |
| probabilities=probabilities, | |
| features_used=len(request.features) | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |