import os import re import time from typing import List import faiss import gradio as gr import numpy as np import pandas as pd from pypdf import PdfReader from sentence_transformers import SentenceTransformer # For embeddings from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm # For progress bars # LLM from transformers import pipeline # Data Directory DATA_DIR = "pdfs" # Relative path (assuming 'pdfs' directory is in the repo) # Model Names EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # A good balance of speed and quality LLM_MODEL_NAME = "google/flan-t5-base" # Or a smaller one for faster inference MAX_CONTEXT_LENGTH = 512 # Maximum context length for the LLM # Sample Queries SAMPLE_QUERIES = [ "What was Google's total revenue in 2023?", "How did Google's research and development expenses change between 2023 and 2024?", "What is the capital of France?", ] def load_pdfs(data_dir: str) -> List[str]: """Loads and extracts text from PDF files in a directory.""" pdf_texts = [] for filename in os.listdir(data_dir): if filename.endswith(".pdf"): filepath = os.path.join(data_dir, filename) try: with open(filepath, 'rb') as f: reader = PdfReader(f) text = "".join([p.extract_text() for p in reader.pages]) pdf_texts.append(text) except Exception as e: print(f"Error loading {filename}: {e}") return pdf_texts def clean_text(text): """Clean the financial text by removing extra spaces, newlines, and special characters.""" text = re.sub(r'\s+', ' ', text) # Remove multiple spaces text = text.replace('\n', ' ') text = text.replace('\t', ' ') text = text.replace('$', '') # Remove currency symbols text = text.encode('ascii', 'ignore').decode() # Remove non-ASCII characters return text.strip() def chunk_text(text, chunk_size=512, overlap=50): """Splits text into overlapping chunks.""" sentences = re.split(r'(? MAX_CONTEXT_LENGTH: prompt_words = prompt.split() prompt = " ".join(prompt_words[:MAX_CONTEXT_LENGTH]) # Keep only the first MAX_CONTEXT_LENGTH words try: result = qa_pipeline(prompt, max_length=512, min_length=50, do_sample=False) return result[0]['generated_text'] except Exception as e: print(f"Error during LLM generation: {e}") return "An error occurred while generating the answer." def input_guardrail(query): """Checks if the query is safe and relevant.""" # Enhanced: Check for financial keywords and relevance query = query.lower() if not any(keyword in query for keyword in ["revenue", "profit", "loss", "cash flow", "balance sheet", "financial", "income", "expense"]): if any(keyword in query for keyword in ["capital of France", "irrelevant", "harmful"]): return "I cannot answer questions of that nature. Please ask a relevant financial question.", 0.0, 0, "Irrelevant/Harmful Query", EMBEDDING_MODEL_NAME, LLM_MODEL_NAME else: return "Please ask a specific question related to financial information.", 0.0, 0, "Non-Financial Query", EMBEDDING_MODEL_NAME, LLM_MODEL_NAME return None, None, None, "Safe", None, None # Query is safe def rag_pipeline(query, tfidf_matrix, tfidf_vectorizer, faiss_index, embedding_model, chunks, qa_pipeline): """Executes the RAG pipeline.""" start_time = time.time() # 1. Retrieval relevant_indices, retrieval_scores = hybrid_retrieval(query, tfidf_matrix, tfidf_vectorizer, faiss_index, embedding_model, chunks, top_k=5) # Adjust top_k as needed context = "\n".join([chunks[i] for i in relevant_indices]) # 2. Generation answer = generate_answer(query, context, qa_pipeline) end_time = time.time() elapsed_time = end_time - start_time # Calculate Confidence Score (Example) if retrieval_scores: confidence_score = np.mean(retrieval_scores) # Average retrieval score else: confidence_score = 0.0 # If no context retrieved, low confidence return answer, elapsed_time, confidence_score, len(chunks) # Return number of chunks def answer_question(query, tfidf_matrix, tfidf_vectorizer, faiss_index, embedding_model, chunks, qa_pipeline): """Main function to run the RAG pipeline and return the answer and confidence.""" # Guardrail check guardrail_message, guardrail_confidence, num_chunks, guardrail_reason, emb_model, llm_model = input_guardrail(query) if guardrail_message: output_text = ( f"