iremrit commited on
Commit
24409e1
·
verified ·
1 Parent(s): e48262c

Delete src/tests/app.py

Browse files
Files changed (1) hide show
  1. src/tests/app.py +0 -94
src/tests/app.py DELETED
@@ -1,94 +0,0 @@
1
- import os
2
- from fastapi import FastAPI, HTTPException, Request
3
- from fastapi.responses import HTMLResponse
4
- from fastapi.templating import Jinja2Templates
5
- from pydantic import BaseModel
6
- from typing import Dict
7
- from config import API_TITLE, API_VERSION, API_DESCRIPTION
8
- from inference import predictor
9
-
10
- app = FastAPI(title=API_TITLE, version=API_VERSION, description=API_DESCRIPTION)
11
-
12
- templates = Jinja2Templates(directory="src/templates")
13
-
14
-
15
- class PredictionRequest(BaseModel):
16
- features: Dict[str, float]
17
-
18
-
19
- class PredictionResponse(BaseModel):
20
- prediction: str
21
- features_used: int
22
-
23
-
24
- class ProbabilityResponse(BaseModel):
25
- probabilities: Dict[str, float]
26
- features_used: int
27
-
28
-
29
- @app.get("/", response_class=HTMLResponse)
30
- async def home(request: Request):
31
- feature_names = predictor.get_feature_names()
32
- # Load top 10 features from features.json
33
- predictor.load_model() # Ensure features are loaded
34
- top_10_features = predictor.features['top_10_features']
35
- return templates.TemplateResponse(
36
- "index.html",
37
- {"request": request, "features": {"all_features": feature_names, "top_10_features": top_10_features}}
38
- )
39
-
40
-
41
- @app.get("/health")
42
- async def health():
43
- return {
44
- "status": "healthy",
45
- "model_loaded": predictor._model_loaded,
46
- "model_ready": predictor.model is not None
47
- }
48
-
49
-
50
- @app.get("/features")
51
- async def get_features():
52
- return {"features": predictor.get_feature_names()}
53
-
54
-
55
- @app.post("/predict", response_model=PredictionResponse)
56
- async def predict(request: PredictionRequest):
57
-
58
- # Validate missing features
59
- expected = set(predictor.get_feature_names())
60
- incoming = set(request.features.keys())
61
-
62
- missing = expected - incoming
63
- if missing:
64
- raise HTTPException(400, f"Missing features: {missing}")
65
- prediction = predictor.predict(request.features)
66
-
67
- return PredictionResponse(
68
- prediction=prediction,
69
- features_used=len(request.features)
70
- )
71
-
72
-
73
- @app.post("/predict_proba", response_model=ProbabilityResponse)
74
- async def predict_proba(request: PredictionRequest):
75
-
76
- # Validate missing features
77
- expected = set(predictor.get_feature_names())
78
- incoming = set(request.features.keys())
79
-
80
- missing = expected - incoming
81
- if missing:
82
- raise HTTPException(400, f"Missing features: {missing}")
83
- probabilities = predictor.predict_proba(request.features)
84
-
85
- return ProbabilityResponse(
86
- probabilities=probabilities,
87
- features_used=len(request.features)
88
- )
89
-
90
-
91
- if __name__ == "__main__":
92
- import uvicorn
93
- port = int(os.environ.get("PORT", 8000))
94
- uvicorn.run(app, host="0.0.0.0", port=port)