junaid17 commited on
Commit
9e9c87f
·
verified ·
1 Parent(s): 8e579f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -58
app.py CHANGED
@@ -1,16 +1,15 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from tools import update_retriever
3
- from chatbot import app as app_graph, rebuild_graph
4
  from langchain_core.messages import HumanMessage
5
  import os
6
  from fastapi.responses import StreamingResponse, FileResponse
 
7
  from fastapi.middleware.cors import CORSMiddleware
 
8
  from pydantic import BaseModel
9
  from utils import TTS, STT
10
 
11
- # =====================================================
12
- # APP SETUP
13
- # =====================================================
14
 
15
  app = FastAPI()
16
 
@@ -22,80 +21,48 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # =====================================================
26
- # MODELS
27
- # =====================================================
28
-
29
  class TTSRequest(BaseModel):
30
  text: str
31
 
32
 
33
- # ⚠️ HF requires persistent storage under /data
34
- UPLOAD_DIR = "/data/uploads"
35
- os.makedirs(UPLOAD_DIR, exist_ok=True)
36
-
37
- # =====================================================
38
- # HEALTH CHECK
39
- # =====================================================
40
 
41
  @app.get("/")
42
  def health():
43
- return {"Status": "The api is live and running"}
44
-
45
-
46
- # =====================================================
47
- # FILE UPLOAD (RAG)
48
- # =====================================================
49
 
50
  @app.post("/upload")
51
  async def upload_file(file: UploadFile = File(...)):
 
 
52
  file_path = os.path.join(UPLOAD_DIR, file.filename)
53
 
54
  with open(file_path, "wb") as f:
55
  f.write(await file.read())
56
 
57
- # Update vector store
58
  update_retriever(file_path)
59
 
60
- # 🔥 Rebuild LangGraph so RAG becomes active
61
- rebuild_graph()
62
-
63
  return {
64
  "status": "success",
65
  "filename": file.filename
66
  }
67
 
68
 
69
- # =====================================================
70
- # CHAT ENDPOINT (STREAMING)
71
- # =====================================================
72
-
73
  @app.post("/chat")
74
  async def chat(message: str, session_id: str = "default"):
75
 
76
  async def event_generator():
77
- buffer = ""
78
-
79
  async for chunk in app_graph.astream(
80
  {"messages": [HumanMessage(content=message)]},
81
  config={"configurable": {"thread_id": session_id}},
82
  stream_mode="messages"
83
  ):
84
- if not chunk:
85
- continue
86
-
87
- msg = chunk[0] if isinstance(chunk, tuple) else chunk
88
-
89
- if hasattr(msg, "content") and msg.content:
90
- buffer += msg.content
91
-
92
- # Flush every ~150 characters (prevents broken tokens)
93
- if len(buffer) > 150:
94
- yield f"data: {buffer.strip()}\n\n"
95
- buffer = ""
96
-
97
- if buffer:
98
- yield f"data: {buffer.strip()}\n\n"
99
 
100
  return StreamingResponse(
101
  event_generator(),
@@ -104,14 +71,9 @@ async def chat(message: str, session_id: str = "default"):
104
  "Cache-Control": "no-cache",
105
  "Connection": "keep-alive",
106
  "X-Accel-Buffering": "no",
107
- "Access-Control-Allow-Origin": "*",
108
  },
109
  )
110
-
111
-
112
- # =====================================================
113
- # STT
114
- # =====================================================
115
 
116
  @app.post("/stt")
117
  async def transcribe_audio(file: UploadFile = File(...)):
@@ -121,10 +83,6 @@ async def transcribe_audio(file: UploadFile = File(...)):
121
  raise HTTPException(status_code=500, detail=str(e))
122
 
123
 
124
- # =====================================================
125
- # TTS
126
- # =====================================================
127
-
128
  @app.post("/tts")
129
  async def generate_tts(request: TTSRequest):
130
  try:
@@ -143,4 +101,4 @@ async def generate_tts(request: TTSRequest):
143
  )
144
 
145
  except Exception as e:
146
- raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from tools import create_rag_tool, update_retriever
3
+ from chatbot import app as app_graph
4
  from langchain_core.messages import HumanMessage
5
  import os
6
  from fastapi.responses import StreamingResponse, FileResponse
7
+ from langchain_core.messages import AIMessage
8
  from fastapi.middleware.cors import CORSMiddleware
9
+ import asyncio
10
  from pydantic import BaseModel
11
  from utils import TTS, STT
12
 
 
 
 
13
 
14
  app = FastAPI()
15
 
 
21
  allow_headers=["*"],
22
  )
23
 
 
 
 
 
24
  class TTSRequest(BaseModel):
25
  text: str
26
 
27
 
28
+ UPLOAD_DIR = "uploads"
 
 
 
 
 
 
29
 
30
  @app.get("/")
31
  def health():
32
+ return {'Status' : 'The api is live and running'}
 
 
 
 
 
33
 
34
  @app.post("/upload")
35
  async def upload_file(file: UploadFile = File(...)):
36
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
37
+
38
  file_path = os.path.join(UPLOAD_DIR, file.filename)
39
 
40
  with open(file_path, "wb") as f:
41
  f.write(await file.read())
42
 
 
43
  update_retriever(file_path)
44
 
 
 
 
45
  return {
46
  "status": "success",
47
  "filename": file.filename
48
  }
49
 
50
 
 
 
 
 
51
  @app.post("/chat")
52
  async def chat(message: str, session_id: str = "default"):
53
 
54
  async def event_generator():
 
 
55
  async for chunk in app_graph.astream(
56
  {"messages": [HumanMessage(content=message)]},
57
  config={"configurable": {"thread_id": session_id}},
58
  stream_mode="messages"
59
  ):
60
+ if len(chunk) >= 1:
61
+ message_chunk = chunk[0] if isinstance(chunk, tuple) else chunk
62
+ if hasattr(message_chunk, 'content') and message_chunk.content:
63
+ data = str(message_chunk.content).replace("\n", "\\n")
64
+ yield f"data: {data}\n\n"
65
+ await asyncio.sleep(0.01)
 
 
 
 
 
 
 
 
 
66
 
67
  return StreamingResponse(
68
  event_generator(),
 
71
  "Cache-Control": "no-cache",
72
  "Connection": "keep-alive",
73
  "X-Accel-Buffering": "no",
 
74
  },
75
  )
76
+ # ---------------- STT ---------------- #
 
 
 
 
77
 
78
  @app.post("/stt")
79
  async def transcribe_audio(file: UploadFile = File(...)):
 
83
  raise HTTPException(status_code=500, detail=str(e))
84
 
85
 
 
 
 
 
86
  @app.post("/tts")
87
  async def generate_tts(request: TTSRequest):
88
  try:
 
101
  )
102
 
103
  except Exception as e:
104
+ raise HTTPException(status_code=500, detail=str(e))