RAG_APP / src /rag /documents_rag_pipeline.py
sxid003's picture
Upload 83 files
3107242 verified
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import logging
import os
from src.utils.helpers import load_chunks_from_disk, load_metadata
from src.configs.config import LOG_DIR, METADATA_FILE, CHUNKS_FILE, EMBEDDINGS_FILE, FAISS_INDEX_FILE, EMBEDDING_MODEL, TOP_K
from src.models.llm_wrapper import GeminiWrapper
from src.utils.helpers import load_prompt_template
import json
import torch
from sentence_transformers import util
LOG_FILE = os.path.join(LOG_DIR, "Agents.log")
logging.basicConfig(
filename=LOG_FILE,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
class RAGPipeline:
def __init__(self):
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
self.index = faiss.read_index(str(FAISS_INDEX_FILE))
self.chunks_data = load_chunks_from_disk(CHUNKS_FILE)
self.metadata = load_metadata(METADATA_FILE)
self.llm = GeminiWrapper()
def retrieve_from_pdf(self, query, k=TOP_K):
if self.index is None:
logging.error("No FAISS index loaded.")
return []
query_embedding = self.embedding_model.encode([query], convert_to_tensor=False)[0]
distances, indices = self.index.search(np.array([query_embedding]), k)
results = []
chunk_map = []
for doc in self.chunks_data:
for chunk in doc["chunks"]:
chunk_map.append({
"faiss_index": chunk["faiss_index"],
"text": chunk["text"],
"Id": doc["Id"],
"pdf_title": doc["pdf_title"],
"download_link": doc["download_link"]
})
for idx, dist in zip(indices[0], distances[0]):
chunk_info = next((item for item in chunk_map if item["faiss_index"] == idx), None)
if chunk_info:
meta = self.metadata.get(chunk_info["pdf_title"], {
"Id": chunk_info["Id"],
"Nom du document": chunk_info["pdf_title"],
"Lien": chunk_info["download_link"]
})
results.append({
"text": chunk_info["text"],
"indicator": chunk_info["Id"],
"pdf_title": meta["Nom du document"],
"pdf_link": meta["Lien"],
"distance": float(dist)
})
logging.info("Retrieved %d results from FAISS index.", len(results))
return results
def generate(self, query, retrieved_chunks):
raw_context = "\n".join([chunk["text"] for chunk in retrieved_chunks])
prompt_path = "src/prompts/documents_rag_prompt.txt"
prompt = load_prompt_template(prompt_path, {
"context": raw_context,
"query": query
})
logging.info("Prompt sent to LLM:\n%s", prompt)
return self.llm.generate(prompt)
def get_top_docs_chunks_for_query(self, query, relevant_docs, top_k=5, chunks_file=CHUNKS_FILE, embeddings_file=EMBEDDINGS_FILE):
with open(str(chunks_file), encoding="utf-8") as f:
all_chunks_data = json.load(f)
title_map = {doc.get("pdf_title"): doc.get("chunks", []) for doc in all_chunks_data}
link_map = {doc.get("download_link"): doc.get("chunks", []) for doc in all_chunks_data}
all_chunks = []
for doc in relevant_docs:
pdf_title = doc.get("Nom du document") or doc.get("pdf_title")
pdf_link = doc.get("Lien") or doc.get("pdf_link") or doc.get("download_link")
chunks = title_map.get(pdf_title) or link_map.get(pdf_link)
if chunks:
for chunk in chunks:
chunk['pdf_title'] = pdf_title
chunk['pdf_link'] = pdf_link
all_chunks.extend(chunks)
logging.info("Total matched chunks before re-ranking: %d", len(all_chunks))
if not all_chunks:
logging.warning("No chunks found for the given relevant documents.")
return []
embeddings = np.load(str(embeddings_file))
chunk_indices = [chunk["faiss_index"] for chunk in all_chunks]
chunk_embeddings = embeddings[chunk_indices]
from src.utils.search_docs_utils import get_model
model = get_model()
query_emb = model.encode([query])[0]
chunk_embeddings_tensor = torch.tensor(chunk_embeddings)
query_tensor = torch.tensor(query_emb)
cos_sim = util.cos_sim(query_tensor, chunk_embeddings_tensor)[0]
scored = list(zip(cos_sim.tolist(), all_chunks))
scored.sort(reverse=True, key=lambda x: x[0])
logging.info("Returning top-%d most relevant chunks.", top_k)
return [chunk for score, chunk in scored[:top_k]]