Hammad712 commited on
Commit
3e1a2e1
·
verified ·
1 Parent(s): 77de2d6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +447 -61
main.py CHANGED
@@ -1,16 +1,26 @@
1
- from fastapi import FastAPI, HTTPException, Body, Query, File, UploadFile, Form
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
 
 
4
  from typing import List, Optional, Dict, Any, Union
5
  import uuid
6
  import os
 
 
 
7
  from dotenv import load_dotenv
 
 
 
 
 
8
 
9
  # Load environment variables
10
  load_dotenv()
11
 
12
- # Import necessary libraries
13
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14
  from langchain_community.vectorstores import FAISS
15
  from langchain.chains import ConversationalRetrievalChain
16
  from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
@@ -20,8 +30,26 @@ from langchain_groq import ChatGroq
20
  from google import genai
21
  from google.genai import types
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Initialize FastAPI app
24
- app = FastAPI(title="RAG System API", description="An API for question answering based on YouTube video content or uploaded video files")
25
 
26
  # Configure CORS
27
  app.add_middleware(
@@ -38,15 +66,169 @@ class TranscriptionRequest(BaseModel):
38
 
39
  class QueryRequest(BaseModel):
40
  query: str
41
- session_id: Optional[str] = None
42
 
43
  class QueryResponse(BaseModel):
44
  answer: str
45
  session_id: str
46
  source_documents: Optional[List[str]] = None
47
 
48
- # Global variables
49
- sessions = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Initialize Google API client
52
  def init_google_client():
@@ -59,7 +241,7 @@ def init_google_client():
59
  def get_llm():
60
  """
61
  Returns the language model instance (LLM) using ChatGroq API.
62
- The LLM used is Llama 3.1 with a versatile 70 billion parameters model.
63
  """
64
  api_key = os.getenv("GROQ_API_KEY", "")
65
  if not api_key:
@@ -78,7 +260,7 @@ def get_embeddings():
78
  model_name = "BAAI/bge-small-en"
79
  model_kwargs = {"device": "cpu"}
80
  encode_kwargs = {"normalize_embeddings": True}
81
- embeddings = HuggingFaceBgeEmbeddings(
82
  model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
83
  )
84
  return embeddings
@@ -125,7 +307,7 @@ def create_chain(retriever):
125
  return chain
126
 
127
  # Process transcription and prepare RAG system
128
- def process_transcription(transcription):
129
  # Process the transcription
130
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20)
131
  all_splits = text_splitter.split_text(transcription)
@@ -138,17 +320,77 @@ def process_transcription(transcription):
138
  # Create a session ID
139
  session_id = str(uuid.uuid4())
140
 
141
- # Store session data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  sessions[session_id] = {
143
  "retriever": retriever,
144
- "chat_history": [],
145
- "transcription": transcription
146
  }
147
 
148
  return session_id
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  @app.post("/transcribe", response_model=Dict[str, str])
151
- async def transcribe_video(request: TranscriptionRequest):
 
 
 
152
  """
153
  Transcribe a YouTube video and prepare the RAG system
154
  """
@@ -173,7 +415,14 @@ async def transcribe_video(request: TranscriptionRequest):
173
  transcription = response.candidates[0].content.parts[0].text
174
 
175
  # Process transcription and get session ID
176
- session_id = process_transcription(transcription)
 
 
 
 
 
 
 
177
 
178
  return {"session_id": session_id, "message": "YouTube video transcribed and RAG system prepared"}
179
 
@@ -181,14 +430,21 @@ async def transcribe_video(request: TranscriptionRequest):
181
  raise HTTPException(status_code=500, detail=f"Error transcribing video: {str(e)}")
182
 
183
  @app.post("/upload", response_model=Dict[str, str])
184
- async def upload_video(file: UploadFile = File(...), prompt: str = Form("Transcribe the Video. Write all the things described in the video")):
 
 
 
 
 
 
185
  """
186
  Upload a video file (max 20MB), transcribe it and prepare the RAG system
187
  """
188
  try:
189
  # Check file size (20MB limit)
190
  contents = await file.read()
191
- if len(contents) > 20 * 1024 * 1024: # 20MB in bytes
 
192
  raise HTTPException(status_code=400, detail="File size exceeds 20MB limit")
193
 
194
  # Check file type
