iremrit commited on
Commit
5e82fdf
·
verified ·
1 Parent(s): 23349d3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException, Request
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.templating import Jinja2Templates
5
+ from pydantic import BaseModel
6
+ from typing import Dict
7
+ from src.tests.config import API_TITLE, API_VERSION, API_DESCRIPTION
8
+ from src.tests.inference import predictor
9
+
10
+ app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION)
11
+
12
+ templates = Jinja2Templates(directory="src/templates")
13
+
14
+
15
+ class PredictionRequest(BaseModel):
16
+ features: Dict[str, float]
17
+
18
+
19
+ class PredictionResponse(BaseModel):
20
+ prediction: str
21
+ features_used: int
22
+
23
+
24
+ class ProbabilityResponse(BaseModel):
25
+ probabilities: Dict[str, float]
26
+ features_used: int
27
+
28
+
29
+ @app.get("/", response_class=HTMLResponse)
30
+ async def home(request: Request):
31
+ feature_names = predictor.get_feature_names()
32
+ # Load top 10 features from features.json
33
+ predictor.load_model() # Ensure features are loaded
34
+ top_10_features = predictor.features['top_10_features']
35
+ return templates.TemplateResponse(
36
+ "index.html",
37
+ {"request": request, "features": {"all_features": feature_names, "top_10_features": top_10_features}}
38
+ )
39
+
40
+
41
+ @app.get("/health")
42
+ async def health():
43
+ return {
44
+ "status": "healthy",
45
+ "model_loaded": predictor._model_loaded,
46
+ "model_ready": predictor.model is not None
47
+ }
48
+
49
+
50
+ @app.get("/features")
51
+ async def get_features():
52
+ return {"features": predictor.get_feature_names()}
53
+
54
+
55
+ @app.post("/predict", response_model=PredictionResponse)
56
+ async def predict(request: PredictionRequest):
57
+
58
+ # Validate missing features
59
+ expected = set(predictor.get_feature_names())
60
+ incoming = set(request.features.keys())
61
+
62
+ missing = expected - incoming
63
+ if missing:
64
+ raise HTTPException(400, f"Missing features: {missing}")
65
+ prediction = predictor.predict(request.features)
66
+
67
+ return PredictionResponse(
68
+ prediction=prediction,
69
+ features_used=len(request.features)
70
+ )
71
+
72
+
73
+ @app.post("/predict_proba", response_model=ProbabilityResponse)
74
+ async def predict_proba(request: PredictionRequest):
75
+
76
+ # Validate missing features
77
+ expected = set(predictor.get_feature_names())
78
+ incoming = set(request.features.keys())
79
+
80
+ missing = expected - incoming
81
+ if missing:
82
+ raise HTTPException(400, f"Missing features: {missing}")
83
+ probabilities = predictor.predict_proba(request.features)
84
+
85
+ return ProbabilityResponse(
86
+ probabilities=probabilities,
87
+ features_used=len(request.features)
88
+ )
89
+
90
+
91
+ if __name__ == "__main__":
92
+ import uvicorn
93
+ port = int(os.environ.get("PORT", 8000))
94
+ uvicorn.run(app, host="0.0.0.0", port=port)