Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from orator import Session, SQLDatabase, DocumentDatabase | |
| from langchain.chat_models import init_chat_model | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import asyncio | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Orator Chat API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Use a specific origin in production | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all methods (GET, POST, etc.) | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| # Initialize LLM and databases | |
| llm = init_chat_model("o3-mini", model_provider="openai") | |
| chinook_db = SQLDatabase.from_uri("sqlite:////home/geetu/work/orator/data/chinook/Chinook.db") | |
| pricegram_db = DocumentDatabase("/home/geetu/work/orator/data/pricegram/data.json", top_k=10) | |
| # Initialize session | |
| session = Session(llm=llm, datasources=[chinook_db, pricegram_db]) | |
| # Pydantic model for request | |
| class QueryRequest(BaseModel): | |
| query: str | |
| source: int | |
| async def get_response(request: QueryRequest): | |
| """Process a query and return the response.""" | |
| try: | |
| print("Got Request:", request) | |
| response, logs = session.invoke(request.query, datasource=request.source) | |
| response = {"response": response} | |
| print("Sending Respose:", response) | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def stream_response(request: QueryRequest): | |
| """Stream responses for a given query.""" | |
| async def event_generator(): | |
| try: | |
| events = session.stream(request.query) | |
| for event in events: | |
| for person, quote in event.items(): | |
| yield f"{person}: {quote['messages'][-1].text}\n" | |
| await asyncio.sleep(0.1) # Simulate streaming delay | |
| except Exception as e: | |
| yield f"Error: {str(e)}" | |
| return StreamingResponse(event_generator(), media_type="text/plain") | |
| async def root(): | |
| return {"message": "Welcome to the Orator Chat API"} | |