ACA050 commited on
Commit
ac573e4
·
verified ·
1 Parent(s): 521bf0e

Update backend/api/main.py

Browse files
Files changed (1) hide show
  1. backend/api/main.py +84 -80
backend/api/main.py CHANGED
@@ -1,80 +1,84 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- import pandas as pd
3
- from backend.core.orchestrator import Orchestrator
4
-
5
- app = FastAPI()
6
- orchestrator = Orchestrator()
7
-
8
- @app.post("/analyze")
9
- async def analyze_dataset(file: UploadFile = File(...), target_column: str = "target"):
10
- try:
11
- df = pd.read_csv(file.file)
12
- result = orchestrator.run(df, target_column)
13
-
14
- # Format response for frontend
15
- dataset_info = result.get("dataset_info", {})
16
- strategy = result.get("strategy", {})
17
-
18
- response = {
19
- "columns": list(df.columns),
20
- "dataTypes": dataset_info.get("data_types", {}),
21
- "risks": dataset_info.get("risks", []),
22
- "problemType": result.get("problem_type"),
23
- "confidence": strategy.get("confidence", 0),
24
- "strategy": strategy
25
- }
26
- return response
27
- except Exception as e:
28
- raise HTTPException(status_code=400, detail=str(e))
29
-
30
- @app.post("/train")
31
- async def train_model(file: UploadFile = File(...), target_column: str = "target"):
32
- try:
33
- df = pd.read_csv(file.file)
34
- result = orchestrator.run(df, target_column, train=True)
35
-
36
- # Ensure strategy is included in the response
37
- strategy = result.get("strategy", {})
38
- response = {
39
- "strategy": strategy,
40
- "metrics": result.get("metrics", {}),
41
- "model_path": result.get("model_path", "/path/to/model.pkl"),
42
- "training_time": result.get("training_time", 0),
43
- "model_id": result.get("model_id", "trained_model_123")
44
- }
45
- return response
46
- except Exception as e:
47
- raise HTTPException(status_code=400, detail=str(e))
48
-
49
- @app.post("/explain")
50
- async def explain_model(file: UploadFile = File(...), target_column: str = "target"):
51
- try:
52
- df = pd.read_csv(file.file)
53
- result = orchestrator.run(df, target_column, train=True)
54
- return {
55
- "strategy_explanation": result.get("strategy_explanation"),
56
- "metrics": result.get("metrics", {}),
57
- "feature_importance": result.get("feature_importance", [])
58
- }
59
- except Exception as e:
60
- raise HTTPException(status_code=400, detail=str(e))
61
-
62
- @app.post("/predict")
63
- async def predict(data: dict):
64
- try:
65
- # Load the trained model
66
- model = orchestrator.model_io.load("exports/models/trained_model.pkl")
67
- # Prepare data for prediction
68
- df = pd.DataFrame([data])
69
- preds = model.predict(df)
70
- return {"prediction": preds.tolist()}
71
- except Exception as e:
72
- raise HTTPException(status_code=400, detail=str(e))
73
-
74
- @app.get("/health")
75
- def health():
76
- return {"status": "ok"}
77
-
78
-
79
-
80
-
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ import pandas as pd
3
+ from backend.core.orchestrator import Orchestrator
4
+
5
+ app = FastAPI()
6
+ orchestrator = Orchestrator()
7
+
8
+ @app.post("/analyze")
9
+ async def analyze_dataset(file: UploadFile = File(...), target_column: str = "target"):
10
+ try:
11
+ df = pd.read_csv(file.file)
12
+ result = orchestrator.run(df, target_column)
13
+
14
+ # Format response for frontend
15
+ dataset_info = result.get("dataset_info", {})
16
+ strategy = result.get("strategy", {})
17
+
18
+ response = {
19
+ "columns": list(df.columns),
20
+ "dataTypes": dataset_info.get("data_types", {}),
21
+ "risks": dataset_info.get("risks", []),
22
+ "problemType": result.get("problem_type"),
23
+ "confidence": strategy.get("confidence", 0),
24
+ "strategy": strategy
25
+ }
26
+ return response
27
+ except Exception as e:
28
+ raise HTTPException(status_code=400, detail=str(e))
29
+
30
+ @app.post("/train")
31
+ async def train_model(file: UploadFile = File(...), target_column: str = "target"):
32
+ try:
33
+ df = pd.read_csv(file.file)
34
+ result = orchestrator.run(df, target_column, train=True)
35
+
36
+ # Ensure strategy is included in the response
37
+ strategy = result.get("strategy", {})
38
+ response = {
39
+ "strategy": strategy,
40
+ "metrics": result.get("metrics", {}),
41
+ "model_path": result.get("model_path", "/path/to/model.pkl"),
42
+ "training_time": result.get("training_time", 0),
43
+ "model_id": result.get("model_id", "trained_model_123")
44
+ }
45
+ return response
46
+ except Exception as e:
47
+ raise HTTPException(status_code=400, detail=str(e))
48
+
49
+ @app.post("/explain")
50
+ async def explain_model(file: UploadFile = File(...), target_column: str = "target"):
51
+ try:
52
+ df = pd.read_csv(file.file)
53
+ result = orchestrator.run(df, target_column, train=True)
54
+ return {
55
+ "strategy_explanation": result.get("strategy_explanation"),
56
+ "metrics": result.get("metrics", {}),
57
+ "feature_importance": result.get("feature_importance", [])
58
+ }
59
+ except Exception as e:
60
+ raise HTTPException(status_code=400, detail=str(e))
61
+
62
+ @app.post("/predict")
63
+ async def predict(data: dict):
64
+ try:
65
+ # Load the trained model
66
+ model = orchestrator.model_io.load("exports/models/trained_model.pkl")
67
+ # Prepare data for prediction
68
+ df = pd.DataFrame([data])
69
+ preds = model.predict(df)
70
+ return {"prediction": preds.tolist()}
71
+ except Exception as e:
72
+ raise HTTPException(status_code=400, detail=str(e))
73
+
74
+ @app.get("/")
75
+ def root():
76
+ return {"message": "ModelSmith AI API", "status": "running"}
77
+
78
+ @app.get("/health")
79
+ def health():
80
+ return {"status": "ok"}
81
+
82
+
83
+
84
+