junaid17 commited on
Commit
de58f11
·
verified ·
1 Parent(s): d1f2f58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -5,11 +5,12 @@ from langchain_core.messages import HumanMessage
5
  from chatbot import app as app_graph
6
  from tools import update_retriever
7
  from utils import STT, TTS
8
- import asyncio
9
  import os
 
10
 
11
  app = FastAPI()
12
 
 
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
@@ -24,25 +25,34 @@ os.makedirs(UPLOAD_DIR, exist_ok=True)
24
 
25
  @app.get("/")
26
  def health():
27
- return {"status": "running"}
28
 
29
 
 
 
 
30
  @app.post("/upload")
31
  async def upload_file(file: UploadFile = File(...)):
32
- path = os.path.join(UPLOAD_DIR, file.filename)
33
 
34
- with open(path, "wb") as f:
35
  f.write(await file.read())
36
 
37
- update_retriever(path)
38
 
39
- return {"status": "uploaded", "file": file.filename}
 
 
 
40
 
41
 
 
 
 
42
  @app.post("/chat")
43
  async def chat(message: str, session_id: str = "default"):
44
 
45
- async def stream():
46
  async for chunk in app_graph.astream(
47
  {"messages": [HumanMessage(content=message)]},
48
  config={"configurable": {"thread_id": session_id}},
@@ -52,15 +62,21 @@ async def chat(message: str, session_id: str = "default"):
52
  if hasattr(msg, "content") and msg.content:
53
  yield msg.content + "\n"
54
 
55
- return StreamingResponse(stream(), media_type="text/plain")
56
 
57
 
 
 
 
58
  @app.post("/stt")
59
  async def stt(file: UploadFile = File(...)):
60
  return await STT(file)
61
 
62
 
 
 
 
63
  @app.post("/tts")
64
  async def tts(text: str):
65
- path = await TTS(text)
66
- return FileResponse(path)
 
5
  from chatbot import app as app_graph
6
  from tools import update_retriever
7
  from utils import STT, TTS
 
8
  import os
9
+ import asyncio
10
 
11
  app = FastAPI()
12
 
13
+ # ---------------- CORS ---------------- #
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
25
 
26
  @app.get("/")
27
  def health():
28
+ return {"status": "API is running"}
29
 
30
 
31
+ # =============================
32
+ # Upload PDF
33
+ # =============================
34
  @app.post("/upload")
35
  async def upload_file(file: UploadFile = File(...)):
36
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
37
 
38
+ with open(file_path, "wb") as f:
39
  f.write(await file.read())
40
 
41
+ update_retriever(file_path)
42
 
43
+ return {
44
+ "status": "success",
45
+ "filename": file.filename
46
+ }
47
 
48
 
49
+ # =============================
50
+ # Chat Endpoint
51
+ # =============================
52
  @app.post("/chat")
53
  async def chat(message: str, session_id: str = "default"):
54
 
55
+ async def event_stream():
56
  async for chunk in app_graph.astream(
57
  {"messages": [HumanMessage(content=message)]},
58
  config={"configurable": {"thread_id": session_id}},
 
62
  if hasattr(msg, "content") and msg.content:
63
  yield msg.content + "\n"
64
 
65
+ return StreamingResponse(event_stream(), media_type="text/plain")
66
 
67
 
68
+ # =============================
69
+ # STT
70
+ # =============================
71
  @app.post("/stt")
72
  async def stt(file: UploadFile = File(...)):
73
  return await STT(file)
74
 
75
 
76
+ # =============================
77
+ # TTS
78
+ # =============================
79
  @app.post("/tts")
80
  async def tts(text: str):
81
+ audio_path = await TTS(text)
82
+ return FileResponse(audio_path)