junaid17 commited on
Commit
23413f9
·
verified ·
1 Parent(s): e52ad85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -119
app.py CHANGED
@@ -1,119 +1,138 @@
1
- import os
2
- import shutil
3
- from fastapi.responses import FileResponse
4
- import asyncio
5
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
6
- from fastapi.responses import StreamingResponse
7
- from pydantic import BaseModel
8
- from utils import STT, TTS
9
- from data_ingestion import Ingest_Data
10
- from RAG import app as rag_app, Ragbot_State, reload_vector_store
11
-
12
- # Initialize FastAPI
13
- app = FastAPI(title="LangGraph RAG Chatbot", version="1.0")
14
-
15
- # --- Pydantic Models ---
16
- class ChatRequest(BaseModel):
17
- query: str
18
- thread_id: str = "default_user"
19
- use_rag: bool = False
20
- use_web: bool = False
21
- model_name: str = "gpt"
22
-
23
- class TTSRequest(BaseModel):
24
- text: str
25
- voice: str = "en-US-AriaNeural"
26
-
27
-
28
- # --- Endpoints ---
29
-
30
- @app.get("/")
31
- def health_check():
32
- return {"status": "running", "message": "Bot is ready"}
33
-
34
- @app.post("/upload")
35
- async def upload_document(
36
- file: UploadFile = File(...),
37
- background_tasks: BackgroundTasks = BackgroundTasks()
38
- ):
39
- try:
40
- temp_filename = f"temp_{file.filename}"
41
-
42
- with open(temp_filename, "wb") as buffer:
43
- shutil.copyfileobj(file.file, buffer)
44
-
45
- def process_and_reload(path):
46
- try:
47
- result = Ingest_Data(path)
48
- print(f"Ingestion Result: {result}")
49
- reload_vector_store()
50
-
51
- except Exception as e:
52
- print(f"Error processing background task: {e}")
53
- finally:
54
- if os.path.exists(path):
55
- os.remove(path)
56
-
57
- background_tasks.add_task(process_and_reload, temp_filename)
58
-
59
- return {
60
- "message": "File received. Processing started in background.",
61
- "filename": file.filename
62
- }
63
-
64
- except Exception as e:
65
- raise HTTPException(status_code=500, detail=str(e))
66
-
67
-
68
- @app.post("/chat")
69
- async def chat_endpoint(request: ChatRequest):
70
- config = {"configurable": {"thread_id": request.thread_id}}
71
-
72
- inputs = {
73
- "query": request.query,
74
- "RAG": request.use_rag,
75
- "web_search": request.use_web,
76
- "model_name": request.model_name,
77
- "context": [],
78
- "metadata": [],
79
- "web_context": "",
80
- }
81
-
82
- async def event_generator():
83
- async for event in rag_app.astream_events(inputs, config=config, version="v1"):
84
- kind = event["event"]
85
- if kind == "on_chat_model_stream":
86
- content = event["data"]["chunk"].content
87
-
88
- if content:
89
- data = content.replace("\n", "\\n")
90
- yield f"data: {data}\n\n"
91
-
92
- return StreamingResponse(
93
- event_generator(),
94
- media_type="text/event-stream",
95
- headers={
96
- "Cache-Control": "no-cache",
97
- "Connection": "keep-alive",
98
- "X-Accel-Buffering": "no",
99
- },
100
- )
101
-
102
-
103
- # ---------------- STT ---------------- #
104
- @app.post("/stt")
105
- async def transcribe_audio(file: UploadFile = File(...)):
106
- try:
107
- return await STT(file)
108
- except Exception as e:
109
- raise HTTPException(status_code=500, detail=str(e))
110
-
111
- # ---------------- TTS ---------------- #
112
- @app.post("/tts")
113
- async def text_to_speech(req: TTSRequest):
114
- try:
115
- audio_path = await TTS(req.text, req.voice)
116
- return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3")
117
-
118
- except Exception as e:
119
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from fastapi.responses import FileResponse
4
+ import asyncio
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
6
+ from fastapi.responses import StreamingResponse
7
+ from pydantic import BaseModel
8
+ from utils import STT, TTS
9
+ from data_ingestion import Ingest_Data
10
+ from RAG import app as rag_app, Ragbot_State, reload_vector_store
11
+
12
+ # Initialize FastAPI
13
+ app = FastAPI(title="LangGraph RAG Chatbot", version="1.0")
14
+
15
+ # --- Pydantic Models ---
16
+ class ChatRequest(BaseModel):
17
+ query: str
18
+ thread_id: str = "default_user"
19
+ use_rag: bool = False
20
+ use_web: bool = False
21
+ model_name: str = "gpt"
22
+
23
+ class TTSRequest(BaseModel):
24
+ text: str
25
+ voice: str = "en-US-AriaNeural"
26
+
27
+
28
+ # --- Endpoints ---
29
+
30
+ @app.get("/")
31
+ def health_check():
32
+ return {"status": "running", "message": "Bot is ready"}
33
+
34
+ @app.post("/upload")
35
+ async def upload_document(
36
+ file: UploadFile = File(...),
37
+ background_tasks: BackgroundTasks = BackgroundTasks()
38
+ ):
39
+ try:
40
+ temp_filename = f"temp_{file.filename}"
41
+
42
+ with open(temp_filename, "wb") as buffer:
43
+ shutil.copyfileobj(file.file, buffer)
44
+
45
+ def process_and_reload(path):
46
+ try:
47
+ result = Ingest_Data(path)
48
+ print(f"Ingestion Result: {result}")
49
+ reload_vector_store()
50
+
51
+ except Exception as e:
52
+ print(f"Error processing background task: {e}")
53
+ finally:
54
+ if os.path.exists(path):
55
+ os.remove(path)
56
+
57
+ background_tasks.add_task(process_and_reload, temp_filename)
58
+
59
+ return {
60
+ "message": "File received. Processing started in background.",
61
+ "filename": file.filename
62
+ }
63
+
64
+ except Exception as e:
65
+ raise HTTPException(status_code=500, detail=str(e))
66
+
67
+
68
+ import json
69
+
70
+ # ... (keep existing imports) ...
71
+
72
+ @app.post("/chat")
73
+ async def chat_endpoint(request: ChatRequest):
74
+ """
75
+ Chat endpoint that returns a STREAMING response in JSON-SSE format.
76
+ """
77
+
78
+ config = {"configurable": {"thread_id": request.thread_id}}
79
+
80
+ inputs = {
81
+ "query": request.query,
82
+ "RAG": request.use_rag,
83
+ "web_search": request.use_web,
84
+ "model_name": request.model_name, # Ensure you passed this if you added multi-llm support
85
+ "context": [],
86
+ "metadata": [],
87
+ "web_context": "",
88
+ }
89
+
90
+ async def event_generator():
91
+ print(f"--- Starting stream for {request.thread_id} ---") # Log to HF console
92
+
93
+ async for event in rag_app.astream_events(inputs, config=config, version="v1"):
94
+ kind = event["event"]
95
+
96
+ # Check for LLM token events
97
+ if kind == "on_chat_model_stream":
98
+ content = event["data"]["chunk"].content
99
+
100
+ if content:
101
+ # 1. Wrap content in a JSON object (Safer than raw text)
102
+ chunk_data = json.dumps({"content": content})
103
+
104
+ # 2. Yield the SSE frame
105
+ yield f"data: {chunk_data}\n\n"
106
+
107
+ # 3. Send a [DONE] signal so the frontend knows to stop
108
+ yield "data: [DONE]\n\n"
109
+
110
+ return StreamingResponse(
111
+ event_generator(),
112
+ media_type="text/event-stream",
113
+ headers={
114
+ "Cache-Control": "no-cache",
115
+ "Connection": "keep-alive",
116
+ "Content-Type": "text/event-stream",
117
+ "X-Accel-Buffering": "no", # CRITICAL: Disables HF/Nginx buffering
118
+ },
119
+ )
120
+
121
+
122
+ # ---------------- STT ---------------- #
123
+ @app.post("/stt")
124
+ async def transcribe_audio(file: UploadFile = File(...)):
125
+ try:
126
+ return await STT(file)
127
+ except Exception as e:
128
+ raise HTTPException(status_code=500, detail=str(e))
129
+
130
+ # ---------------- TTS ---------------- #
131
+ @app.post("/tts")
132
+ async def text_to_speech(req: TTSRequest):
133
+ try:
134
+ audio_path = await TTS(req.text, req.voice)
135
+ return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3")
136
+
137
+ except Exception as e:
138
+ raise HTTPException(status_code=500, detail=str(e))