|
|
from fastapi import FastAPI, HTTPException, Body, Query, File, UploadFile, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional, Dict, Any, Union |
|
|
import uuid |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.chains import ConversationalRetrievalChain |
|
|
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_core.documents import Document |
|
|
from langchain_groq import ChatGroq |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
|
|
|
|
|
|
app = FastAPI(title="RAG System API", description="An API for question answering based on YouTube video content or uploaded video files") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class TranscriptionRequest(BaseModel): |
|
|
youtube_url: str |
|
|
|
|
|
class QueryRequest(BaseModel): |
|
|
query: str |
|
|
session_id: Optional[str] = None |
|
|
|
|
|
class QueryResponse(BaseModel): |
|
|
answer: str |
|
|
session_id: str |
|
|
source_documents: Optional[List[str]] = None |
|
|
|
|
|
|
|
|
sessions = {} |
|
|
|
|
|
|
|
|
def init_google_client(): |
|
|
api_key = os.getenv("GOOGLE_API_KEY", "") |
|
|
if not api_key: |
|
|
raise ValueError("GOOGLE_API_KEY environment variable not set") |
|
|
return genai.Client(api_key=api_key) |
|
|
|
|
|
|
|
|
def get_llm(): |
|
|
""" |
|
|
Returns the language model instance (LLM) using ChatGroq API. |
|
|
The LLM used is Llama 3.1 with a versatile 70 billion parameters model. |
|
|
""" |
|
|
api_key = os.getenv("GROQ_API_KEY", "") |
|
|
if not api_key: |
|
|
raise ValueError("GROQ_API_KEY environment variable not set") |
|
|
|
|
|
llm = ChatGroq( |
|
|
model="llama-3.3-70b-versatile", |
|
|
temperature=0, |
|
|
max_tokens=1024, |
|
|
api_key=api_key |
|
|
) |
|
|
return llm |
|
|
|
|
|
|
|
|
def get_embeddings(): |
|
|
model_name = "BAAI/bge-small-en" |
|
|
model_kwargs = {"device": "cpu"} |
|
|
encode_kwargs = {"normalize_embeddings": True} |
|
|
embeddings = HuggingFaceBgeEmbeddings( |
|
|
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs |
|
|
) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
quiz_solving_prompt = ''' |
|
|
You are an assistant specialized in solving quizzes. Your goal is to provide accurate, concise, and contextually relevant answers. |
|
|
Use the following retrieved context to answer the user's question. |
|
|
If the context lacks sufficient information, respond with "I don't know." Do not make up answers or provide unverified information. |
|
|
|
|
|
Guidelines: |
|
|
1. Extract key information from the context to form a coherent response. |
|
|
2. Maintain a clear and professional tone. |
|
|
3. If the question requires clarification, specify it politely. |
|
|
|
|
|
Retrieved context: |
|
|
{context} |
|
|
|
|
|
User's question: |
|
|
{question} |
|
|
|
|
|
Your response: |
|
|
''' |
|
|
|
|
|
|
|
|
user_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", quiz_solving_prompt), |
|
|
("human", "{question}"), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def create_chain(retriever): |
|
|
llm = get_llm() |
|
|
chain = ConversationalRetrievalChain.from_llm( |
|
|
llm=llm, |
|
|
retriever=retriever, |
|
|
return_source_documents=True, |
|
|
chain_type='stuff', |
|
|
combine_docs_chain_kwargs={"prompt": user_prompt}, |
|
|
verbose=False, |
|
|
) |
|
|
return chain |
|
|
|
|
|
|
|
|
def process_transcription(transcription): |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20) |
|
|
all_splits = text_splitter.split_text(transcription) |
|
|
|
|
|
|
|
|
embeddings = get_embeddings() |
|
|
vectorstore = FAISS.from_texts(all_splits, embeddings) |
|
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
|
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
sessions[session_id] = { |
|
|
"retriever": retriever, |
|
|
"chat_history": [], |
|
|
"transcription": transcription |
|
|
} |
|
|
|
|
|
return session_id |
|
|
|
|
|
@app.post("/transcribe", response_model=Dict[str, str]) |
|
|
async def transcribe_video(request: TranscriptionRequest): |
|
|
""" |
|
|
Transcribe a YouTube video and prepare the RAG system |
|
|
""" |
|
|
try: |
|
|
|
|
|
client = init_google_client() |
|
|
|
|
|
|
|
|
response = client.models.generate_content( |
|
|
model='models/gemini-2.0-flash', |
|
|
contents=types.Content( |
|
|
parts=[ |
|
|
types.Part(text='Transcribe the Video. Write all the things described in the video'), |
|
|
types.Part( |
|
|
file_data=types.FileData(file_uri=request.youtube_url) |
|
|
) |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
transcription = response.candidates[0].content.parts[0].text |
|
|
|
|
|
|
|
|
session_id = process_transcription(transcription) |
|
|
|
|
|
return {"session_id": session_id, "message": "YouTube video transcribed and RAG system prepared"} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error transcribing video: {str(e)}") |
|
|
|
|
|
@app.post("/upload", response_model=Dict[str, str]) |
|
|
async def upload_video(file: UploadFile = File(...), prompt: str = Form("Transcribe the Video. Write all the things described in the video")): |
|
|
""" |
|
|
Upload a video file (max 20MB), transcribe it and prepare the RAG system |
|
|
""" |
|
|
try: |
|
|
|
|
|
contents = await file.read() |
|
|
if len(contents) > 20 * 1024 * 1024: |
|
|
raise HTTPException(status_code=400, detail="File size exceeds 20MB limit") |
|
|
|
|
|
|
|
|
if not file.content_type.startswith('video/'): |
|
|
raise HTTPException(status_code=400, detail="File must be a video") |
|
|
|
|
|
|
|
|
client = init_google_client() |
|
|
|
|
|
|
|
|
response = client.models.generate_content( |
|
|
model='models/gemini-2.0-flash', |
|
|
contents=types.Content( |
|
|
parts=[ |
|
|
types.Part(text=prompt), |
|
|
types.Part( |
|
|
inline_data=types.Blob(data=contents, mime_type=file.content_type) |
|
|
) |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
transcription = response.candidates[0].content.parts[0].text |
|
|
|
|
|
|
|
|
session_id = process_transcription(transcription) |
|
|
|
|
|
return {"session_id": session_id, "message": "Uploaded video transcribed and RAG system prepared"} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing uploaded video: {str(e)}") |
|
|
finally: |
|
|
|
|
|
await file.seek(0) |
|
|
|
|
|
@app.post("/query", response_model=QueryResponse) |
|
|
async def query_system(request: QueryRequest): |
|
|
""" |
|
|
Query the RAG system with a question |
|
|
""" |
|
|
try: |
|
|
session_id = request.session_id |
|
|
|
|
|
|
|
|
if not session_id or session_id not in sessions: |
|
|
raise HTTPException(status_code=404, detail="Session not found. Please transcribe a video first.") |
|
|
|
|
|
|
|
|
session = sessions[session_id] |
|
|
retriever = session["retriever"] |
|
|
chat_history = session["chat_history"] |
|
|
|
|
|
|
|
|
chain = create_chain(retriever) |
|
|
|
|
|
|
|
|
result = chain({"question": request.query, "chat_history": chat_history}) |
|
|
|
|
|
|
|
|
chat_history.append((request.query, result["answer"])) |
|
|
|
|
|
|
|
|
source_docs = [doc.page_content[:100] + "..." for doc in result.get("source_documents", [])] |
|
|
|
|
|
return { |
|
|
"answer": result["answer"], |
|
|
"session_id": session_id, |
|
|
"source_documents": source_docs |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}") |
|
|
|
|
|
@app.get("/sessions/{session_id}", response_model=Dict[str, Any]) |
|
|
async def get_session_info(session_id: str): |
|
|
""" |
|
|
Get information about a specific session |
|
|
""" |
|
|
if session_id not in sessions: |
|
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
|
|
session = sessions[session_id] |
|
|
|
|
|
return { |
|
|
"session_id": session_id, |
|
|
"chat_history_length": len(session["chat_history"]), |
|
|
"transcription_preview": session["transcription"][:200] + "..." |
|
|
} |
|
|
|
|
|
@app.delete("/sessions/{session_id}") |
|
|
async def delete_session(session_id: str): |
|
|
""" |
|
|
Delete a session |
|
|
""" |
|
|
if session_id not in sessions: |
|
|
raise HTTPException(status_code=404, detail="Session not found") |
|
|
|
|
|
del sessions[session_id] |
|
|
return {"message": f"Session {session_id} deleted successfully"} |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
""" |
|
|
API root endpoint |
|
|
""" |
|
|
return { |
|
|
"message": "Video Transcription and QA API", |
|
|
"endpoints": { |
|
|
"/transcribe": "Transcribe YouTube videos", |
|
|
"/upload": "Upload and transcribe video files (max 20MB)", |
|
|
"/query": "Query the RAG system", |
|
|
"/sessions/{session_id}": "Get session information", |
|
|
} |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |