ACA050 commited on
Commit
d02e833
·
verified ·
1 Parent(s): d9aad77

Update backend/api/main.py

Browse files
Files changed (1) hide show
  1. backend/api/main.py +53 -4
backend/api/main.py CHANGED
@@ -5,12 +5,20 @@ import os
5
  from backend.core.orchestrator import Orchestrator
6
  from backend.api.schemas import PredictRequest
7
 
8
- app = FastAPI()
 
 
 
 
 
9
  orchestrator = Orchestrator()
10
 
11
  MODEL_PATH = "exports/models/trained_model.pkl"
12
 
13
 
 
 
 
14
  @app.post("/analyze")
15
  async def analyze_dataset(
16
  file: UploadFile = File(...),
@@ -18,6 +26,7 @@ async def analyze_dataset(
18
  ):
19
  try:
20
  df = pd.read_csv(file.file)
 
21
  result = orchestrator.run(df, target_column)
22
 
23
  dataset_info = result.get("dataset_info", {})
@@ -36,6 +45,9 @@ async def analyze_dataset(
36
  raise HTTPException(status_code=400, detail=str(e))
37
 
38
 
 
 
 
39
  @app.post("/train")
40
  async def train_model(
41
  file: UploadFile = File(...),
@@ -43,19 +55,23 @@ async def train_model(
43
  ):
44
  try:
45
  df = pd.read_csv(file.file)
 
46
  result = orchestrator.run(df, target_column, train=True)
47
 
48
  return {
49
  "strategy": result.get("strategy"),
50
  "metrics": result.get("metrics"),
51
  "model_path": MODEL_PATH,
52
- "model_id": "trained_model",
53
  }
54
 
55
  except Exception as e:
56
  raise HTTPException(status_code=400, detail=str(e))
57
 
58
 
 
 
 
59
  @app.post("/explain")
60
  async def explain_model(
61
  file: UploadFile = File(...),
@@ -63,29 +79,36 @@ async def explain_model(
63
  ):
64
  try:
65
  df = pd.read_csv(file.file)
 
66
  result = orchestrator.run(df, target_column, train=True)
67
 
68
  return {
69
  "strategy_explanation": result.get("strategy_explanation"),
70
  "metrics": result.get("metrics"),
71
- "feature_importance": result.get("feature_importance"),
72
  }
73
 
74
  except Exception as e:
75
  raise HTTPException(status_code=400, detail=str(e))
76
 
77
 
 
 
 
78
  @app.post("/predict")
79
  async def predict(request: PredictRequest):
80
  try:
 
81
  if not os.path.exists(MODEL_PATH):
82
  raise HTTPException(
83
  status_code=400,
84
  detail="No trained model found. Train a model first."
85
  )
86
 
 
87
  model = orchestrator.model_io.load(MODEL_PATH)
88
 
 
89
  df = pd.DataFrame(request.instances)
90
 
91
  if df.empty:
@@ -94,19 +117,45 @@ async def predict(request: PredictRequest):
94
  detail="Prediction data is empty."
95
  )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  predictions = model.predict(df)
98
 
99
  return {
100
  "predictions": predictions.tolist()
101
  }
102
 
 
 
103
  except Exception as e:
104
  raise HTTPException(status_code=400, detail=str(e))
105
 
106
 
 
 
 
107
  @app.get("/")
108
  def root():
109
- return {"message": "ModelSmith AI API", "status": "running"}
 
 
 
110
 
111
 
112
  @app.get("/health")
 
5
  from backend.core.orchestrator import Orchestrator
6
  from backend.api.schemas import PredictRequest
7
 
8
+ app = FastAPI(
9
+ title="ModelSmith AI",
10
+ description="Automated ML platform for tabular data",
11
+ version="1.0.0"
12
+ )
13
+
14
  orchestrator = Orchestrator()
15
 
16
  MODEL_PATH = "exports/models/trained_model.pkl"
17
 
18
 
19
+ # -----------------------------
20
+ # Analyze Dataset
21
+ # -----------------------------
22
  @app.post("/analyze")
23
  async def analyze_dataset(
24
  file: UploadFile = File(...),
 
26
  ):
27
  try:
28
  df = pd.read_csv(file.file)
29
+
30
  result = orchestrator.run(df, target_column)
31
 
32
  dataset_info = result.get("dataset_info", {})
 
45
  raise HTTPException(status_code=400, detail=str(e))
46
 
47
 
48
+ # -----------------------------
49
+ # Train Model
50
+ # -----------------------------
51
  @app.post("/train")
52
  async def train_model(
53
  file: UploadFile = File(...),
 
55
  ):
56
  try:
57
  df = pd.read_csv(file.file)
58
+
59
  result = orchestrator.run(df, target_column, train=True)
60
 
61
  return {
62
  "strategy": result.get("strategy"),
63
  "metrics": result.get("metrics"),
64
  "model_path": MODEL_PATH,
65
+ "model_id": "trained_model"
66
  }
67
 
68
  except Exception as e:
69
  raise HTTPException(status_code=400, detail=str(e))
70
 
71
 
72
+ # -----------------------------
73
+ # Explain Model
74
+ # -----------------------------
75
  @app.post("/explain")
76
  async def explain_model(
77
  file: UploadFile = File(...),
 
79
  ):
80
  try:
81
  df = pd.read_csv(file.file)
82
+
83
  result = orchestrator.run(df, target_column, train=True)
84
 
85
  return {
86
  "strategy_explanation": result.get("strategy_explanation"),
87
  "metrics": result.get("metrics"),
88
+ "feature_importance": result.get("feature_importance")
89
  }
90
 
91
  except Exception as e:
92
  raise HTTPException(status_code=400, detail=str(e))
93
 
94
 
95
+ # -----------------------------
96
+ # Predict
97
+ # -----------------------------
98
  @app.post("/predict")
99
  async def predict(request: PredictRequest):
100
  try:
101
+ # 1. Check model exists
102
  if not os.path.exists(MODEL_PATH):
103
  raise HTTPException(
104
  status_code=400,
105
  detail="No trained model found. Train a model first."
106
  )
107
 
108
+ # 2. Load model
109
  model = orchestrator.model_io.load(MODEL_PATH)
110
 
111
+ # 3. Convert input to DataFrame
112
  df = pd.DataFrame(request.instances)
113
 
114
  if df.empty:
 
117
  detail="Prediction data is empty."
118
  )
119
 
120
+ # 4. Validate feature columns
121
+ if hasattr(model, "feature_names_in_"):
122
+ expected_features = list(model.feature_names_in_)
123
+ received_features = list(df.columns)
124
+
125
+ missing = set(expected_features) - set(received_features)
126
+ extra = set(received_features) - set(expected_features)
127
+
128
+ if missing:
129
+ raise HTTPException(
130
+ status_code=400,
131
+ detail=f"Missing required features: {sorted(missing)}"
132
+ )
133
+
134
+ # Drop extra columns if any
135
+ df = df[expected_features]
136
+
137
+ # 5. Predict
138
  predictions = model.predict(df)
139
 
140
  return {
141
  "predictions": predictions.tolist()
142
  }
143
 
144
+ except HTTPException:
145
+ raise
146
  except Exception as e:
147
  raise HTTPException(status_code=400, detail=str(e))
148
 
149
 
150
+ # -----------------------------
151
+ # Root & Health
152
+ # -----------------------------
153
  @app.get("/")
154
  def root():
155
+ return {
156
+ "message": "ModelSmith AI API",
157
+ "status": "running"
158
+ }
159
 
160
 
161
  @app.get("/health")