@@ -215,7 +471,19 @@ async def upload_video(file: UploadFile = File(...), prompt: str = Form("Transcr
215
  transcription = response.candidates[0].content.parts[0].text
216
 
217
  # Process transcription and get session ID
218
- session_id = process_transcription(transcription)
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  return {"session_id": session_id, "message": "Uploaded video transcribed and RAG system prepared"}
221
 
@@ -225,18 +493,73 @@ async def upload_video(file: UploadFile = File(...), prompt: str = Form("Transcr
225
  # Reset file pointer
226
  await file.seek(0)
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  @app.post("/query", response_model=QueryResponse)
229
- async def query_system(request: QueryRequest):
 
 
 
230
  """
231
  Query the RAG system with a question
232
  """
233
  try:
234
  session_id = request.session_id
235
 
236
- # Create a new session if none provided
237
  if not session_id or session_id not in sessions:
238
  raise HTTPException(status_code=404, detail="Session not found. Please transcribe a video first.")
239
 
 
 
 
 
 
240
  # Get session data
241
  session = sessions[session_id]
242
  retriever = session["retriever"]
@@ -245,11 +568,17 @@ async def query_system(request: QueryRequest):
245
  # Create chain
246
  chain = create_chain(retriever)
247
 
 
 
 
 
 
248
  # Query the chain
249
- result = chain({"question": request.query, "chat_history": chat_history})
250
 
251
  # Update chat history
252
- chat_history.append((request.query, result["answer"]))
 
253
 
254
  # Prepare source documents
255
  source_docs = [doc.page_content[:100] + "..." for doc in result.get("source_documents", [])]
@@ -263,31 +592,107 @@ async def query_system(request: QueryRequest):
263
  except Exception as e:
264
  raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}")
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  @app.get("/sessions/{session_id}", response_model=Dict[str, Any])
267
- async def get_session_info(session_id: str):
 
 
 
268
  """
269
  Get information about a specific session
270
  """
271
- if session_id not in sessions:
 
 
 
272
  raise HTTPException(status_code=404, detail="Session not found")
273
 
274
- session = sessions[session_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  return {
277
  "session_id": session_id,
278
- "chat_history_length": len(session["chat_history"]),
279
- "transcription_preview": session["transcription"][:200] + "..."
 
 
 
 
 
280
  }
281
 
282
  @app.delete("/sessions/{session_id}")
283
- async def delete_session(session_id: str):
 
 
 
284
  """
285
  Delete a session
286
  """
287
- if session_id not in sessions:
 
 
 
288
  raise HTTPException(status_code=404, detail="Session not found")
289
 
290
- del sessions[session_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  return {"message": f"Session {session_id} deleted successfully"}
292
 
293
  @app.get("/")
@@ -298,43 +703,24 @@ async def root():
298
  return {
299
  "message": "Video Transcription and QA API",
300
  "endpoints": {
 
 
301
  "/transcribe": "Transcribe YouTube videos",
302
  "/upload": "Upload and transcribe video files (max 20MB)",
 
303
  "/query": "Query the RAG system",
 
304
  "/sessions/{session_id}": "Get session information",
305
  }
306
  }
307
- @app.route('/transcribe-audio', methods=['POST'])
308
- def transcribe_audio():
309
- if 'audio' not in request.files:
310
- return jsonify({"error": "No audio file provided"}), 400
311
-
312
- audio_file = request.files['audio']
313
-
314
- # Save the uploaded file temporarily
315
- temp_path = os.path.join(os.path.dirname(__file__), "temp_audio.m4a")
316
- audio_file.save(temp_path)
317
-
318
- try:
319
- # Use Groq client to transcribe the audio
320
- with open(temp_path, "rb") as file:
321
- transcription = client.audio.transcriptions.create(
322
- file=(temp_path, file.read()),
323
- model="whisper-large-v3",
324
- response_format="verbose_json",
325
- )
326
-
327
- # Return the transcription result
328
- return jsonify({"transcription": transcription.text})
329
-
330
- except Exception as e:
331
- return jsonify({"error": str(e)}), 500
332
-
333
- finally:
334
- # Clean up the temporary file
335
- if os.path.exists(temp_path):
336
- os.remove(temp_path)
337
 
338
  if __name__ == "__main__":
339
  import uvicorn
 
340
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, Depends, File, UploadFile, Form, Response, BackgroundTasks
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
4
+ from fastapi.responses import StreamingResponse
5
+ from pydantic import BaseModel, Field, EmailStr
6
  from typing import List, Optional, Dict, Any, Union
7
  import uuid
8
  import os
9
+ import io
10
+ import shutil
11
+ from datetime import datetime, timedelta
12
  from dotenv import load_dotenv
13
+ import hashlib
14
+ import jwt
15
+ from passlib.context import CryptContext
16
+ from pymongo import MongoClient
17
+ from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
18
 
19
  # Load environment variables
20
  load_dotenv()
21
 
22
+ # Import necessary libraries - updating deprecated imports
23
+ from langchain_huggingface import HuggingFaceEmbeddings
24
  from langchain_community.vectorstores import FAISS
25
  from langchain.chains import ConversationalRetrievalChain
26
  from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
 
30
  from google import genai
31
  from google.genai import types
32
 
33
+ # MongoDB Configuration
34
+ MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017")
35
+ DATABASE_NAME = os.getenv("MONGO_DB_NAME", "rag_system")
36
+ CHAT_COLLECTION = "chat_history"
37
+ USER_COLLECTION = "users"
38
+ VIDEO_COLLECTION = "videos"
39
+
40
+ # Security
41
+ SECRET_KEY = os.getenv("SECRET_KEY", "your_secret_key_here")
42
+ ALGORITHM = "HS256"
43
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
44
+
45
+ # Password hashing
46
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
47
+
48
+ # OAuth2 scheme
49
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
50
+
51
  # Initialize FastAPI app
52
+ app = FastAPI(title="RAG System API", description="An API for question answering based on video content with user authentication")
53
 
54
  # Configure CORS
55
  app.add_middleware(
 
66
 
67
  class QueryRequest(BaseModel):
68
  query: str
69
+ session_id: str
70
 
71
  class QueryResponse(BaseModel):
72
  answer: str
73
  session_id: str
74
  source_documents: Optional[List[str]] = None
75
 
76
+ class User(BaseModel):
77
+ username: str
78
+ email: EmailStr
79
+ full_name: Optional[str] = None
80
+
81
+ class UserInDB(User):
82
+ hashed_password: str
83
+
84
+ class UserCreate(User):
85
+ password: str
86
+
87
+ class Token(BaseModel):
88
+ access_token: str
89
+ token_type: str
90
+
91
+ class TokenData(BaseModel):
92
+ username: Optional[str] = None
93
+
94
+ class VideoData(BaseModel):
95
+ video_id: str
96
+ user_id: str
97
+ title: str
98
+ source_type: str # "youtube" or "upload"
99
+ source_url: Optional[str] = None
100
+ created_at: datetime = Field(default_factory=datetime.utcnow)
101
+ transcription: str
102
+ size: Optional[int] = None
103
+
104
+ # MongoDB connection and chat management
105
+ class MongoDB:
106
+ def __init__(self):
107
+ self.client = MongoClient(MONGO_URI)
108
+ self.db = self.client[DATABASE_NAME]
109
+ self.users = self.db[USER_COLLECTION]
110
+ self.videos = self.db[VIDEO_COLLECTION]
111
+
112
+ # Ensure indexes
113
+ self.users.create_index("username", unique=True)
114
+ self.users.create_index("email", unique=True)
115
+ self.videos.create_index("video_id", unique=True)
116
+ self.videos.create_index("user_id")
117
+
118
+ def close(self):
119
+ self.client.close()
120
+
121
+ # Chat Management Class
122
+ class ChatManagement:
123
+ def __init__(self, cluster_url, database_name, collection_name):
124
+ self.connection_string = cluster_url
125
+ self.database_name = database_name
126
+ self.collection_name = collection_name
127
+ self.chat_sessions = {} # Dictionary to store chat history objects for each session
128
+
129
+ def create_new_chat(self):
130
+ # Generate a unique chat ID
131
+ chat_id = str(uuid.uuid4())
132
+ # Initialize MongoDBChatMessageHistory for the chat session
133
+ chat_message_history = MongoDBChatMessageHistory(
134
+ session_id=chat_id,
135
+ connection_string=self.connection_string,
136
+ database_name=self.database_name,
137
+ collection_name=self.collection_name
138
+ )
139
+ # Store the chat_message_history object in the session dictionary
140
+ self.chat_sessions[chat_id] = chat_message_history
141
+ return chat_id
142
+
143
+ def get_chat_history(self, chat_id):
144
+ # Check if the chat session is already in memory
145
+ if chat_id in self.chat_sessions:
146
+ return self.chat_sessions[chat_id]
147
+ # If not in memory, try to fetch from the database
148
+ chat_message_history = MongoDBChatMessageHistory(
149
+ session_id=chat_id,
150
+ connection_string=self.connection_string,
151
+ database_name=self.database_name,
152
+ collection_name=self.collection_name
153
+ )
154
+ if chat_message_history.messages: # Check if the session exists in the database
155
+ self.chat_sessions[chat_id] = chat_message_history
156
+ return chat_message_history
157
+ return None # Chat session not found
158
+
159
+ def initialize_chat_history(self, chat_id):
160
+ # If the chat history already exists, return it
161
+ if chat_id in self.chat_sessions:
162
+ return self.chat_sessions[chat_id]
163
+ # Otherwise, create a new chat history
164
+ chat_message_history = MongoDBChatMessageHistory(
165
+ session_id=chat_id,
166
+ connection_string=self.connection_string,
167
+ database_name=self.database_name,
168
+ collection_name=self.collection_name
169
+ )
170
+ # Save the new chat session to the session dictionary
171
+ self.chat_sessions[chat_id] = chat_message_history
172
+ return chat_message_history
173
+
174
+ # Global variables and instances
175
+ mongodb = MongoDB()
176
+ chat_manager = ChatManagement(MONGO_URI, DATABASE_NAME, CHAT_COLLECTION)
177
+ sessions = {} # In-memory session storage for retrievers
178
+
179
+ # Video directory for temporary storage
180
+ VIDEOS_DIR = "temp_videos"
181
+ os.makedirs(VIDEOS_DIR, exist_ok=True)
182
+
183
+ # Security functions
184
+ def verify_password(plain_password, hashed_password):
185
+ return pwd_context.verify(plain_password, hashed_password)
186
+
187
+ def get_password_hash(password):
188
+ return pwd_context.hash(password)
189
+
190
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
191
+ to_encode = data.copy()
192
+ if expires_delta:
193
+ expire = datetime.utcnow() + expires_delta
194
+ else:
195
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
196
+ to_encode.update({"exp": expire})
197
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
198
+ return encoded_jwt
199
+
200
+ def get_user(username: str):
201
+ user_data = mongodb.users.find_one({"username": username})
202
+ if user_data:
203
+ return UserInDB(**user_data)
204
+ return None
205
+
206
+ def authenticate_user(username: str, password: str):
207
+ user = get_user(username)
208
+ if not user:
209
+ return False
210
+ if not verify_password(password, user.hashed_password):
211
+ return False
212
+ return user
213
+
214
+ async def get_current_user(token: str = Depends(oauth2_scheme)):
215
+ credentials_exception = HTTPException(
216
+ status_code=401,
217
+ detail="Could not validate credentials",
218
+ headers={"WWW-Authenticate": "Bearer"},
219
+ )
220
+ try:
221
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
222
+ username: str = payload.get("sub")
223
+ if username is None:
224
+ raise credentials_exception
225
+ token_data = TokenData(username=username)
226
+ except jwt.PyJWTError:
227
+ raise credentials_exception
228
+ user = get_user(username=token_data.username)
229
+ if user is None:
230
+ raise credentials_exception
231
+ return user
232
 
233
  # Initialize Google API client
234
  def init_google_client():
 
241
  def get_llm():
242
  """
243
  Returns the language model instance (LLM) using ChatGroq API.
244
+ The LLM used is Llama 3.3 with a versatile 70 billion parameters model.
245
  """
246
  api_key = os.getenv("GROQ_API_KEY", "")
247
  if not api_key:
 
260
  model_name = "BAAI/bge-small-en"
261
  model_kwargs = {"device": "cpu"}
262
  encode_kwargs = {"normalize_embeddings": True}
263
+ embeddings = HuggingFaceEmbeddings(
264
  model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
265
  )
266
  return embeddings
 
307
  return chain
308
 
309
  # Process transcription and prepare RAG system
310
+ def process_transcription(transcription, user_id, title, source_type, source_url=None, file_size=None):
311
  # Process the transcription
312
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20)
313
  all_splits = text_splitter.split_text(transcription)
 
320
  # Create a session ID
321
  session_id = str(uuid.uuid4())
322
 
323
+ # Store video data in MongoDB
324
+ video_data = {
325
+ "video_id": session_id,
326
+ "user_id": user_id,
327
+ "title": title,
328
+ "source_type": source_type,
329
+ "source_url": source_url,
330
+ "created_at": datetime.utcnow(),
331
+ "transcription": transcription,
332
+ "size": file_size
333
+ }
334
+
335
+ mongodb.videos.insert_one(video_data)
336
+
337
+ # Store session data in memory
338
  sessions[session_id] = {
339
  "retriever": retriever,
340
+ "chat_history": chat_manager.initialize_chat_history(session_id)
 
341
  }
342
 
343
  return session_id
344
 
345
+ # Save video to disk (background task)
346
+ def save_video_file(video_id, file_path, contents):
347
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
348
+ with open(file_path, "wb") as f:
349
+ f.write(contents)
350
+
351
+ # Auth endpoints
352
+ @app.post("/register", response_model=User)
353
+ async def register_user(user: UserCreate):
354
+ # Check if username already exists
355
+ if mongodb.users.find_one({"username": user.username}):
356
+ raise HTTPException(status_code=400, detail="Username already registered")
357
+
358
+ # Check if email already exists
359
+ if mongodb.users.find_one({"email": user.email}):
360
+ raise HTTPException(status_code=400, detail="Email already registered")
361
+
362
+ # Create user
363
+ hashed_password = get_password_hash(user.password)
364
+ user_dict = user.dict()
365
+ del user_dict["password"]
366
+ user_dict["hashed_password"] = hashed_password
367
+
368
+ # Insert user
369
+ mongodb.users.insert_one(user_dict)
370
+
371
+ return User(**user_dict)
372
+
373
+ @app.post("/token", response_model=Token)
374
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
375
+ user = authenticate_user(form_data.username, form_data.password)
376
+ if not user:
377
+ raise HTTPException(
378
+ status_code=401,
379
+ detail="Incorrect username or password",
380
+ headers={"WWW-Authenticate": "Bearer"},
381
+ )
382
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
383
+ access_token = create_access_token(
384
+ data={"sub": user.username}, expires_delta=access_token_expires
385
+ )
386
+ return {"access_token": access_token, "token_type": "bearer"}
387
+
388
+ # Video processing endpoints
389
  @app.post("/transcribe", response_model=Dict[str, str])
390
+ async def transcribe_video(
391
+ request: TranscriptionRequest,
392
+ current_user: User = Depends(get_current_user)
393
+ ):
394
  """
395
  Transcribe a YouTube video and prepare the RAG system
396
  """
 
415
  transcription = response.candidates[0].content.parts[0].text
416
 
417
  # Process transcription and get session ID
418
+ video_title = f"YouTube Video - {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}"
419
+ session_id = process_transcription(
420
+ transcription,
421
+ current_user.username,
422
+ video_title,
423
+ "youtube",
424
+ request.youtube_url
425
+ )
426
 
427
  return {"session_id": session_id, "message": "YouTube video transcribed and RAG system prepared"}
428
 
 
430
  raise HTTPException(status_code=500, detail=f"Error transcribing video: {str(e)}")
431
 
432
  @app.post("/upload", response_model=Dict[str, str])
433
+ async def upload_video(
434
+ background_tasks: BackgroundTasks,
435
+ title: str = Form(...),
436
+ file: UploadFile = File(...),
437
+ prompt: str = Form("Transcribe the Video. Write all the things described in the video"),
438
+ current_user: User = Depends(get_current_user)
439
+ ):
440
  """
441
  Upload a video file (max 20MB), transcribe it and prepare the RAG system
442
  """
443
  try:
444
  # Check file size (20MB limit)
445
  contents = await file.read()
446
+ file_size = len(contents)
447
+ if file_size > 20 * 1024 * 1024: # 20MB in bytes
448
  raise HTTPException(status_code=400, detail="File size exceeds 20MB limit")
449
 
450
  # Check file type
 
471
  transcription = response.candidates[0].content.parts[0].text
472
 
473
  # Process transcription and get session ID
474
+ session_id = process_transcription(
475
+ transcription,
476
+ current_user.username,
477
+ title,
478
+ "upload",
479
+ None,
480
+ file_size
481
+ )
482
+
483
+ # Save video file to disk
484
+ file_extension = os.path.splitext(file.filename)[1]
485
+ file_path = os.path.join(VIDEOS_DIR, f"{session_id}{file_extension}")
486
+ background_tasks.add_task(save_video_file, session_id, file_path, contents)
487
 
488
  return {"session_id": session_id, "message": "Uploaded video transcribed and RAG system prepared"}
489
 
 
493
  # Reset file pointer
494
  await file.seek(0)
495
 
496
+ @app.get("/download/{video_id}")
497
+ async def download_video(
498
+ video_id: str,
499
+ current_user: User = Depends(get_current_user)
500
+ ):
501
+ """
502
+ Download a previously uploaded video
503
+ """
504
+ # Check if video exists in database
505
+ video_data = mongodb.videos.find_one({"video_id": video_id})
506
+
507
+ if not video_data:
508
+ raise HTTPException(status_code=404, detail="Video not found")
509
+
510
+ # Check if user has access to this video
511
+ if video_data["user_id"] != current_user.username:
512
+ raise HTTPException(status_code=403, detail="Not authorized to access this video")
513
+
514
+ # For YouTube videos, we don't have the actual file
515
+ if video_data["source_type"] == "youtube":
516
+ return {"message": "This is a YouTube video. Please use the original URL to access the video.", "url": video_data["source_url"]}
517
+
518
+ # For uploaded videos, check if file exists
519
+ # Look for any file with the video_id as the base name
520
+ video_files = [f for f in os.listdir(VIDEOS_DIR) if f.startswith(video_id)]
521
+
522
+ if not video_files:
523
+ raise HTTPException(status_code=404, detail="Video file not found")
524
+
525
+ file_path = os.path.join(VIDEOS_DIR, video_files[0])
526
+
527
+ # Determine file extension and MIME type
528
+ file_extension = os.path.splitext(video_files[0])[1]
529
+ mime_type = f"video/{file_extension[1:]}" if file_extension else "video/mp4"
530
+
531
+ # Stream the file
532
+ def iterfile():
533
+ with open(file_path, "rb") as f:
534
+ while chunk := f.read(8192):
535
+ yield chunk
536
+
537
+ return StreamingResponse(
538
+ iterfile(),
539
+ media_type=mime_type,
540
+ headers={"Content-Disposition": f"attachment; filename={video_data['title']}{file_extension}"}
541
+ )
542
+
543
  @app.post("/query", response_model=QueryResponse)
544
+ async def query_system(
545
+ request: QueryRequest,
546
+ current_user: User = Depends(get_current_user)
547
+ ):
548
  """
549
  Query the RAG system with a question
550
  """
551
  try:
552
  session_id = request.session_id
553
 
554
+ # Check if session exists
555
  if not session_id or session_id not in sessions:
556
  raise HTTPException(status_code=404, detail="Session not found. Please transcribe a video first.")
557
 
558
+ # Check if user has access to this session
559
+ video_data = mongodb.videos.find_one({"video_id": session_id})
560
+ if not video_data or video_data["user_id"] != current_user.username:
561
+ raise HTTPException(status_code=403, detail="Not authorized to access this session")
562
+
563
  # Get session data
564
  session = sessions[session_id]
565
  retriever = session["retriever"]
 
568
  # Create chain
569
  chain = create_chain(retriever)
570
 
571
+ # Get chat history from MongoDB in LangChain format
572
+ messages = chat_history.messages
573
+ langchain_chat_history = [(messages[i].content, messages[i+1].content)
574
+ for i in range(0, len(messages)-1, 2) if i+1 < len(messages)]
575
+
576
  # Query the chain
577
+ result = chain.invoke({"question": request.query, "chat_history": langchain_chat_history})
578
 
579
  # Update chat history
580
+ chat_history.add_user_message(request.query)
581
+ chat_history.add_ai_message(result["answer"])
582
 
583
  # Prepare source documents
584
  source_docs = [doc.page_content[:100] + "..." for doc in result.get("source_documents", [])]
 
592
  except Exception as e:
593
  raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}")
594
 
595
+ @app.get("/sessions", response_model=List[Dict[str, Any]])
596
+ async def get_user_sessions(current_user: User = Depends(get_current_user)):
597
+ """
598
+ Get all video sessions for the current user
599
+ """
600
+ user_videos = list(mongodb.videos.find({"user_id": current_user.username}))
601
+
602
+ # Format response
603
+ sessions_list = []
604
+ for video in user_videos:
605
+ sessions_list.append({
606
+ "session_id": video["video_id"],
607
+ "title": video["title"],
608
+ "source_type": video["source_type"],
609
+ "created_at": video["created_at"],
610
+ "transcription_preview": video["transcription"][:200] + "..." if len(video["transcription"]) > 200 else video["transcription"]
611
+ })
612
+
613
+ return sessions_list
614
+
615
  @app.get("/sessions/{session_id}", response_model=Dict[str, Any])
616
+ async def get_session_info(
617
+ session_id: str,
618
+ current_user: User = Depends(get_current_user)
619
+ ):
620
  """
621
  Get information about a specific session
622
  """
623
+ # Check if session exists in database
624
+ video_data = mongodb.videos.find_one({"video_id": session_id})
625
+
626
+ if not video_data:
627
  raise HTTPException(status_code=404, detail="Session not found")
628
 
629
+ # Check if user has access to this session
630
+ if video_data["user_id"] != current_user.username:
631
+ raise HTTPException(status_code=403, detail="Not authorized to access this session")
632
+
633
+ # Get chat history
634
+ chat_history_obj = chat_manager.get_chat_history(session_id)
635
+ chat_messages = []
636
+
637
+ if chat_history_obj:
638
+ messages = chat_history_obj.messages
639
+ for i in range(0, len(messages), 2):
640
+ if i+1 < len(messages):
641
+ chat_messages.append({
642
+ "question": messages[i].content,
643
+ "answer": messages[i+1].content
644
+ })
645
 
646
  return {
647
  "session_id": session_id,
648
+ "title": video_data["title"],
649
+ "source_type": video_data["source_type"],
650
+ "source_url": video_data.get("source_url"),
651
+ "created_at": video_data["created_at"],
652
+ "transcription_preview": video_data["transcription"][:200] + "..." if len(video_data["transcription"]) > 200 else video_data["transcription"],
653
+ "full_transcription": video_data["transcription"],
654
+ "chat_history": chat_messages
655
  }
656
 
657
  @app.delete("/sessions/{session_id}")
658
+ async def delete_session(
659
+ session_id: str,
660
+ current_user: User = Depends(get_current_user)
661
+ ):
662
  """
663
  Delete a session
664
  """
665
+ # Check if session exists in database
666
+ video_data = mongodb.videos.find_one({"video_id": session_id})
667
+
668
+ if not video_data:
669
  raise HTTPException(status_code=404, detail="Session not found")
670
 
671
+ # Check if user has access to this session
672
+ if video_data["user_id"] != current_user.username:
673
+ raise HTTPException(status_code=403, detail="Not authorized to access this session")
674
+
675
+ # Delete from MongoDB
676
+ mongodb.videos.delete_one({"video_id": session_id})
677
+
678
+ # Delete chat history
679
+ chat_history = chat_manager.get_chat_history(session_id)
680
+ if chat_history:
681
+ # This will delete all messages with this session_id from MongoDB
682
+ mongodb.db[CHAT_COLLECTION].delete_many({"session_id": session_id})
683
+
684
+ # Remove from in-memory sessions
685
+ if session_id in sessions:
686
+ del sessions[session_id]
687
+
688
+ # Delete video file if it exists
689
+ video_files = [f for f in os.listdir(VIDEOS_DIR) if f.startswith(session_id)]
690
+ for file in video_files:
691
+ try:
692
+ os.remove(os.path.join(VIDEOS_DIR, file))
693
+ except:
694
+ pass
695
+
696
  return {"message": f"Session {session_id} deleted successfully"}
697
 
698
  @app.get("/")
 
703
  return {
704
  "message": "Video Transcription and QA API",
705
  "endpoints": {
706
+ "/register": "Register a new user",
707
+ "/token": "Login and get access token",
708
  "/transcribe": "Transcribe YouTube videos",
709
  "/upload": "Upload and transcribe video files (max 20MB)",
710
+ "/download/{video_id}": "Download an uploaded video",
711
  "/query": "Query the RAG system",
712
+ "/sessions": "List all user sessions",
713
  "/sessions/{session_id}": "Get session information",
714
  }
715
  }
716
+
717
+ @app.on_event("shutdown")
718
+ def shutdown_event():
719
+ mongodb.close()
720
+ # Clean up temporary files
721
+ shutil.rmtree(VIDEOS_DIR, ignore_errors=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
 
723
  if __name__ == "__main__":
724
  import uvicorn
725
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Fix for the tokenizers warning
726
  uvicorn.run(app, host="0.0.0.0", port=8000)