junaid17 commited on
Commit
ab7994b
·
verified ·
1 Parent(s): ecd8fcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -99
app.py CHANGED
@@ -1,20 +1,21 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from tools import update_retriever
3
  from chatbot import app as app_graph
4
  from langchain_core.messages import HumanMessage
5
  import os
6
  from fastapi.responses import StreamingResponse, FileResponse
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  import asyncio
9
  from pydantic import BaseModel
10
  from utils import TTS, STT
11
 
 
12
  app = FastAPI()
13
 
14
- # Configure CORS - restrict in production
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"], # TODO: Restrict to specific origins in production
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
@@ -23,82 +24,46 @@ app.add_middleware(
23
  class TTSRequest(BaseModel):
24
  text: str
25
 
 
26
  UPLOAD_DIR = "uploads"
27
- ALLOWED_EXTENSIONS = {".pdf", ".txt", ".docx", ".md"} # Add allowed file types
28
- MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB limit
29
 
30
  @app.get("/")
31
  def health():
32
- return {'status': 'The api is live and running'}
33
 
34
  @app.post("/upload")
35
  async def upload_file(file: UploadFile = File(...)):
36
- try:
37
- # Validate file extension
38
- file_ext = os.path.splitext(file.filename)[1].lower()
39
- if file_ext not in ALLOWED_EXTENSIONS:
40
- raise HTTPException(
41
- status_code=400,
42
- detail=f"File type not allowed. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
43
- )
44
-
45
- # Create upload directory
46
- os.makedirs(UPLOAD_DIR, exist_ok=True)
47
-
48
- # Read file content
49
- content = await file.read()
50
-
51
- # Validate file size
52
- if len(content) > MAX_FILE_SIZE:
53
- raise HTTPException(
54
- status_code=400,
55
- detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024)}MB"
56
- )
57
-
58
- # Save file
59
- file_path = os.path.join(UPLOAD_DIR, file.filename)
60
- with open(file_path, "wb") as f:
61
- f.write(content)
62
-
63
- # Update retriever
64
- update_retriever(file_path)
65
-
66
- return {
67
- "status": "success",
68
- "filename": file.filename,
69
- "size": len(content)
70
- }
71
-
72
- except HTTPException:
73
- raise
74
- except Exception as e:
75
- raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
76
 
77
  @app.post("/chat")
78
  async def chat(message: str, session_id: str = "default"):
79
- if not message.strip():
80
- raise HTTPException(status_code=400, detail="Message cannot be empty")
81
-
82
- # Validate session_id (basic validation)
83
- if not session_id or len(session_id) > 100:
84
- raise HTTPException(status_code=400, detail="Invalid session_id")
85
-
86
  async def event_generator():
87
- try:
88
- async for chunk in app_graph.astream(
89
- {"messages": [HumanMessage(content=message)]},
90
- config={"configurable": {"thread_id": session_id}},
91
- stream_mode="messages"
92
- ):
93
- if len(chunk) >= 1:
94
- message_chunk = chunk[0] if isinstance(chunk, tuple) else chunk
95
- if hasattr(message_chunk, 'content') and message_chunk.content:
96
- data = str(message_chunk.content).replace("\n", "\\n")
97
- yield f"data: {data}\n\n"
98
- await asyncio.sleep(0.01)
99
- except Exception as e:
100
- yield f"data: Error: {str(e)}\n\n"
101
-
102
  return StreamingResponse(
103
  event_generator(),
104
  media_type="text/event-stream",
@@ -108,54 +73,32 @@ async def chat(message: str, session_id: str = "default"):
108
  "X-Accel-Buffering": "no",
109
  },
110
  )
 
111
 
112
  @app.post("/stt")
113
  async def transcribe_audio(file: UploadFile = File(...)):
114
  try:
115
- # Validate audio file type
116
- if not file.content_type or not file.content_type.startswith("audio/"):
117
- raise HTTPException(status_code=400, detail="Invalid audio file")
118
-
119
  return await STT(file)
120
- except HTTPException:
121
- raise
122
  except Exception as e:
123
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
124
 
125
  @app.post("/tts")
126
  async def generate_tts(request: TTSRequest):
127
  try:
128
  if not request.text.strip():
