cortex / app.py
junaid17's picture
Update app.py
7a65abf verified
raw
history blame
4.25 kB
import os
import shutil
from fastapi.responses import FileResponse
import asyncio
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from utils import STT, TTS
from data_ingestion import Ingest_Data
from RAG import app as rag_app, Ragbot_State, reload_vector_store
# Initialize FastAPI
app = FastAPI(title="LangGraph RAG Chatbot", version="1.0")
# --- Pydantic Models ---
class ChatRequest(BaseModel):
query: str
thread_id: str = "default_user"
use_rag: bool = False
use_web: bool = False
model_name: str = "gpt"
class TTSRequest(BaseModel):
text: str
voice: str = "en-US-AriaNeural"
# --- Endpoints ---
@app.get("/")
def health_check():
return {"status": "running", "message": "Bot is ready"}
@app.post("/upload")
async def upload_document(
file: UploadFile = File(...),
background_tasks: BackgroundTasks = BackgroundTasks()
):
try:
temp_filename = f"temp_{file.filename}"
with open(temp_filename, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
def process_and_reload(path):
try:
result = Ingest_Data(path)
print(f"Ingestion Result: {result}")
reload_vector_store()
except Exception as e:
print(f"Error processing background task: {e}")
finally:
if os.path.exists(path):
os.remove(path)
background_tasks.add_task(process_and_reload, temp_filename)
return {
"message": "File received. Processing started in background.",
"filename": file.filename
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""
Streaming endpoint adapted from your working Hugging Face snippet.
"""
# 1. Setup Inputs
config = {"configurable": {"thread_id": request.thread_id}}
inputs = {
"query": request.query,
"RAG": request.use_rag,
"web_search": request.use_web,
"model_name": request.model_name,
"context": [],
"metadata": [],
"web_context": "",
}
# 2. Define the Generator (Matching your snippet's logic)
async def event_generator():
# Iterate through events (LangGraph's version of bot.stream)
async for event in rag_app.astream_events(inputs, config=config, version="v1"):
# We look for the specific event type that contains the LLM chunks
kind = event["event"]
if kind == "on_chat_model_stream":
# Get the chunk data
chunk = event["data"]["chunk"]
# Logic from your snippet: check if content exists
if chunk and hasattr(chunk, "content"):
content = chunk.content
if content:
# EXACT FORMATTING FROM YOUR SNIPPET
data = str(content).replace("\n", "\\n")
yield f"data: {data}\n\n"
# 3. Return StreamingResponse (Matching your snippet's headers)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Critical for Hugging Face
"Connection": "keep-alive", # Added for extra safety
},
)
# ---------------- STT ---------------- #
@app.post("/stt")
async def transcribe_audio(file: UploadFile = File(...)):
try:
return await STT(file)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ---------------- TTS ---------------- #
@app.post("/tts")
async def text_to_speech(req: TTSRequest):
try:
audio_path = await TTS(req.text, req.voice)
return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))