junaid17 commited on
Commit
5c5b559
·
verified ·
1 Parent(s): bd740be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -38
app.py CHANGED
@@ -1,15 +1,14 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from tools import create_rag_tool, update_retriever
3
- from chatbot import app as app_graph
4
- from langchain_core.messages import HumanMessage
5
- import os
6
- from fastapi.responses import StreamingResponse, FileResponse
7
- from langchain_core.messages import AIMessage
8
  from fastapi.middleware.cors import CORSMiddleware
9
- import asyncio
10
  from pydantic import BaseModel
11
- from utils import TTS, STT
 
12
 
 
 
 
 
13
 
14
  app = FastAPI()
15
 
@@ -21,20 +20,21 @@ app.add_middleware(
21
  allow_headers=["*"],
22
  )
23
 
 
 
 
 
24
  class TTSRequest(BaseModel):
25
  text: str
26
 
27
 
28
- UPLOAD_DIR = "uploads"
29
-
30
  @app.get("/")
31
  def health():
32
- return {'Status' : 'The api is live and running'}
 
33
 
34
  @app.post("/upload")
35
  async def upload_file(file: UploadFile = File(...)):
36
- os.makedirs(UPLOAD_DIR, exist_ok=True)
37
-
38
  file_path = os.path.join(UPLOAD_DIR, file.filename)
39
 
40
  with open(file_path, "wb") as f:
@@ -57,11 +57,10 @@ async def chat(message: str, session_id: str = "default"):
57
  config={"configurable": {"thread_id": session_id}},
58
  stream_mode="messages"
59
  ):
60
- if len(chunk) >= 1:
61
- message_chunk = chunk[0] if isinstance(chunk, tuple) else chunk
62
- if hasattr(message_chunk, 'content') and message_chunk.content:
63
- data = str(message_chunk.content).replace("\n", "\\n")
64
- yield f"data: {data}\n\n"
65
  await asyncio.sleep(0.01)
66
 
67
  return StreamingResponse(
@@ -73,32 +72,22 @@ async def chat(message: str, session_id: str = "default"):
73
  "X-Accel-Buffering": "no",
74
  },
75
  )
76
- # ---------------- STT ---------------- #
77
 
78
  @app.post("/stt")
79
  async def transcribe_audio(file: UploadFile = File(...)):
80
- try:
81
- return await STT(file)
82
- except Exception as e:
83
- raise HTTPException(status_code=500, detail=str(e))
84
 
85
 
86
  @app.post("/tts")
87
  async def generate_tts(request: TTSRequest):
88
- try:
89
- if not request.text.strip():
90
- raise HTTPException(status_code=400, detail="Text is empty")
91
-
92
- audio_path = await TTS(text=request.text)
93
 
94
- if not os.path.exists(audio_path):
95
- raise HTTPException(status_code=500, detail="Audio file not created")
96
 
97
- return FileResponse(
98
- path=audio_path,
99
- media_type="audio/mpeg",
100
- filename="speech.mp3"
101
- )
102
-
103
- except Exception as e:
104
- raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
 
 
 
 
 
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse, FileResponse
4
  from pydantic import BaseModel
5
+ import os
6
+ import asyncio
7
 
8
+ from chatbot import app as app_graph
9
+ from langchain_core.messages import HumanMessage
10
+ from tools import update_retriever
11
+ from utils import TTS, STT
12
 
13
  app = FastAPI()
14
 
 
20
  allow_headers=["*"],
21
  )
22
 
23
+ UPLOAD_DIR = "uploads"
24
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
25
+
26
+
27
  class TTSRequest(BaseModel):
28
  text: str
29
 
30
 
 
 
31
  @app.get("/")
32
  def health():
33
+ return {"status": "API is running"}
34
+
35
 
36
  @app.post("/upload")
37
  async def upload_file(file: UploadFile = File(...)):
 
 
38
  file_path = os.path.join(UPLOAD_DIR, file.filename)
39
 
40
  with open(file_path, "wb") as f:
 
57
  config={"configurable": {"thread_id": session_id}},
58
  stream_mode="messages"
59
  ):
60
+ if chunk:
61
+ msg = chunk[0] if isinstance(chunk, tuple) else chunk
62
+ if hasattr(msg, "content") and msg.content:
63
+ yield f"data: {msg.content}\n\n"
 
64
  await asyncio.sleep(0.01)
65
 
66
  return StreamingResponse(
 
72
  "X-Accel-Buffering": "no",
73
  },
74
  )
75
+
76
 
77
  @app.post("/stt")
78
  async def transcribe_audio(file: UploadFile = File(...)):
79
+ return await STT(file)
 
 
 
80
 
81
 
82
  @app.post("/tts")
83
  async def generate_tts(request: TTSRequest):
84
+ audio_path = await TTS(text=request.text)
 
 
 
 
85
 
86
+ if not os.path.exists(audio_path):
87
+ raise HTTPException(status_code=500, detail="Audio not generated")
88
 
89
+ return FileResponse(
90
+ path=audio_path,
91
+ media_type="audio/mpeg",
92
+ filename="speech.mp3"
93
+ )