ACA050's picture
Update backend/api/main.py
d02e833 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
import pandas as pd
import os
from backend.core.orchestrator import Orchestrator
from backend.api.schemas import PredictRequest
app = FastAPI(
title="ModelSmith AI",
description="Automated ML platform for tabular data",
version="1.0.0"
)
orchestrator = Orchestrator()
MODEL_PATH = "exports/models/trained_model.pkl"
# -----------------------------
# Analyze Dataset
# -----------------------------
@app.post("/analyze")
async def analyze_dataset(
file: UploadFile = File(...),
target_column: str = "target"
):
try:
df = pd.read_csv(file.file)
result = orchestrator.run(df, target_column)
dataset_info = result.get("dataset_info", {})
strategy = result.get("strategy", {})
return {
"columns": list(df.columns),
"dataTypes": dataset_info.get("data_types", {}),
"risks": dataset_info.get("risks", []),
"problemType": result.get("problem_type"),
"confidence": strategy.get("confidence", 0),
"strategy": strategy
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# -----------------------------
# Train Model
# -----------------------------
@app.post("/train")
async def train_model(
file: UploadFile = File(...),
target_column: str = "target"
):
try:
df = pd.read_csv(file.file)
result = orchestrator.run(df, target_column, train=True)
return {
"strategy": result.get("strategy"),
"metrics": result.get("metrics"),
"model_path": MODEL_PATH,
"model_id": "trained_model"
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# -----------------------------
# Explain Model
# -----------------------------
@app.post("/explain")
async def explain_model(
file: UploadFile = File(...),
target_column: str = "target"
):
try:
df = pd.read_csv(file.file)
result = orchestrator.run(df, target_column, train=True)
return {
"strategy_explanation": result.get("strategy_explanation"),
"metrics": result.get("metrics"),
"feature_importance": result.get("feature_importance")
}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# -----------------------------
# Predict
# -----------------------------
@app.post("/predict")
async def predict(request: PredictRequest):
try:
# 1. Check model exists
if not os.path.exists(MODEL_PATH):
raise HTTPException(
status_code=400,
detail="No trained model found. Train a model first."
)
# 2. Load model
model = orchestrator.model_io.load(MODEL_PATH)
# 3. Convert input to DataFrame
df = pd.DataFrame(request.instances)
if df.empty:
raise HTTPException(
status_code=400,
detail="Prediction data is empty."
)
# 4. Validate feature columns
if hasattr(model, "feature_names_in_"):
expected_features = list(model.feature_names_in_)
received_features = list(df.columns)
missing = set(expected_features) - set(received_features)
extra = set(received_features) - set(expected_features)
if missing:
raise HTTPException(
status_code=400,
detail=f"Missing required features: {sorted(missing)}"
)
# Drop extra columns if any
df = df[expected_features]
# 5. Predict
predictions = model.predict(df)
return {
"predictions": predictions.tolist()
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
# -----------------------------
# Root & Health
# -----------------------------
@app.get("/")
def root():
return {
"message": "ModelSmith AI API",
"status": "running"
}
@app.get("/health")
def health():
return {"status": "ok"}