| |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from typing import List, Tuple, Dict |
| import helper |
| import dialogue |
| import models.linear as linear_model |
| import models.logistic as logistic_model |
| import judge |
|
|
| app = FastAPI( |
| title="KindCare Chat API with ML", |
| description="Chat + prediction + judge endpoints" |
| ) |
|
|
| class ChatRequest(BaseModel): |
| message: str |
| history: List[Tuple[str, str]] = [] |
|
|
| class ChatResponse(BaseModel): |
| reply: str |
| history: List[Tuple[str, str]] |
|
|
| class PredictRequest(BaseModel): |
| features: Dict[str, float] |
|
|
| class PredictLinearResponse(BaseModel): |
| prediction: float |
| metrics: Dict[str, float] |
|
|
| class PredictLogisticRequest(PredictRequest): |
| bins: List[str] |
|
|
| class PredictLogisticResponse(BaseModel): |
| prediction: str |
| metrics: Dict[str, float] |
|
|
| class JudgeRequest(BaseModel): |
| model_type: str |
| metrics: Dict[str, float] |
|
|
| class JudgeResponse(BaseModel): |
| verdict: str |
| comments: List[str] |
|
|
| @app.post("/chat", response_model=ChatResponse) |
| async def chat(req: ChatRequest): |
| try: |
| reply, hist = dialogue.handle_message(req.message, req.history) |
| return ChatResponse(reply=reply, history=hist) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/predict/linear", response_model=PredictLinearResponse) |
| async def predict_linear(req: PredictRequest): |
| pred, metrics = linear_model.predict(req.features) |
| return PredictLinearResponse(prediction=pred, metrics=metrics) |
|
|
| @app.post("/predict/logistic", response_model=PredictLogisticResponse) |
| async def predict_logistic(req: PredictLogisticRequest): |
| pred, metrics = logistic_model.predict(req.features, req.bins) |
| return PredictLogisticResponse(prediction=pred, metrics=metrics) |
|
|
| @app.post("/judge/model", response_model=JudgeResponse) |
| async def judge_endpoint(req: JudgeRequest): |
| result = judge.judge_model(req.model_type, req.metrics) |
| return JudgeResponse(**result) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=8000) |
|
|