129
  raise HTTPException(status_code=400, detail="Text is empty")
130
-
131
- # Limit text length to prevent abuse
132
- if len(request.text) > 5000:
133
- raise HTTPException(status_code=400, detail="Text too long (max 5000 characters)")
134
-
135
  audio_path = await TTS(text=request.text)
136
-
137
  if not os.path.exists(audio_path):
138
  raise HTTPException(status_code=500, detail="Audio file not created")
139
-
140
  return FileResponse(
141
  path=audio_path,
142
  media_type="audio/mpeg",
143
- filename="speech.mp3",
144
- background=None # Consider adding cleanup after response
145
  )
146
-
147
- except HTTPException:
148
- raise
149
- except Exception as e:
150
- raise HTTPException(status_code=500, detail=f"TTS failed: {str(e)}")
151
 
152
- # Optional: Add cleanup endpoint for old files
153
- @app.on_event("startup")
154
- async def startup_event():
155
- os.makedirs(UPLOAD_DIR, exist_ok=True)
156
-
157
- # Optional: Graceful shutdown
158
- @app.on_event("shutdown")
159
- async def shutdown_event():
160
- # Clean up temporary files if needed
161
- pass
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from tools import create_rag_tool, update_retriever
3
  from chatbot import app as app_graph
4
  from langchain_core.messages import HumanMessage
5
  import os
6
  from fastapi.responses import StreamingResponse, FileResponse
7
+ from langchain_core.messages import AIMessage
8
  from fastapi.middleware.cors import CORSMiddleware
9
  import asyncio
10
  from pydantic import BaseModel
11
  from utils import TTS, STT
12
 
13
+
14
  app = FastAPI()
15
 
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
 
24
  class TTSRequest(BaseModel):
25
  text: str
26
 
27
+
28
  UPLOAD_DIR = "uploads"
 
 
29
 
30
  @app.get("/")
31
  def health():
32
+ return {'Status' : 'The api is live and running'}
33
 
34
  @app.post("/upload")
35
  async def upload_file(file: UploadFile = File(...)):
36
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
37
+
38
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
39
+
40
+ with open(file_path, "wb") as f:
41
+ f.write(await file.read())
42
+
43
+ update_retriever(file_path)
44
+
45
+ return {
46
+ "status": "success",
47
+ "filename": file.filename
48
+ }
49
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @app.post("/chat")
52
  async def chat(message: str, session_id: str = "default"):
53
+
 
 
 
 
 
 
54
  async def event_generator():
55
+ async for chunk in app_graph.astream(
56
+ {"messages": [HumanMessage(content=message)]},
57
+ config={"configurable": {"thread_id": session_id}},
58
+ stream_mode="messages"
59
+ ):
60
+ if len(chunk) >= 1:
61
+ message_chunk = chunk[0] if isinstance(chunk, tuple) else chunk
62
+ if hasattr(message_chunk, 'content') and message_chunk.content:
63
+ data = str(message_chunk.content).replace("\n", "\\n")
64
+ yield f"data: {data}\n\n"
65
+ await asyncio.sleep(0.01)
66
+
 
 
 
67
  return StreamingResponse(
68
  event_generator(),
69
  media_type="text/event-stream",
 
73
  "X-Accel-Buffering": "no",
74
  },
75
  )
76
+ # ---------------- STT ---------------- #
77
 
78
  @app.post("/stt")
79
  async def transcribe_audio(file: UploadFile = File(...)):
80
  try:
 
 
 
 
81
  return await STT(file)
 
 
82
  except Exception as e:
83
+ raise HTTPException(status_code=500, detail=str(e))
84
+
85
 
86
  @app.post("/tts")
87
  async def generate_tts(request: TTSRequest):
88
  try:
89
  if not request.text.strip():
90
  raise HTTPException(status_code=400, detail="Text is empty")
91
+
 
 
 
 
92
  audio_path = await TTS(text=request.text)
93
+
94
  if not os.path.exists(audio_path):
95
  raise HTTPException(status_code=500, detail="Audio file not created")
96
+
97
  return FileResponse(
98
  path=audio_path,
99
  media_type="audio/mpeg",
100
+ filename="speech.mp3"
 
101
  )
 
 
 
 
 
102
 
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=str(e))