cortex / app.py
junaid17's picture
Update app.py
e517ecf verified
raw
history blame
4.54 kB
import os
import shutil
from fastapi.responses import FileResponse
import asyncio
import json
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from utils import STT, TTS
from data_ingestion import Ingest_Data
from RAG import app as rag_app, Ragbot_State, reload_vector_store
# Initialize FastAPI
app = FastAPI(title="LangGraph RAG Chatbot", version="1.0")
# --- Pydantic Models ---
class ChatRequest(BaseModel):
query: str
thread_id: str = "default_user"
use_rag: bool = False
use_web: bool = False
model_name: str = "gpt"
class TTSRequest(BaseModel):
text: str
voice: str = "en-US-AriaNeural"
# --- Endpoints ---
@app.get("/")
def health_check():
return {"status": "running", "message": "Bot is ready"}
@app.post("/upload")
async def upload_document(
file: UploadFile = File(...),
background_tasks: BackgroundTasks = BackgroundTasks()
):
try:
temp_filename = f"temp_{file.filename}"
with open(temp_filename, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
def process_and_reload(path):
try:
result = Ingest_Data(path)
print(f"Ingestion Result: {result}")
reload_vector_store()
except Exception as e:
print(f"Error processing background task: {e}")
finally:
if os.path.exists(path):
os.remove(path)
background_tasks.add_task(process_and_reload, temp_filename)
return {
"message": "File received. Processing started in background.",
"filename": file.filename
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ... (keep existing imports) ...
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""
Robust Streaming Endpoint that logs events to console.
"""
config = {"configurable": {"thread_id": request.thread_id}}
inputs = {
"query": request.query,
"RAG": request.use_rag,
"web_search": request.use_web,
"model_name": request.model_name,
"context": [],
"metadata": [],
"web_context": "",
}
async def event_generator():
print(f"--- 🚀 Starting stream for {request.thread_id} ---")
# Use 'v2' if you are on the latest LangGraph, but 'v1' is safer for compatibility
async for event in rag_app.astream_events(inputs, config=config, version="v1"):
# [DEBUG] Print the event type to your Hugging Face Logs
# This will show us if the events are firing but named differently
event_type = event.get("event")
# Logic: We don't care about the event name.
# We only care: "Does this event have a chunk with text?"
data = event.get("data", {})
chunk = data.get("chunk")
# Check if chunk exists and has .content attribute (standard LangChain message chunk)
if chunk and hasattr(chunk, "content") and chunk.content:
content = chunk.content
# Filter out empty strings or weird artifacts
if content.strip() != "":
# JSON encode the content
chunk_json = json.dumps({"content": content})
yield f"data: {chunk_json}\n\n"
# End of stream
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
"X-Accel-Buffering": "no",
},
)
# ---------------- STT ---------------- #
@app.post("/stt")
async def transcribe_audio(file: UploadFile = File(...)):
try:
return await STT(file)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ---------------- TTS ---------------- #
@app.post("/tts")
async def text_to_speech(req: TTSRequest):
try:
audio_path = await TTS(req.text, req.voice)
return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))