SparrowTale / rag_agent.py
abhinav0231's picture
Update rag_agent.py
d1ebc00 verified
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from typing import List, Dict
import streamlit as st
def get_document_context(file_path: str, query: str) -> str:
"""
Lightweight document retrieval using TF-IDF instead of FAISS.
"""
print("--- Using TF-IDF for document retrieval ---")
# Load document
if file_path.endswith(".pdf"):
try:
from pypdf import PdfReader
reader = PdfReader(file_path)
documents = []
for page in reader.pages:
text = page.extract_text()
if text.strip():
documents.append(text)
except:
return "Error: Could not read PDF file."
elif file_path.endswith(".txt"):
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Split into chunks of ~1000 characters
documents = [content[i:i+1000] for i in range(0, len(content), 800)]
except:
return "Error: Could not read text file."
else:
return "Error: Unsupported file format. Please upload a .pdf or .txt file."
if not documents:
return "Error: Document is empty or could not be read."
try:
# Create TF-IDF vectors - this is our "embedding" replacement
vectorizer = TfidfVectorizer(
stop_words='english',
max_features=5000,
ngram_range=(1, 2) # Include bigrams for better context
)
# Transform documents and query
doc_vectors = vectorizer.fit_transform(documents)
query_vector = vectorizer.transform([query])
# Calculate similarities
similarities = cosine_similarity(query_vector, doc_vectors).flatten()
# Get top 3 most relevant chunks
top_indices = similarities.argsort()[-3:][::-1]
context_chunks = []
for idx in top_indices:
if similarities[idx] > 0.1: # Only include if reasonably relevant
context_chunks.append(documents[idx])
context = "\n\n".join(context_chunks)
return context if context else "No relevant context found in the document."
except Exception as e:
print(f"An error occurred during document processing: {e}")
return "Error: Failed to process the provided document."
def run_rag_agent(user_prompt: str, file_path: str) -> str:
"""
The main agentic function - keep the same interface as before.
"""
print("--- RAG Agent Activated (Lightweight TF-IDF Mode) ---")
# Generate optimized search query using LLM (same logic as before)
from llm_setup import llm
if not llm:
return "Error: LLM not available for query generation."
try:
search_prompt = f"""You are a research assistant. Based on the user's story idea, what is the single most
important keyword or question to search for within their provided document to find relevant context?
User's Story Idea: '{user_prompt}'
Optimized Search Query for Document:"""
response = llm.invoke(search_prompt)
search_query = response.content.strip()
print(f"Generated Search Query: {search_query}")
except Exception as e:
print(f"Query generation failed, using original prompt: {e}")
search_query = user_prompt
# Use our lightweight retrieval
context = get_document_context(file_path, search_query)
print("--- RAG Agent Finished ---")
return context