junaid17 commited on
Commit
7a65abf
·
verified ·
1 Parent(s): 9233a61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -119
app.py CHANGED
@@ -1,119 +1,135 @@
1
- import os
2
- import shutil
3
- from fastapi.responses import FileResponse
4
- import asyncio
5
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
6
- from fastapi.responses import StreamingResponse
7
- from pydantic import BaseModel
8
- from utils import STT, TTS
9
- from data_ingestion import Ingest_Data
10
- from RAG import app as rag_app, Ragbot_State, reload_vector_store
11
-
12
- # Initialize FastAPI
13
- app = FastAPI(title="LangGraph RAG Chatbot", version="1.0")
14
-
15
- # --- Pydantic Models ---
16
- class ChatRequest(BaseModel):
17
- query: str
18
- thread_id: str = "default_user"
19
- use_rag: bool = False
20
- use_web: bool = False
21
- model_name: str = "gpt"
22
-
23
- class TTSRequest(BaseModel):
24
- text: str
25
- voice: str = "en-US-AriaNeural"
26
-
27
-
28
- # --- Endpoints ---
29
-
30
- @app.get("/")
31
- def health_check():
32
- return {"status": "running", "message": "Bot is ready"}
33
-
34
- @app.post("/upload")
35
- async def upload_document(
36
- file: UploadFile = File(...),
37
- background_tasks: BackgroundTasks = BackgroundTasks()
38
- ):
39
- try:
40
- temp_filename = f"temp_{file.filename}"
41
-
42
- with open(temp_filename, "wb") as buffer:
43
- shutil.copyfileobj(file.file, buffer)
44
-
45
- def process_and_reload(path):
46
- try:
47
- result = Ingest_Data(path)
48
- print(f"Ingestion Result: {result}")
49
- reload_vector_store()
50
-
51
- except Exception as e:
52
- print(f"Error processing background task: {e}")
53
- finally:
54
- if os.path.exists(path):
55
- os.remove(path)
56
-
57
- background_tasks.add_task(process_and_reload, temp_filename)
58
-
59
- return {
60
- "message": "File received. Processing started in background.",
61
- "filename": file.filename
62
- }
63
-
64
- except Exception as e:
65
- raise HTTPException(status_code=500, detail=str(e))
66
-
67
-
68
- @app.post("/chat")
69
- async def chat_endpoint(request: ChatRequest):
70
- config = {"configurable": {"thread_id": request.thread_id}}
71
-
72
- inputs = {
73
- "query": request.query,
74
- "RAG": request.use_rag,
75
- "web_search": request.use_web,
76
- "model_name": request.model_name,
77
- "context": [],
78
- "metadata": [],
79
- "web_context": "",
80
- }
81
-
82
- async def event_generator():
83
- async for event in rag_app.astream_events(inputs, config=config, version="v1"):
84
- kind = event["event"]
85
- if kind == "on_chat_model_stream":
86
- content = event["data"]["chunk"].content
87
-
88
- if content:
89
- data = content.replace("\n", "\\n")
90
- yield f"data: {data}\n\n"
91
-
92
- return StreamingResponse(
93
- event_generator(),
94
- media_type="text/event-stream",
95
- headers={
96
- "Cache-Control": "no-cache",
97
- "Connection": "keep-alive",
98
- "X-Accel-Buffering": "no",
99
- },
100
- )
101
-
102
-
103
- # ---------------- STT ---------------- #
104
- @app.post("/stt")
105
- async def transcribe_audio(file: UploadFile = File(...)):
106
- try:
107
- return await STT(file)
108
- except Exception as e:
109
- raise HTTPException(status_code=500, detail=str(e))
110
-
111
- # ---------------- TTS ---------------- #
112
- @app.post("/tts")
113
- async def text_to_speech(req: TTSRequest):
114
- try:
115
- audio_path = await TTS(req.text, req.voice)
116
- return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3")
117
-
118
- except Exception as e:
119
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from fastapi.responses import FileResponse
4
+ import asyncio
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
6
+ from fastapi.responses import StreamingResponse
7
+ from pydantic import BaseModel
8
+ from utils import STT, TTS
9
+ from data_ingestion import Ingest_Data
10
+ from RAG import app as rag_app, Ragbot_State, reload_vector_store
11
+
12
+ # Initialize FastAPI
13
+ app = FastAPI(title="LangGraph RAG Chatbot", version="1.0")
14
+
15
+ # --- Pydantic Models ---
16
+ class ChatRequest(BaseModel):
17
+ query: str
18
+ thread_id: str = "default_user"
19
+ use_rag: bool = False
20
+ use_web: bool = False
21
+ model_name: str = "gpt"
22
+
23
+ class TTSRequest(BaseModel):
24
+ text: str
25
+ voice: str = "en-US-AriaNeural"
26
+
27
+
28
+ # --- Endpoints ---
29
+
30
+ @app.get("/")
31
+ def health_check():
32
+ return {"status": "running", "message": "Bot is ready"}
33
+
34
+ @app.post("/upload")
35
+ async def upload_document(
36
+ file: UploadFile = File(...),
37
+ background_tasks: BackgroundTasks = BackgroundTasks()
38
+ ):
39
+ try:
40
+ temp_filename = f"temp_{file.filename}"
41
+
42
+ with open(temp_filename, "wb") as buffer:
43
+ shutil.copyfileobj(file.file, buffer)
44
+
45
+ def process_and_reload(path):
46
+ try:
47
+ result = Ingest_Data(path)
48
+ print(f"Ingestion Result: {result}")
49
+ reload_vector_store()
50
+
51
+ except Exception as e:
52
+ print(f"Error processing background task: {e}")
53
+ finally:
54
+ if os.path.exists(path):
55
+ os.remove(path)
56
+
57
+ background_tasks.add_task(process_and_reload, temp_filename)
58
+
59
+ return {
60
+ "message": "File received. Processing started in background.",
61
+ "filename": file.filename
62
+ }
63
+
64
+ except Exception as e:
65
+ raise HTTPException(status_code=500, detail=str(e))
66
+
67
+
68
+ @app.post("/chat")
69
+ async def chat_endpoint(request: ChatRequest):
70
+ """
71
+ Streaming endpoint adapted from your working Hugging Face snippet.
72
+ """
73
+ # 1. Setup Inputs
74
+ config = {"configurable": {"thread_id": request.thread_id}}
75
+
76
+ inputs = {
77
+ "query": request.query,
78
+ "RAG": request.use_rag,
79
+ "web_search": request.use_web,
80
+ "model_name": request.model_name,
81
+ "context": [],
82
+ "metadata": [],
83
+ "web_context": "",
84
+ }
85
+
86
+ # 2. Define the Generator (Matching your snippet's logic)
87
+ async def event_generator():
88
+ # Iterate through events (LangGraph's version of bot.stream)
89
+ async for event in rag_app.astream_events(inputs, config=config, version="v1"):
90
+
91
+ # We look for the specific event type that contains the LLM chunks
92
+ kind = event["event"]
93
+
94
+ if kind == "on_chat_model_stream":
95
+ # Get the chunk data
96
+ chunk = event["data"]["chunk"]
97
+
98
+ # Logic from your snippet: check if content exists
99
+ if chunk and hasattr(chunk, "content"):
100
+ content = chunk.content
101
+
102
+ if content:
103
+ # EXACT FORMATTING FROM YOUR SNIPPET
104
+ data = str(content).replace("\n", "\\n")
105
+ yield f"data: {data}\n\n"
106
+
107
+ # 3. Return StreamingResponse (Matching your snippet's headers)
108
+ return StreamingResponse(
109
+ event_generator(),
110
+ media_type="text/event-stream",
111
+ headers={
112
+ "Cache-Control": "no-cache",
113
+ "X-Accel-Buffering": "no", # Critical for Hugging Face
114
+ "Connection": "keep-alive", # Added for extra safety
115
+ },
116
+ )
117
+
118
+
119
+ # ---------------- STT ---------------- #
120
+ @app.post("/stt")
121
+ async def transcribe_audio(file: UploadFile = File(...)):
122
+ try:
123
+ return await STT(file)
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=str(e))
126
+
127
+ # ---------------- TTS ---------------- #
128
+ @app.post("/tts")
129
+ async def text_to_speech(req: TTSRequest):
130
+ try:
131
+ audio_path = await TTS(req.text, req.voice)
132
+ return FileResponse(audio_path, media_type="audio/mpeg", filename="output.mp3")
133
+
134
+ except Exception as e:
135
+ raise HTTPException(status_code=500, detail=str(e))