junaid17 commited on
Commit
ecd8fcb
·
verified ·
1 Parent(s): 8b5e73d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -104
app.py CHANGED
@@ -1,104 +1,161 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from tools import create_rag_tool, update_retriever
3
- from chatbot import app as app_graph
4
- from langchain_core.messages import HumanMessage
5
- import os
6
- from fastapi.responses import StreamingResponse, FileResponse
7
- from langchain_core.messages import AIMessage
8
- from fastapi.middleware.cors import CORSMiddleware
9
- import asyncio
10
- from pydantic import BaseModel
11
- from utils import TTS, STT
12
-
13
-
14
- app = FastAPI()
15
-
16
- app.add_middleware(
17
- CORSMiddleware,
18
- allow_origins=["*"],
19
- allow_credentials=True,
20
- allow_methods=["*"],
21
- allow_headers=["*"],
22
- )
23
-
24
- class TTSRequest(BaseModel):
25
- text: str
26
-
27
-
28
- UPLOAD_DIR = "uploads"
29
-
30
- @app.get("/")
31
- def health():
32
- return {'Status' : 'The api is live and running'}
33
-
34
- @app.post("/upload")
35
- async def upload_file(file: UploadFile = File(...)):
36
- os.makedirs(UPLOAD_DIR, exist_ok=True)
37
-
38
- file_path = os.path.join(UPLOAD_DIR, file.filename)
39
-
40
- with open(file_path, "wb") as f:
41
- f.write(await file.read())
42
-
43
- update_retriever(file_path)
44
-
45
- return {
46
- "status": "success",
47
- "filename": file.filename
48
- }
49
-
50
-
51
- @app.post("/chat")
52
- async def chat(message: str, session_id: str = "default"):
53
-
54
- async def event_generator():
55
- async for chunk in app_graph.astream(
56
- {"messages": [HumanMessage(content=message)]},
57
- config={"configurable": {"thread_id": session_id}},
58
- stream_mode="messages"
59
- ):
60
- if len(chunk) >= 1:
61
- message_chunk = chunk[0] if isinstance(chunk, tuple) else chunk
62
- if hasattr(message_chunk, 'content') and message_chunk.content:
63
- data = str(message_chunk.content).replace("\n", "\\n")
64
- yield f"data: {data}\n\n"
65
- await asyncio.sleep(0.01)
66
-
67
- return StreamingResponse(
68
- event_generator(),
69
- media_type="text/event-stream",
70
- headers={
71
- "Cache-Control": "no-cache",
72
- "Connection": "keep-alive",
73
- "X-Accel-Buffering": "no",
74
- },
75
- )
76
- # ---------------- STT ---------------- #
77
-
78
- @app.post("/stt")
79
- async def transcribe_audio(file: UploadFile = File(...)):
80
- try:
81
- return await STT(file)
82
- except Exception as e:
83
- raise HTTPException(status_code=500, detail=str(e))
84
-
85
-
86
- @app.post("/tts")
87
- async def generate_tts(request: TTSRequest):
88
- try:
89
- if not request.text.strip():
90
- raise HTTPException(status_code=400, detail="Text is empty")
91
-
92
- audio_path = await TTS(text=request.text)
93
-
94
- if not os.path.exists(audio_path):
95
- raise HTTPException(status_code=500, detail="Audio file not created")
96
-
97
- return FileResponse(
98
- path=audio_path,
99
- media_type="audio/mpeg",
100
- filename="speech.mp3"
101
- )
102
-
103
- except Exception as e:
104
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from tools import update_retriever
3
+ from chatbot import app as app_graph
4
+ from langchain_core.messages import HumanMessage
5
+ import os
6
+ from fastapi.responses import StreamingResponse, FileResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ import asyncio
9
+ from pydantic import BaseModel
10
+ from utils import TTS, STT
11
+
12
+ app = FastAPI()
13
+
14
+ # Configure CORS - restrict in production
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"], # TODO: Restrict to specific origins in production
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ class TTSRequest(BaseModel):
24
+ text: str
25
+
26
+ UPLOAD_DIR = "uploads"
27
+ ALLOWED_EXTENSIONS = {".pdf", ".txt", ".docx", ".md"} # Add allowed file types
28
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB limit
29
+
30
+ @app.get("/")
31
+ def health():
32
+ return {'status': 'The api is live and running'}
33
+
34
+ @app.post("/upload")
35
+ async def upload_file(file: UploadFile = File(...)):
36
+ try:
37
+ # Validate file extension
38
+ file_ext = os.path.splitext(file.filename)[1].lower()
39
+ if file_ext not in ALLOWED_EXTENSIONS:
40
+ raise HTTPException(
41
+ status_code=400,
42
+ detail=f"File type not allowed. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
43
+ )
44
+
45
+ # Create upload directory
46
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
47
+
48
+ # Read file content
49
+ content = await file.read()
50
+
51
+ # Validate file size
52
+ if len(content) > MAX_FILE_SIZE:
53
+ raise HTTPException(
54
+ status_code=400,
55
+ detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024)}MB"
56
+ )
57
+
58
+ # Save file
59
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
60
+ with open(file_path, "wb") as f:
61
+ f.write(content)
62
+
63
+ # Update retriever
64
+ update_retriever(file_path)
65
+
66
+ return {
67
+ "status": "success",
68
+ "filename": file.filename,
69
+ "size": len(content)
70
+ }
71
+
72
+ except HTTPException:
73
+ raise
74
+ except Exception as e:
75
+ raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
76
+
77
+ @app.post("/chat")
78
+ async def chat(message: str, session_id: str = "default"):
79
+ if not message.strip():
80
+ raise HTTPException(status_code=400, detail="Message cannot be empty")
81
+
82
+ # Validate session_id (basic validation)
83
+ if not session_id or len(session_id) > 100:
84
+ raise HTTPException(status_code=400, detail="Invalid session_id")
85
+
86
+ async def event_generator():
87
+ try:
88
+ async for chunk in app_graph.astream(
89
+ {"messages": [HumanMessage(content=message)]},
90
+ config={"configurable": {"thread_id": session_id}},
91
+ stream_mode="messages"
92
+ ):
93
+ if len(chunk) >= 1:
94
+ message_chunk = chunk[0] if isinstance(chunk, tuple) else chunk
95
+ if hasattr(message_chunk, 'content') and message_chunk.content:
96
+ data = str(message_chunk.content).replace("\n", "\\n")
97
+ yield f"data: {data}\n\n"
98
+ await asyncio.sleep(0.01)
99
+ except Exception as e:
100
+ yield f"data: Error: {str(e)}\n\n"
101
+
102
+ return StreamingResponse(
103
+ event_generator(),
104
+ media_type="text/event-stream",
105
+ headers={
106
+ "Cache-Control": "no-cache",
107
+ "Connection": "keep-alive",
108
+ "X-Accel-Buffering": "no",
109
+ },
110
+ )
111
+
112
+ @app.post("/stt")
113
+ async def transcribe_audio(file: UploadFile = File(...)):
114
+ try:
115
+ # Validate audio file type
116
+ if not file.content_type or not file.content_type.startswith("audio/"):
117
+ raise HTTPException(status_code=400, detail="Invalid audio file")
118
+
119
+ return await STT(file)
120
+ except HTTPException:
121
+ raise
122
+ except Exception as e:
123
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
124
+
125
+ @app.post("/tts")
126
+ async def generate_tts(request: TTSRequest):
127
+ try:
128
+ if not request.text.strip():
129
+ raise HTTPException(status_code=400, detail="Text is empty")
130
+
131
+ # Limit text length to prevent abuse
132
+ if len(request.text) > 5000:
133
+ raise HTTPException(status_code=400, detail="Text too long (max 5000 characters)")
134
+
135
+ audio_path = await TTS(text=request.text)
136
+
137
+ if not os.path.exists(audio_path):
138
+ raise HTTPException(status_code=500, detail="Audio file not created")
139
+
140
+ return FileResponse(
141
+ path=audio_path,
142
+ media_type="audio/mpeg",
143
+ filename="speech.mp3",
144
+ background=None # Consider adding cleanup after response
145
+ )
146
+
147
+ except HTTPException:
148
+ raise
149
+ except Exception as e:
150
+ raise HTTPException(status_code=500, detail=f"TTS failed: {str(e)}")
151
+
152
+ # Optional: Add cleanup endpoint for old files
153
+ @app.on_event("startup")
154
+ async def startup_event():
155
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
156
+
157
+ # Optional: Graceful shutdown
158
+ @app.on_event("shutdown")
159
+ async def shutdown_event():
160
+ # Clean up temporary files if needed
161
+ pass