dnj0 commited on
Commit
bb6a42c
Β·
verified Β·
1 Parent(s): d3aa2b9

Update src/rag_pipeline.py

Browse files
Files changed (1) hide show
  1. src/rag_pipeline.py +157 -93
src/rag_pipeline.py CHANGED
@@ -1,93 +1,157 @@
1
- from typing import List, Dict, Optional
2
- from pdf_parser import extract_text_from_pdfs
3
- from vector_store import VectorStore
4
- from embeddings import CLIPEmbedder
5
- from multimodal_model import GemmaVisionModel
6
-
7
- class RAGPipeline:
8
- def __init__(self, pdf_dir: str, chroma_dir: str = "./chroma_db", device: str = "cpu"):
9
- self.pdf_dir = pdf_dir
10
- self.device = device
11
-
12
- # Initialize components
13
- print("β†’ Initializing RAG Pipeline...")
14
-
15
- # Initialize embedder
16
- self.embedder = CLIPEmbedder(model_name="openai/clip-vit-base-patch32", device=device)
17
-
18
- # Initialize vector store
19
- self.vector_store = VectorStore(persist_dir=chroma_dir)
20
- self.vector_store.get_or_create_collection()
21
-
22
- # Initialize LLM
23
- self.llm = GemmaVisionModel(model_name="unsloth/gemma-3-1b-pt", device=device)
24
-
25
- print("βœ“ RAG Pipeline initialized")
26
-
27
- def index_pdfs(self):
28
- """Index all PDFs from directory"""
29
- print("β†’ Indexing PDF documents...")
30
-
31
- # Extract text from PDFs
32
- documents, metadatas = extract_text_from_pdfs(self.pdf_dir)
33
-
34
- if documents:
35
- # Generate IDs
36
- ids = [f"doc_{i}" for i in range(len(documents))]
37
-
38
- # Add to vector store (embeddings generated automatically)
39
- self.vector_store.add_documents(documents, metadatas, ids)
40
-
41
- print(f"βœ“ Indexed {len(documents)} document chunks")
42
- else:
43
- print("No documents to index")
44
-
45
- def retrieve_documents(self, query: str, n_results: int = 5) -> List[Dict]:
46
- """Retrieve relevant documents"""
47
- results = self.vector_store.search(query, n_results=n_results)
48
-
49
- retrieved_docs = []
50
- for doc, metadata in zip(results["documents"][0], results["metadatas"][0]):
51
- retrieved_docs.append({
52
- "content": doc,
53
- "source": f"{metadata.get('filename', 'Unknown')} (p{metadata.get('page', '?')})"
54
- })
55
-
56
- return retrieved_docs
57
-
58
- def answer_question(self, question: str, n_context_docs: int = 3) -> Dict:
59
- """Answer question using RAG"""
60
- # Retrieve relevant documents
61
- retrieved_docs = self.retrieve_documents(question, n_results=n_context_docs)
62
-
63
- # Combine context
64
- context = "\n\n".join([f"[Source: {doc['source']}]\n{doc['content']}" for doc in retrieved_docs])
65
-
66
- # Generate answer
67
- answer = self.llm.answer_question(question, context)
68
-
69
- # Extract just the answer (remove prompt)
70
- if "Answer:" in answer:
71
- answer = answer.split("Answer:")[-1].strip()
72
-
73
- return {
74
- "answer": answer,
75
- "sources": [doc["source"] for doc in retrieved_docs],
76
- "context_used": len(retrieved_docs)
77
- }
78
-
79
- def summarize_documents(self) -> str:
80
- """Summarize all indexed documents"""
81
- # Get all documents from vector store
82
- collection_info = self.vector_store.get_collection_info()
83
- doc_count = collection_info.get("document_count", 0)
84
-
85
- if doc_count == 0:
86
- return "No documents to summarize"
87
-
88
- # Sample documents
89
- results = self.vector_store.search("main topic summary", n_results=5)
90
- sampled_content = " ".join([doc for docs in results["documents"] for doc in docs[:200]])
91
-
92
- summary = self.llm.summarize_text(sampled_content)
93
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from pdf_parser import extract_text_from_pdfs
3
+ from vector_store import VectorStore
4
+ from embeddings import CLIPEmbedder
5
+ from multimodal_model import Gemma3Model # ← Changed from GemmaVisionModel
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class RAGPipeline:
11
+ def __init__(self, pdf_dir: str, chroma_dir: str = "./chroma_db", device: str = "cpu"):
12
+ self.pdf_dir = pdf_dir
13
+ self.device = device
14
+
15
+ logger.info("β†’ Initializing RAG Pipeline...")
16
+
17
+ try:
18
+ # Initialize embedder
19
+ logger.debug("Loading embedder...")
20
+ self.embedder = CLIPEmbedder(
21
+ model_name="openai/clip-vit-base-patch32",
22
+ device=device
23
+ )
24
+
25
+ # Initialize vector store
26
+ logger.debug("Initializing vector store...")
27
+ self.vector_store = VectorStore(persist_dir=chroma_dir)
28
+ self.vector_store.get_or_create_collection()
29
+
30
+ # Initialize LLM with Gemma3Model
31
+ logger.debug("Loading Gemma 3 1B model...")
32
+ self.llm = Gemma3Model(model_name="google/gemma-3-1b-it", device=device) # ← Use Gemma3Model
33
+
34
+ logger.info("βœ“ RAG Pipeline initialized successfully")
35
+
36
+ except Exception as e:
37
+ logger.error(f"Failed to initialize pipeline: {str(e)}", exc_info=True)
38
+ raise
39
+
40
+ def index_pdfs(self):
41
+ """Index PDFs with error logging"""
42
+ logger.info("β†’ Starting PDF indexing...")
43
+
44
+ try:
45
+ documents, metadatas = extract_text_from_pdfs(self.pdf_dir)
46
+
47
+ if not documents:
48
+ logger.warning("No documents extracted")
49
+ return
50
+
51
+ logger.debug(f"Extracted {len(documents)} document chunks")
52
+
53
+ ids = [f"doc_{i}" for i in range(len(documents))]
54
+ self.vector_store.add_documents(documents, metadatas, ids)
55
+
56
+ logger.info(f"βœ“ Indexed {len(documents)} document chunks")
57
+
58
+ except Exception as e:
59
+ logger.error(f"Error during indexing: {str(e)}", exc_info=True)
60
+ raise
61
+
62
+ def retrieve_documents(self, query: str, n_results: int = 5) -> List[Dict]:
63
+ """Retrieve documents with error handling"""
64
+ try:
65
+ logger.debug(f"Searching for: {query[:50]}...")
66
+ results = self.vector_store.search(query, n_results=n_results)
67
+
68
+ retrieved_docs = []
69
+ for doc, metadata in zip(results["documents"][0], results["metadatas"][0]):
70
+ retrieved_docs.append({
71
+ "content": doc,
72
+ "source": f"{metadata.get('filename', 'Unknown')} (p{metadata.get('page', '?')})"
73
+ })
74
+
75
+ logger.debug(f"Retrieved {len(retrieved_docs)} documents")
76
+ return retrieved_docs
77
+
78
+ except Exception as e:
79
+ logger.error(f"Error retrieving documents: {str(e)}", exc_info=True)
80
+ return []
81
+
82
+ def answer_question(self, question: str, n_context_docs: int = 3) -> Dict:
83
+ """Answer question using RAG with comprehensive error handling"""
84
+
85
+ logger.info(f"Processing question: {question[:50]}...")
86
+
87
+ try:
88
+ # Retrieve relevant documents
89
+ logger.debug(f"Retrieving {n_context_docs} documents...")
90
+ retrieved_docs = self.retrieve_documents(question, n_results=n_context_docs)
91
+
92
+ if not retrieved_docs:
93
+ logger.warning("No documents retrieved")
94
+ return {
95
+ "answer": "No relevant documents found.",
96
+ "sources": [],
97
+ "context_used": 0
98
+ }
99
+
100
+ logger.debug(f"Retrieved {len(retrieved_docs)} documents")
101
+
102
+ # Combine context (limit to prevent memory issues)
103
+ context = "\n\n".join([
104
+ f"[{doc['source']}]\n{doc['content'][:500]}"
105
+ for doc in retrieved_docs
106
+ ])[:2000]
107
+
108
+ logger.debug("Generating answer with Gemma 3...")
109
+
110
+ try:
111
+ # Use greedy decoding for faster inference with Gemma 3
112
+ answer = self.llm.answer_question(question, context)
113
+ except Exception as e:
114
+ logger.warning(f"Answer generation failed ({e}), using greedy fallback...")
115
+ # Fallback to greedy
116
+ answer = self.llm.generate_response_greedy(
117
+ f"Q: {question}\nA:"
118
+ )
119
+
120
+ # Extract answer
121
+ if "Answer:" in answer:
122
+ answer = answer.split("Answer:")[-1].strip()
123
+
124
+ logger.info("βœ“ Answer generated successfully")
125
+
126
+ return {
127
+ "answer": answer[:1000], # Limit output length
128
+ "sources": [doc["source"] for doc in retrieved_docs],
129
+ "context_used": len(retrieved_docs)
130
+ }
131
+
132
+ except Exception as e:
133
+ logger.error(f"Error in answer_question: {str(e)}", exc_info=True)
134
+ return {
135
+ "answer": f"Error generating answer: {str(e)}",
136
+ "sources": [],
137
+ "context_used": 0
138
+ }
139
+
140
+ def summarize_documents(self) -> str:
141
+ """Summarize all indexed documents"""
142
+ collection_info = self.vector_store.get_collection_info()
143
+ doc_count = collection_info.get("document_count", 0)
144
+
145
+ if doc_count == 0:
146
+ return "No documents to summarize"
147
+
148
+ # Sample documents
149
+ results = self.vector_store.search("main topic summary", n_results=5)
150
+ sampled_content = " ".join([doc for docs in results["documents"] for doc in docs[:200]])
151
+
152
+ summary = self.llm.summarize_text(sampled_content)
153
+ return summary
154
+
155
+ def get_collection_info(self) -> Dict:
156
+ """Get collection statistics"""
157
+ return self.vector_store.get_collection_info()