| import asyncio
|
| import json
|
| import logging
|
| from typing import AsyncGenerator, List
|
|
|
| from fastapi import FastAPI, HTTPException
|
| from fastapi.responses import StreamingResponse
|
| from pydantic import BaseModel, Field
|
|
|
| from agent import agent_router
|
| from model import get_model_manager
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
| app = FastAPI(
|
| title="General AI Assistant Backend",
|
| description="Production-ready FastAPI backend with tools, memory, and CPU-friendly LLM inference.",
|
| version="1.0.0",
|
| )
|
|
|
|
|
| class ChatRequest(BaseModel):
|
| user_id: str = Field(..., min_length=1, max_length=128)
|
| message: str = Field(..., min_length=1, max_length=4000)
|
| stream: bool = Field(default=False)
|
|
|
|
|
| class ChatResponse(BaseModel):
|
| response: str
|
| route_used: str
|
| tools_used: List[str]
|
| stream_enabled: bool
|
|
|
|
|
| def _next_event_or_none(iterator):
|
| try:
|
| return next(iterator)
|
| except StopIteration:
|
| return None
|
|
|
|
|
| async def _sse_stream_from_agent(user_id: str, message: str) -> AsyncGenerator[str, None]:
|
| iterator = agent_router.stream_respond(user_id, message)
|
|
|
| while True:
|
| event = await asyncio.to_thread(_next_event_or_none, iterator)
|
| if event is None:
|
| break
|
|
|
| payload = {
|
| **event,
|
| "stream_enabled": True,
|
| }
|
| yield f"data: {json.dumps(payload, ensure_ascii=True)}\n\n"
|
| await asyncio.sleep(0)
|
|
|
|
|
| @app.on_event("startup")
|
| async def startup_event() -> None:
|
|
|
| logger.info("Loading language model...")
|
| await asyncio.to_thread(get_model_manager().load)
|
| logger.info("Model loaded.")
|
|
|
|
|
| @app.get("/health")
|
| async def health() -> dict:
|
| return {"status": "ok"}
|
|
|
|
|
| @app.post("/chat")
|
| async def chat(payload: ChatRequest):
|
| try:
|
| message = payload.message.strip()
|
| if not message:
|
| raise HTTPException(status_code=400, detail="Message cannot be empty.")
|
|
|
| if payload.stream:
|
| return StreamingResponse(
|
| _sse_stream_from_agent(payload.user_id, message),
|
| media_type="text/event-stream",
|
| headers={
|
| "Cache-Control": "no-cache",
|
| "Connection": "keep-alive",
|
| },
|
| )
|
|
|
| response = await asyncio.to_thread(
|
| agent_router.respond,
|
| payload.user_id,
|
| message,
|
| )
|
|
|
| if isinstance(response, dict):
|
| text = str(response.get("response", "")).strip()
|
| route_used = str(response.get("route_used", "llm"))
|
| tools_used = [str(t) for t in response.get("tools_used", [])]
|
| else:
|
| text = str(response).strip()
|
| route_used = "llm"
|
| tools_used = []
|
|
|
| if not text:
|
| raise HTTPException(status_code=500, detail="Model returned an empty response.")
|
|
|
| return ChatResponse(
|
| response=text,
|
| route_used=route_used,
|
| tools_used=tools_used,
|
| stream_enabled=False,
|
| )
|
| except HTTPException:
|
| raise
|
| except Exception as exc:
|
| logger.exception("Chat endpoint failed")
|
| raise HTTPException(status_code=500, detail=f"Internal server error: {exc}") from exc
|
|
|
|
|
|
|
|
|
|
|