RAG_APP / src /rag /youtube_rag_pipeline.py
sxid003's picture
Upload 83 files
3107242 verified
import faiss
import pickle
import numpy as np
import logging
from sentence_transformers import SentenceTransformer
from src.models.llm_wrapper import GeminiWrapper
from src.configs.config import YT_EMBEDDING_MODEL, FAISS_INDEX_FILE_YT, FAISS_METADATA_FILE_YT, TOP_K, LOG_DIR
from pathlib import Path
import os
LOG_FILE = os.path.join(LOG_DIR, "Agents.log")
logging.basicConfig(
filename=LOG_FILE,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
class YouTubeRAGPipeline:
def __init__(self):
self.embedding_model = SentenceTransformer(YT_EMBEDDING_MODEL)
self.index = faiss.read_index(str(FAISS_INDEX_FILE_YT))
with open(FAISS_METADATA_FILE_YT, "rb") as f:
self.metadata = pickle.load(f)
self.llm = GeminiWrapper()
self.prompt_template_path = Path("src/prompts/youtube_rag_prompt.txt")
def load_prompt(self, query: str, context: str) -> str:
template = self.prompt_template_path.read_text(encoding="utf-8")
return template.replace("{{query}}", query).replace("{{context}}", context)
def search(self, query, top_k=TOP_K):
embedding = self.embedding_model.encode([query], normalize_embeddings=True)[0].astype("float32")
D, I = self.index.search(np.array([embedding]).reshape(1, -1), top_k)
logging.info(f"Top {top_k} search indices: {I[0]}")
return [self.metadata.iloc[i] for i in I[0] if i < len(self.metadata)]
def generate(self, query, history=None):
logging.info(f"Generating response for query: {query}")
context_chunks = self.search(query, top_k=5)
context_text = "\n".join(chunk["texte"] for chunk in context_chunks)
prompt = self.load_prompt(query, context_text)
logging.info("Prompt passed to LLM:\n%s", prompt)
return self.llm.generate(prompt, history), context_chunks