File size: 2,325 Bytes
745c08b
 
5c5b559
745c08b
5c5b559
 
745c08b
5c5b559
 
 
 
745c08b
 
 
 
 
 
 
 
 
 
 
5c5b559
 
 
 
745c08b
 
 
 
 
 
5c5b559
 
745c08b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c5b559
 
 
 
9e9c87f
745c08b
 
 
 
 
 
 
 
 
 
5c5b559
745c08b
 
 
5c5b559
745c08b
 
 
 
5c5b559
745c08b
5c5b559
 
745c08b
5c5b559
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
import os
import asyncio

from chatbot import app as app_graph
from langchain_core.messages import HumanMessage
from tools import update_retriever
from utils import TTS, STT

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)


class TTSRequest(BaseModel):
    text: str


@app.get("/")
def health():
    return {"status": "API is running"}


@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    file_path = os.path.join(UPLOAD_DIR, file.filename)

    with open(file_path, "wb") as f:
        f.write(await file.read())

    update_retriever(file_path)

    return {
        "status": "success",
        "filename": file.filename
    }


@app.post("/chat")
async def chat(message: str, session_id: str = "default"):

    async def event_generator():
        async for chunk in app_graph.astream(
            {"messages": [HumanMessage(content=message)]},
            config={"configurable": {"thread_id": session_id}},
            stream_mode="messages"
        ):
            if chunk:
                msg = chunk[0] if isinstance(chunk, tuple) else chunk
                if hasattr(msg, "content") and msg.content:
                    yield f"data: {msg.content}\n\n"
                    await asyncio.sleep(0.01)

    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


@app.post("/stt")
async def transcribe_audio(file: UploadFile = File(...)):
    return await STT(file)


@app.post("/tts")
async def generate_tts(request: TTSRequest):
    audio_path = await TTS(text=request.text)

    if not os.path.exists(audio_path):
        raise HTTPException(status_code=500, detail="Audio not generated")

    return FileResponse(
        path=audio_path,
        media_type="audio/mpeg",
        filename="speech.mp3"
    )