Spaces:
Sleeping
Sleeping
File size: 3,163 Bytes
fbdfc24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, Optional
import uvicorn
from datetime import datetime
from core.chat_manager import LegalChatManager
# Pydantic models for API
class ChatRequest(BaseModel):
query: str
session_id: Optional[str] = None
context: Optional[Dict] = None
class ChatResponse(BaseModel):
response: str
session_id: str
session_stats: Dict
error: Optional[str] = None
class HealthResponse(BaseModel):
status: str
stats: Dict
timestamp: str
class LegalRAGAPI:
def __init__(self, chat_manager: LegalChatManager):
self.app = FastAPI(title="Legal RAG API", version="1.0.0")
self.chat_manager = chat_manager
self._setup_middleware()
self._setup_routes()
def _setup_middleware(self):
"""Setup CORS and other middleware"""
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _setup_routes(self):
"""Setup API routes"""
@self.app.get("/")
async def root():
return {"message": "Legal RAG API is running"}
@self.app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
session_id = request.session_id or f"web_{datetime.now().timestamp()}"
response = await self.chat_manager.chat(
request.query,
session_id,
request.context
)
session_stats = self.chat_manager.get_session_stats(session_id)
return ChatResponse(
response=response,
session_id=session_id,
session_stats=session_stats,
error=None
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@self.app.get("/health", response_model=HealthResponse)
async def health_check():
return HealthResponse(
status="healthy",
stats=self.chat_manager.get_global_stats(),
timestamp=datetime.now().isoformat()
)
@self.app.get("/sessions/{session_id}/history")
async def get_session_history(session_id: str):
try:
history = await self.chat_manager.get_conversation_history(session_id)
return {
"session_id": session_id,
"message_count": len(history),
"messages": history
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def run(self, host: str = "0.0.0.0", port: int = 8000):
"""Run the API server"""
uvicorn.run(self.app, host=host, port=port) |