ACA050 commited on
Commit
d9aad77
·
verified ·
1 Parent(s): c4c8f23

Update backend/api/main.py

Browse files
Files changed (1) hide show
  1. backend/api/main.py +58 -28
backend/api/main.py CHANGED
@@ -1,21 +1,29 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
  import pandas as pd
 
 
3
  from backend.core.orchestrator import Orchestrator
 
4
 
5
  app = FastAPI()
6
  orchestrator = Orchestrator()
7
 
 
 
 
8
  @app.post("/analyze")
9
- async def analyze_dataset(file: UploadFile = File(...), target_column: str = "target"):
 
 
 
10
  try:
11
  df = pd.read_csv(file.file)
12
  result = orchestrator.run(df, target_column)
13
 
14
- # Format response for frontend
15
  dataset_info = result.get("dataset_info", {})
16
  strategy = result.get("strategy", {})
17
 
18
- response = {
19
  "columns": list(df.columns),
20
  "dataTypes": dataset_info.get("data_types", {}),
21
  "risks": dataset_info.get("risks", []),
@@ -23,62 +31,84 @@ async def analyze_dataset(file: UploadFile = File(...), target_column: str = "ta
23
  "confidence": strategy.get("confidence", 0),
24
  "strategy": strategy
25
  }
26
- return response
27
  except Exception as e:
28
  raise HTTPException(status_code=400, detail=str(e))
29
 
 
30
  @app.post("/train")
31
- async def train_model(file: UploadFile = File(...), target_column: str = "target"):
 
 
 
32
  try:
33
  df = pd.read_csv(file.file)
34
  result = orchestrator.run(df, target_column, train=True)
35
 
36
- # Ensure strategy is included in the response
37
- strategy = result.get("strategy", {})
38
- response = {
39
- "strategy": strategy,
40
- "metrics": result.get("metrics", {}),
41
- "model_path": result.get("model_path", "/path/to/model.pkl"),
42
- "training_time": result.get("training_time", 0),
43
- "model_id": result.get("model_id", "trained_model_123")
44
  }
45
- return response
46
  except Exception as e:
47
  raise HTTPException(status_code=400, detail=str(e))
48
 
 
49
  @app.post("/explain")
50
- async def explain_model(file: UploadFile = File(...), target_column: str = "target"):
 
 
 
51
  try:
52
  df = pd.read_csv(file.file)
53
  result = orchestrator.run(df, target_column, train=True)
 
54
  return {
55
  "strategy_explanation": result.get("strategy_explanation"),
56
- "metrics": result.get("metrics", {}),
57
- "feature_importance": result.get("feature_importance", [])
58
  }
 
59
  except Exception as e:
60
  raise HTTPException(status_code=400, detail=str(e))
61
 
 
62
  @app.post("/predict")
63
- async def predict(data: dict):
64
  try:
65
- # Load the trained model
66
- model = orchestrator.model_io.load("exports/models/trained_model.pkl")
67
- # Prepare data for prediction
68
- df = pd.DataFrame([data])
69
- preds = model.predict(df)
70
- return {"prediction": preds.tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
  raise HTTPException(status_code=400, detail=str(e))
73
 
 
74
  @app.get("/")
75
  def root():
76
  return {"message": "ModelSmith AI API", "status": "running"}
77
 
 
78
  @app.get("/health")
79
  def health():
80
  return {"status": "ok"}
81
-
82
-
83
-
84
-
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
  import pandas as pd
3
+ import os
4
+
5
  from backend.core.orchestrator import Orchestrator
6
+ from backend.api.schemas import PredictRequest
7
 
8
  app = FastAPI()
9
  orchestrator = Orchestrator()
10
 
11
+ MODEL_PATH = "exports/models/trained_model.pkl"
12
+
13
+
14
  @app.post("/analyze")
15
+ async def analyze_dataset(
16
+ file: UploadFile = File(...),
17
+ target_column: str = "target"
18
+ ):
19
  try:
20
  df = pd.read_csv(file.file)
21
  result = orchestrator.run(df, target_column)
22
 
 
23
  dataset_info = result.get("dataset_info", {})
24
  strategy = result.get("strategy", {})
25
 
26
+ return {
27
  "columns": list(df.columns),
28
  "dataTypes": dataset_info.get("data_types", {}),
29
  "risks": dataset_info.get("risks", []),
 
31
  "confidence": strategy.get("confidence", 0),
32
  "strategy": strategy
33
  }
34
+
35
  except Exception as e:
36
  raise HTTPException(status_code=400, detail=str(e))
37
 
38
+
39
  @app.post("/train")
40
+ async def train_model(
41
+ file: UploadFile = File(...),
42
+ target_column: str = "target"
43
+ ):
44
  try:
45
  df = pd.read_csv(file.file)
46
  result = orchestrator.run(df, target_column, train=True)
47
 
48
+ return {
49
+ "strategy": result.get("strategy"),
50
+ "metrics": result.get("metrics"),
51
+ "model_path": MODEL_PATH,
52
+ "model_id": "trained_model",
 
 
 
53
  }
54
+
55
  except Exception as e:
56
  raise HTTPException(status_code=400, detail=str(e))
57
 
58
+
59
  @app.post("/explain")
60
+ async def explain_model(
61
+ file: UploadFile = File(...),
62
+ target_column: str = "target"
63
+ ):
64
  try:
65
  df = pd.read_csv(file.file)
66
  result = orchestrator.run(df, target_column, train=True)
67
+
68
  return {
69
  "strategy_explanation": result.get("strategy_explanation"),
70
+ "metrics": result.get("metrics"),
71
+ "feature_importance": result.get("feature_importance"),
72
  }
73
+
74
  except Exception as e:
75
  raise HTTPException(status_code=400, detail=str(e))
76
 
77
+
78
  @app.post("/predict")
79
+ async def predict(request: PredictRequest):
80
  try:
81
+ if not os.path.exists(MODEL_PATH):
82
+ raise HTTPException(
83
+ status_code=400,
84
+ detail="No trained model found. Train a model first."
85
+ )
86
+
87
+ model = orchestrator.model_io.load(MODEL_PATH)
88
+
89
+ df = pd.DataFrame(request.instances)
90
+
91
+ if df.empty:
92
+ raise HTTPException(
93
+ status_code=400,
94
+ detail="Prediction data is empty."
95
+ )
96
+
97
+ predictions = model.predict(df)
98
+
99
+ return {
100
+ "predictions": predictions.tolist()
101
+ }
102
+
103
  except Exception as e:
104
  raise HTTPException(status_code=400, detail=str(e))
105
 
106
+
107
  @app.get("/")
108
  def root():
109
  return {"message": "ModelSmith AI API", "status": "running"}
110
 
111
+
112
  @app.get("/health")
113
  def health():
114
  return {"status": "ok"}