abhinav0231 commited on
Commit
d1ebc00
·
verified ·
1 Parent(s): 7367624

Update rag_agent.py

Browse files
Files changed (1) hide show
  1. rag_agent.py +75 -66
rag_agent.py CHANGED
@@ -2,92 +2,101 @@ import os
2
  from sklearn.feature_extraction.text import TfidfVectorizer
3
  from sklearn.metrics.pairwise import cosine_similarity
4
  import numpy as np
5
- import PyPDF2
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
  import streamlit as st
8
 
9
- def load_document(file_path: str) -> str:
10
- """Load document content from PDF or TXT file."""
11
- try:
12
- if file_path.endswith(".pdf"):
13
- with open(file_path, 'rb') as file:
14
- pdf_reader = PyPDF2.PdfReader(file)
15
- text = ""
16
- for page in pdf_reader.pages:
17
- text += page.extract_text() + "\n"
18
- return text
19
- elif file_path.endswith(".txt"):
20
- with open(file_path, 'r', encoding='utf-8') as file:
21
- return file.read()
22
- else:
23
- return "Error: Unsupported file format."
24
- except Exception as e:
25
- return f"Error reading file: {str(e)}"
26
-
27
- def simple_text_search(query: str, document_text: str, max_chunks: int = 3) -> str:
28
- """Simple TF-IDF based text retrieval - much faster than FAISS."""
29
-
30
- # Split document into chunks
31
- chunks = []
32
- words = document_text.split()
33
- chunk_size = 200 # words per chunk
34
-
35
- for i in range(0, len(words), chunk_size):
36
- chunk = " ".join(words[i:i + chunk_size])
37
- if chunk.strip():
38
- chunks.append(chunk)
39
-
40
- if not chunks:
41
- return "No content found in document."
42
-
43
- # Create TF-IDF vectors
44
- vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
- # Vectorize chunks and query
48
- chunk_vectors = vectorizer.fit_transform(chunks)
 
 
 
 
 
 
 
49
  query_vector = vectorizer.transform([query])
50
 
51
- # Calculate similarity
52
- similarities = cosine_similarity(query_vector, chunk_vectors).flatten()
53
 
54
- # Get top matching chunks
55
- top_indices = similarities.argsort()[-max_chunks:][::-1]
56
 
57
- relevant_chunks = [chunks[i] for i in top_indices if similarities[i] > 0.1]
 
 
 
58
 
59
- return "\n\n".join(relevant_chunks[:max_chunks])
 
60
 
61
  except Exception as e:
62
- return f"Search error: {str(e)}"
 
63
 
64
  def run_rag_agent(user_prompt: str, file_path: str) -> str:
65
- """Simple but effective RAG implementation."""
66
- print("--- RAG Agent Activated (Lightweight Version) ---")
 
 
67
 
68
- # Load document
69
- document_text = load_document(file_path)
70
- if document_text.startswith("Error"):
71
- return document_text
72
 
73
- # Generate search query using LLM
74
- api_key = st.secrets.get("GEMINI_API_KEY", os.getenv("GEMINI_API_KEY"))
75
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=api_key)
76
 
77
- search_prompt = f"""Based on this story idea: "{user_prompt}"
78
-
79
- What are the 2-3 most important keywords to search for in a document to find relevant context?
80
- Respond with just the keywords separated by spaces."""
81
-
82
  try:
 
 
 
 
 
 
 
83
  response = llm.invoke(search_prompt)
84
  search_query = response.content.strip()
85
  print(f"Generated Search Query: {search_query}")
86
- except:
87
- search_query = user_prompt # Fallback
 
88
 
89
- # Retrieve relevant content
90
- context = simple_text_search(search_query, document_text)
91
 
92
  print("--- RAG Agent Finished ---")
93
- return context
 
2
  from sklearn.feature_extraction.text import TfidfVectorizer
3
  from sklearn.metrics.pairwise import cosine_similarity
4
  import numpy as np
5
+ from typing import List, Dict
 
6
  import streamlit as st
7
 
8
+ def get_document_context(file_path: str, query: str) -> str:
9
+ """
10
+ Lightweight document retrieval using TF-IDF instead of FAISS.
11
+ """
12
+ print("--- Using TF-IDF for document retrieval ---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Load document
15
+ if file_path.endswith(".pdf"):
16
+ try:
17
+ from pypdf import PdfReader
18
+ reader = PdfReader(file_path)
19
+ documents = []
20
+ for page in reader.pages:
21
+ text = page.extract_text()
22
+ if text.strip():
23
+ documents.append(text)
24
+ except:
25
+ return "Error: Could not read PDF file."
26
+
27
+ elif file_path.endswith(".txt"):
28
+ try:
29
+ with open(file_path, 'r', encoding='utf-8') as f:
30
+ content = f.read()
31
+ # Split into chunks of ~1000 characters
32
+ documents = [content[i:i+1000] for i in range(0, len(content), 800)]
33
+ except:
34
+ return "Error: Could not read text file."
35
+ else:
36
+ return "Error: Unsupported file format. Please upload a .pdf or .txt file."
37
+
38
+ if not documents:
39
+ return "Error: Document is empty or could not be read."
40
+
41
  try:
42
+ # Create TF-IDF vectors - this is our "embedding" replacement
43
+ vectorizer = TfidfVectorizer(
44
+ stop_words='english',
45
+ max_features=5000,
46
+ ngram_range=(1, 2) # Include bigrams for better context
47
+ )
48
+
49
+ # Transform documents and query
50
+ doc_vectors = vectorizer.fit_transform(documents)
51
  query_vector = vectorizer.transform([query])
52
 
53
+ # Calculate similarities
54
+ similarities = cosine_similarity(query_vector, doc_vectors).flatten()
55
 
56
+ # Get top 3 most relevant chunks
57
+ top_indices = similarities.argsort()[-3:][::-1]
58
 
59
+ context_chunks = []
60
+ for idx in top_indices:
61
+ if similarities[idx] > 0.1: # Only include if reasonably relevant
62
+ context_chunks.append(documents[idx])
63
 
64
+ context = "\n\n".join(context_chunks)
65
+ return context if context else "No relevant context found in the document."
66
 
67
  except Exception as e:
68
+ print(f"An error occurred during document processing: {e}")
69
+ return "Error: Failed to process the provided document."
70
 
71
  def run_rag_agent(user_prompt: str, file_path: str) -> str:
72
+ """
73
+ The main agentic function - keep the same interface as before.
74
+ """
75
+ print("--- RAG Agent Activated (Lightweight TF-IDF Mode) ---")
76
 
77
+ # Generate optimized search query using LLM (same logic as before)
78
+ from llm_setup import llm
 
 
79
 
80
+ if not llm:
81
+ return "Error: LLM not available for query generation."
 
82
 
 
 
 
 
 
83
  try:
84
+ search_prompt = f"""You are a research assistant. Based on the user's story idea, what is the single most
85
+ important keyword or question to search for within their provided document to find relevant context?
86
+
87
+ User's Story Idea: '{user_prompt}'
88
+
89
+ Optimized Search Query for Document:"""
90
+
91
  response = llm.invoke(search_prompt)
92
  search_query = response.content.strip()
93
  print(f"Generated Search Query: {search_query}")
94
+ except Exception as e:
95
+ print(f"Query generation failed, using original prompt: {e}")
96
+ search_query = user_prompt
97
 
98
+ # Use our lightweight retrieval
99
+ context = get_document_context(file_path, search_query)
100
 
101
  print("--- RAG Agent Finished ---")
102
+ return context