Spaces:
Sleeping
Sleeping
File size: 2,325 Bytes
745c08b 5c5b559 745c08b 5c5b559 745c08b 5c5b559 745c08b 5c5b559 745c08b 5c5b559 745c08b 5c5b559 9e9c87f 745c08b 5c5b559 745c08b 5c5b559 745c08b 5c5b559 745c08b 5c5b559 745c08b 5c5b559 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
import os
import asyncio
from chatbot import app as app_graph
from langchain_core.messages import HumanMessage
from tools import update_retriever
from utils import TTS, STT
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
class TTSRequest(BaseModel):
text: str
@app.get("/")
def health():
return {"status": "API is running"}
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as f:
f.write(await file.read())
update_retriever(file_path)
return {
"status": "success",
"filename": file.filename
}
@app.post("/chat")
async def chat(message: str, session_id: str = "default"):
async def event_generator():
async for chunk in app_graph.astream(
{"messages": [HumanMessage(content=message)]},
config={"configurable": {"thread_id": session_id}},
stream_mode="messages"
):
if chunk:
msg = chunk[0] if isinstance(chunk, tuple) else chunk
if hasattr(msg, "content") and msg.content:
yield f"data: {msg.content}\n\n"
await asyncio.sleep(0.01)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.post("/stt")
async def transcribe_audio(file: UploadFile = File(...)):
return await STT(file)
@app.post("/tts")
async def generate_tts(request: TTSRequest):
audio_path = await TTS(text=request.text)
if not os.path.exists(audio_path):
raise HTTPException(status_code=500, detail="Audio not generated")
return FileResponse(
path=audio_path,
media_type="audio/mpeg",
filename="speech.mp3"
)
|