File size: 3,645 Bytes
7a65abf
 
 
 
 
 
 
 
 
84ef2a8
7a65abf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648fa9d
 
7a65abf
648fa9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a65abf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import shutil
from fastapi.responses import FileResponse
import asyncio
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from utils import STT, TTS
from data_ingestion import Ingest_Data 
from RAG import app as rag_app, Ragbot_State, reload_vector_store
import os

# Initialize FastAPI
app = FastAPI(title="LangGraph RAG Chatbot", version="1.0")

# --- Pydantic Models ---
class ChatRequest(BaseModel):
    query: str
    thread_id: str = "default_user"
    use_rag: bool = False
    use_web: bool = False
    model_name: str = "gpt"

class TTSRequest(BaseModel):
    text: str
    voice: str = "en-US-AriaNeural"


# --- Endpoints ---

@app.get("/")
def health_check():
    return {"status": "running", "message": "Bot is ready"}

@app.post("/upload")
async def upload_document(
    file: UploadFile = File(...), 
    background_tasks: BackgroundTasks = BackgroundTasks()
):
    try:
        temp_filename = f"temp_{file.filename}"

        with open(temp_filename, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)

        def process_and_reload(path):
            try:
                result = Ingest_Data(path)
                print(f"Ingestion Result: {result}")
                reload_vector_store()
                
            except Exception as e:
                print(f"Error processing background task: {e}")
            finally:
                if os.path.exists(path):
                    os.remove(path)

        background_tasks.add_task(process_and_reload, temp_filename)

        return {
            "message": "File received. Processing started in background.", 
            "filename": file.filename
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    """
    Standard Chat Endpoint (Non-Streaming).
    Waits for the LLM to finish and returns the full JSON response.
    """
    try:
        # 1. Setup Config & Inputs
        config = {"configurable": {"thread_id": request.thread_id}}
        
        inputs = {
            "query": request.query,
            "RAG": request.use_rag,
            "web_search": request.use_web,
            "model_name": request.model_name,
            "context": [],
            "metadata": [],
            "web_context": "",
        }

        # 2. Invoke the Graph (Waits for completion)
        # using ainvoke is better for FastAPI to prevent blocking the server
        result = await rag_app.ainvoke(inputs, config=config)
        
        # 3. Extract the last message (AI Response)
        last_message = result['response'][-1]
        
        # 4. Return standard JSON
        return {
            "response": last_message.content,
            "thread_id": request.thread_id
        }

    except Exception as e:
        print(f"Error generation response: {e}")
        raise HTTPException(status_code=500, detail=str(e))


# ---------------- STT ---------------- #
@app.post("/stt")
async def transcribe_audio(file: UploadFile = File(...)):
    try:
        return await STT(file)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# ---------------- TTS ---------------- #
@app.post("/tts")
async def text_to_speech(req: TTSRequest):
    try:
        audio_path = await TTS(req.text, req.voice)
        return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3")
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))