dnj0 commited on
Commit
555c75a
·
verified ·
1 Parent(s): 44a83f9

Upload 6 files

Browse files
src/app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from pathlib import Path
4
+ from rag_pipeline import RAGPipeline
5
+ import shutil
6
+
7
+ # Page configuration
8
+ st.set_page_config(
9
+ page_title="Local Multimodal RAG",
10
+ page_icon="📚",
11
+ layout="wide",
12
+ initial_sidebar_state="expanded"
13
+ )
14
+
15
+ st.title("📚 Local Multimodal RAG System")
16
+ st.markdown("**Analyze PDF documents locally with Mistral + CLIP embeddings**")
17
+
18
+ # Initialize session state
19
+ if "uploaded_files" not in st.session_state:
20
+ st.session_state.uploaded_files = []
21
+ if "rag_pipeline" not in st.session_state:
22
+ st.session_state.rag_pipeline = None
23
+ if "needs_reindex" not in st.session_state:
24
+ st.session_state.needs_reindex = False
25
+
26
+ # Sidebar configuration
27
+ with st.sidebar:
28
+ st.header("⚙️ Configuration")
29
+
30
+ pdf_dir = st.text_input(
31
+ "📁 PDF Directory",
32
+ value="./pdfs",
33
+ help="Path to directory containing PDF files"
34
+ )
35
+
36
+ device = st.selectbox(
37
+ "🖥️ Device",
38
+ ["cpu", "cuda"],
39
+ help="Device for model inference"
40
+ )
41
+
42
+ n_context_docs = st.slider(
43
+ "📄 Context Documents",
44
+ min_value=1,
45
+ max_value=10,
46
+ value=3,
47
+ help="Number of documents to retrieve for context"
48
+ )
49
+
50
+ st.divider()
51
+
52
+ # PDF Upload Section
53
+ st.subheader("📤 Upload PDF Files")
54
+
55
+ uploaded_pdfs = st.file_uploader(
56
+ "Choose PDF files to upload",
57
+ type="pdf",
58
+ accept_multiple_files=True,
59
+ help="Select one or more PDF files to add to the system"
60
+ )
61
+
62
+ if uploaded_pdfs:
63
+ # Create PDF directory if not exists
64
+ os.makedirs(pdf_dir, exist_ok=True)
65
+
66
+ upload_button = st.button("⬆️ Upload PDFs", use_container_width=True)
67
+
68
+ if upload_button:
69
+ uploaded_count = 0
70
+ for uploaded_file in uploaded_pdfs:
71
+ file_path = os.path.join(pdf_dir, uploaded_file.name)
72
+
73
+ # Save file
74
+ with open(file_path, "wb") as f:
75
+ f.write(uploaded_file.getbuffer())
76
+
77
+ st.session_state.uploaded_files.append(uploaded_file.name)
78
+ uploaded_count += 1
79
+
80
+ st.success(f"✅ Uploaded {uploaded_count} PDF(s) successfully!")
81
+ st.session_state.needs_reindex = True
82
+
83
+ st.divider()
84
+
85
+ # Display uploaded files
86
+ pdf_files = list(Path(pdf_dir).glob("*.pdf"))
87
+ if pdf_files:
88
+ st.subheader(f"📚 Documents ({len(pdf_files)})")
89
+ for pdf_file in pdf_files:
90
+ col1, col2 = st.columns([4, 1])
91
+ with col1:
92
+ st.write(f"• {pdf_file.name}")
93
+ with col2:
94
+ if st.button("🗑️", key=f"delete_{pdf_file.name}", help="Delete this file"):
95
+ os.remove(pdf_file)
96
+ st.session_state.needs_reindex = True
97
+ st.rerun()
98
+
99
+ st.divider()
100
+
101
+ # Reindex button
102
+ if st.button("🔄 Reload & Index PDFs", use_container_width=True):
103
+ st.session_state.rag_pipeline = None
104
+ st.session_state.needs_reindex = True
105
+ st.rerun()
106
+
107
+
108
+ # Initialize pipeline in session state
109
+ @st.cache_resource
110
+ def init_rag_pipeline(_device, _pdf_dir):
111
+ """Initialize RAG pipeline (cached)"""
112
+ # Create PDF directory if not exists
113
+ os.makedirs(_pdf_dir, exist_ok=True)
114
+
115
+ # Check if PDFs exist
116
+ pdf_files = list(Path(_pdf_dir).glob("*.pdf"))
117
+ if not pdf_files:
118
+ return None, f"No PDF files found in {_pdf_dir}. Upload PDFs using the sidebar."
119
+
120
+ try:
121
+ with st.spinner("⏳ Initializing RAG pipeline..."):
122
+ pipeline = RAGPipeline(pdf_dir=_pdf_dir, device=_device)
123
+ with st.spinner("⏳ Indexing PDFs..."):
124
+ pipeline.index_pdfs()
125
+ return pipeline, None
126
+ except Exception as e:
127
+ return None, str(e)
128
+
129
+
130
+ # Get or initialize pipeline
131
+ if st.session_state.rag_pipeline is None or st.session_state.needs_reindex:
132
+ pipeline, error = init_rag_pipeline(device, pdf_dir)
133
+ if error:
134
+ st.error(f"❌ Error: {error}")
135
+ st.info("💡 **How to get started:**\n1. Upload PDF files using the sidebar\n2. Click 'Upload PDFs' to save them\n3. Click 'Reload & Index PDFs' to process them")
136
+ st.stop()
137
+ st.session_state.rag_pipeline = pipeline
138
+ st.session_state.needs_reindex = False
139
+ else:
140
+ pipeline = st.session_state.rag_pipeline
141
+
142
+ # Main content
143
+ if pipeline:
144
+ # Tabs
145
+ tab1, tab2, tab3 = st.tabs(["❓ Q&A", "📊 Summary", "📖 Retrieval"])
146
+
147
+ # Tab 1: Question Answering
148
+ with tab1:
149
+ st.subheader("Ask Questions about Your Documents")
150
+
151
+ question = st.text_area(
152
+ "Your question (in Russian or English):",
153
+ height=100,
154
+ placeholder="What is this document about? What are the main points? Etc.",
155
+ key="qa_question"
156
+ )
157
+
158
+ col1, col2 = st.columns(2)
159
+ with col1:
160
+ get_answer_btn = st.button("🔍 Get Answer", use_container_width=True)
161
+ with col2:
162
+ clear_btn = st.button("🗑️ Clear", use_container_width=True)
163
+
164
+ if clear_btn:
165
+ st.rerun()
166
+
167
+ if get_answer_btn:
168
+ if question.strip():
169
+ with st.spinner("⏳ Retrieving documents and generating answer..."):
170
+ try:
171
+ result = pipeline.answer_question(question, n_context_docs=n_context_docs)
172
+ except Exception as e:
173
+ st.error(f"Error generating answer: {str(e)}")
174
+ result = None
175
+
176
+ if result and result.get("answer"):
177
+ st.success("✓ Answer generated!")
178
+
179
+ # Display answer
180
+ st.subheader("📝 Answer")
181
+ st.write(result["answer"])
182
+
183
+ # Display sources
184
+ with st.expander("📚 Sources Used"):
185
+ for i, source in enumerate(result["sources"], 1):
186
+ st.write(f"{i}. {source}")
187
+
188
+ # Display stats
189
+ col1, col2 = st.columns(2)
190
+ with col1:
191
+ st.metric("Documents Used", result.get("context_used", 0))
192
+ with col2:
193
+ st.metric("Answer Length", len(result["answer"]))
194
+ else:
195
+ st.warning("Please enter a question")
196
+
197
+ # Tab 2: Document Summary
198
+ with tab2:
199
+ st.subheader("Summary of Indexed Documents")
200
+
201
+ if st.button("📊 Generate Summary", use_container_width=True):
202
+ with st.spinner("⏳ Generating summary..."):
203
+ try:
204
+ summary = pipeline.summarize_documents()
205
+ st.success("✓ Summary generated!")
206
+ st.subheader("📄 Document Summary")
207
+ st.write(summary)
208
+ except Exception as e:
209
+ st.error(f"Error generating summary: {str(e)}")
210
+
211
+ # Tab 3: Document Retrieval
212
+ with tab3:
213
+ st.subheader("Search and Retrieve Documents")
214
+
215
+ search_query = st.text_input(
216
+ "Search query:",
217
+ placeholder="Enter search terms...",
218
+ key="retrieval_search"
219
+ )
220
+
221
+ col1, col2 = st.columns(2)
222
+ with col1:
223
+ search_btn = st.button("🔎 Search", use_container_width=True)
224
+ with col2:
225
+ clear_search_btn = st.button("Clear Search", use_container_width=True)
226
+
227
+ if clear_search_btn:
228
+ st.rerun()
229
+
230
+ if search_btn:
231
+ if search_query.strip():
232
+ with st.spinner("⏳ Searching..."):
233
+ try:
234
+ results = pipeline.retrieve_documents(search_query, n_results=n_context_docs)
235
+ except Exception as e:
236
+ st.error(f"Search error: {str(e)}")
237
+ results = []
238
+
239
+ if results:
240
+ st.success(f"✓ Found {len(results)} documents")
241
+
242
+ for i, doc in enumerate(results, 1):
243
+ with st.expander(f"📄 Document {i} - {doc['source']}", expanded=(i==1)):
244
+ st.write(doc["content"])
245
+ else:
246
+ st.warning("No documents found matching your query")
247
+ else:
248
+ st.warning("Please enter a search query")
249
+
250
+ # Footer
251
+ st.divider()
252
+ with st.expander("ℹ️ System Information"):
253
+ info = pipeline.vector_store.get_collection_info()
254
+ col1, col2, col3, col4 = st.columns(4)
255
+ with col1:
256
+ st.metric("📚 Documents", info.get("document_count", 0))
257
+ with col2:
258
+ st.metric("🖥️ Device", device.upper())
259
+ with col3:
260
+ st.metric("🔍 Context Docs", n_context_docs)
261
+ with col4:
262
+ pdf_count = len(list(Path(pdf_dir).glob("*.pdf")))
263
+ st.metric("📁 PDF Files", pdf_count)
264
+
265
+ else:
266
+ st.error("❌ Failed to initialize RAG pipeline")
267
+ st.info("💡 **How to get started:**\n1. Upload PDF files using the sidebar\n2. Click 'Upload PDFs' to save them\n3. Click 'Reload & Index PDFs' to process them")
src/embeddings.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import List
4
+ from transformers import CLIPModel, CLIPProcessor
5
+
6
+ class CLIPEmbedder:
7
+ def __init__(self, model_name: str = "openai/clip-vit-base-patch32", device: str = "cpu"):
8
+ self.device = device
9
+ self.model_name = model_name
10
+
11
+ print(f"→ Loading CLIP model: {model_name}")
12
+
13
+ # Load from transformers with correct identifier
14
+ self.model = CLIPModel.from_pretrained(model_name).to(device)
15
+ self.processor = CLIPProcessor.from_pretrained(model_name)
16
+
17
+ # Set model to eval mode
18
+ self.model.eval()
19
+
20
+ print(f"✓ CLIP model loaded on {device}")
21
+
22
+ def encode_text(self, texts: List[str]) -> np.ndarray:
23
+ """Encode text using CLIP"""
24
+ with torch.no_grad():
25
+ # Process texts
26
+ inputs = self.processor(
27
+ text=texts,
28
+ return_tensors="pt",
29
+ padding=True,
30
+ truncation=True,
31
+ max_length=77
32
+ ).to(self.device)
33
+
34
+ # Get text embeddings
35
+ text_features = self.model.get_text_features(**inputs)
36
+
37
+ # Normalize embeddings
38
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
39
+
40
+ return text_features.cpu().numpy()
41
+
42
+ def encode_single_text(self, text: str) -> np.ndarray:
43
+ """Encode single text"""
44
+ return self.encode_text([text])[0]
45
+
46
+ def __call__(self, texts: List[str]) -> np.ndarray:
47
+ """Make embedder callable"""
48
+ return self.encode_text(texts)
src/multimodal_model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoImageProcessor
3
+ from typing import Optional, Tuple
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ class GemmaVisionModel:
8
+ def __init__(self, model_name: str = "unsloth/gemma-3-1b-pt", device: str = "cpu"):
9
+ self.device = device
10
+ self.model_name = model_name
11
+
12
+ print(f"→ Loading {model_name}...")
13
+
14
+ # Load with 4-bit quantization for memory efficiency
15
+ try:
16
+ from transformers import BitsAndBytesConfig
17
+
18
+ quantization_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_compute_dtype=torch.float32,
21
+ bnb_4bit_use_double_quant=False,
22
+ bnb_4bit_quant_type="nf4"
23
+ )
24
+
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ quantization_config=quantization_config,
28
+ device_map="auto",
29
+ trust_remote_code=True
30
+ )
31
+ except:
32
+ # Fallback without quantization
33
+ self.model = AutoModelForCausalLM.from_pretrained(
34
+ model_name,
35
+ torch_dtype=torch.float32,
36
+ device_map="cpu",
37
+ trust_remote_code=True
38
+ )
39
+
40
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
41
+ self.model.eval()
42
+
43
+ print(f"✓ Model loaded successfully")
44
+
45
+ def generate_response(self, prompt: str, max_length: int = 512, temperature: float = 0.7) -> str:
46
+ """Generate text response"""
47
+ with torch.no_grad():
48
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
49
+
50
+ outputs = self.model.generate(
51
+ **inputs,
52
+ temperature=0.8, # ← Keep in 0.5-1.5 range
53
+ do_sample=True, # ← Use sampling for variety
54
+ top_p=0.95, # ← Nucleus sampling
55
+ top_k=50, # ← Top-K sampling
56
+ remove_invalid_values=True, # ← Remove NaN/Inf
57
+ repetition_penalty=1.2, # ← Avoid repetition
58
+ pad_token_id=self.tokenizer.eos_token_id,
59
+ eos_token_id=self.tokenizer.eos_token_id
60
+ )
61
+
62
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+
64
+ return response
65
+
66
+ def summarize_text(self, text: str, max_length: int = 256) -> str:
67
+ """Summarize provided text"""
68
+ prompt = f"Summarize the following text in Russian:\n\n{text}\n\nSummary:"
69
+ return self.generate_response(prompt, max_length=max_length)
70
+
71
+ def answer_question(self, question: str, context: str) -> str:
72
+ """Answer question based on context"""
73
+ prompt = f"""Based on the following context, answer the question in Russian.
74
+
75
+ Context:
76
+ {context}
77
+
78
+ Question: {question}
79
+
80
+ Answer:"""
81
+ return self.generate_response(prompt, max_length=512)
src/pdf_parser.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pdfplumber
4
+ import hashlib
5
+ from pathlib import Path
6
+ from typing import Dict, List, Tuple
7
+ from PIL import Image
8
+ import io
9
+
10
+ class PDFParser:
11
+ def __init__(self, pdf_dir: str, cache_dir: str = ".pdf_cache"):
12
+ self.pdf_dir = pdf_dir
13
+ self.cache_dir = cache_dir
14
+ self.cache_file = os.path.join(cache_dir, "processed_files.json")
15
+
16
+ # Create cache directory
17
+ os.makedirs(cache_dir, exist_ok=True)
18
+
19
+ # Load processed files cache
20
+ self.processed_files = self._load_cache()
21
+
22
+ def _load_cache(self) -> Dict:
23
+ """Load cache of processed files"""
24
+ if os.path.exists(self.cache_file):
25
+ with open(self.cache_file, 'r') as f:
26
+ return json.load(f)
27
+ return {}
28
+
29
+ def _save_cache(self):
30
+ """Save cache of processed files"""
31
+ with open(self.cache_file, 'w') as f:
32
+ json.dump(self.processed_files, f, indent=2)
33
+
34
+ def _get_file_hash(self, filepath: str) -> str:
35
+ """Generate hash of file to detect changes"""
36
+ hash_md5 = hashlib.md5()
37
+ with open(filepath, "rb") as f:
38
+ for chunk in iter(lambda: f.read(4096), b""):
39
+ hash_md5.update(chunk)
40
+ return hash_md5.hexdigest()
41
+
42
+ def _extract_tables(self, page) -> List[Dict]:
43
+ """Extract tables from PDF page"""
44
+ tables = []
45
+ try:
46
+ page_tables = page.extract_tables()
47
+ for i, table in enumerate(page_tables):
48
+ table_text = "\n".join([" | ".join([str(cell) if cell else "" for cell in row]) for row in table])
49
+ tables.append({
50
+ "type": "table",
51
+ "index": i,
52
+ "content": table_text
53
+ })
54
+ except:
55
+ pass
56
+ return tables
57
+
58
+ def _extract_images(self, page, page_num: int, pdf_filename: str) -> List[Dict]:
59
+ """Extract images from PDF page"""
60
+ images = []
61
+ try:
62
+ # Get page images
63
+ page_images = page.images
64
+ for i, img_dict in enumerate(page_images):
65
+ try:
66
+ # Get image as bytes and save locally
67
+ img_name = f"{pdf_filename}_p{page_num}_img{i}.png"
68
+ img_path = os.path.join(self.cache_dir, img_name)
69
+
70
+ # Extract image bytes
71
+ xref = img_dict["srcsize"]
72
+ if xref:
73
+ images.append({
74
+ "type": "image",
75
+ "index": i,
76
+ "path": img_path,
77
+ "description": f"Image from page {page_num}"
78
+ })
79
+ except:
80
+ pass
81
+ except:
82
+ pass
83
+ return images
84
+
85
+ def parse_pdf(self, pdf_path: str) -> Dict:
86
+ """Parse single PDF file"""
87
+ pdf_name = os.path.basename(pdf_path)
88
+ file_hash = self._get_file_hash(pdf_path)
89
+
90
+ # Check if already processed
91
+ if pdf_name in self.processed_files:
92
+ if self.processed_files[pdf_name]["hash"] == file_hash:
93
+ print(f"✓ Skipping {pdf_name} (already processed)")
94
+ return self.processed_files[pdf_name]["data"]
95
+
96
+ print(f"→ Processing {pdf_name}...")
97
+ content = {
98
+ "filename": pdf_name,
99
+ "pages": [],
100
+ "total_pages": 0
101
+ }
102
+
103
+ try:
104
+ with pdfplumber.open(pdf_path) as pdf:
105
+ content["total_pages"] = len(pdf.pages)
106
+
107
+ for page_num, page in enumerate(pdf.pages):
108
+ page_content = {
109
+ "page_num": page_num,
110
+ "text": page.extract_text() or "",
111
+ "tables": self._extract_tables(page),
112
+ "images": self._extract_images(page, page_num, pdf_name.replace('.pdf', ''))
113
+ }
114
+ content["pages"].append(page_content)
115
+
116
+ # Update cache
117
+ self.processed_files[pdf_name] = {
118
+ "hash": file_hash,
119
+ "data": content
120
+ }
121
+ self._save_cache()
122
+ print(f"✓ Successfully processed {pdf_name}")
123
+
124
+ except Exception as e:
125
+ print(f"✗ Error processing {pdf_name}: {str(e)}")
126
+
127
+ return content
128
+
129
+ def parse_all_pdfs(self) -> List[Dict]:
130
+ """Parse all PDFs in directory"""
131
+ pdf_files = list(Path(self.pdf_dir).glob("*.pdf"))
132
+
133
+ if not pdf_files:
134
+ print(f"No PDF files found in {self.pdf_dir}")
135
+ return []
136
+
137
+ all_content = []
138
+ for pdf_path in pdf_files:
139
+ content = self.parse_pdf(str(pdf_path))
140
+ all_content.append(content)
141
+
142
+ return all_content
143
+
144
+
145
+ def extract_text_from_pdfs(pdf_dir: str) -> Tuple[List[str], List[str]]:
146
+ """Extract all text and metadata from PDFs"""
147
+ parser = PDFParser(pdf_dir)
148
+ all_pdfs = parser.parse_all_pdfs()
149
+
150
+ documents = []
151
+ metadatas = []
152
+
153
+ for pdf_content in all_pdfs:
154
+ for page in pdf_content["pages"]:
155
+ # Extract text
156
+ text = page["text"]
157
+
158
+ # Extract table content
159
+ for table in page["tables"]:
160
+ text += "\n\n[TABLE]\n" + table["content"] + "\n[/TABLE]\n"
161
+
162
+ # Split into chunks if too long
163
+ if text.strip():
164
+ # Split by sentences for better chunking
165
+ sentences = text.split('.')
166
+ chunk = ""
167
+ for sentence in sentences:
168
+ if len(chunk) + len(sentence) < 1000:
169
+ chunk += sentence + "."
170
+ else:
171
+ if chunk.strip():
172
+ documents.append(chunk)
173
+ metadatas.append({
174
+ "filename": pdf_content["filename"],
175
+ "page": page["page_num"]
176
+ })
177
+ chunk = sentence + "."
178
+
179
+ if chunk.strip():
180
+ documents.append(chunk)
181
+ metadatas.append({
182
+ "filename": pdf_content["filename"],
183
+ "page": page["page_num"]
184
+ })
185
+
186
+ return documents, metadatas
src/rag_pipeline.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from pdf_parser import extract_text_from_pdfs
3
+ from vector_store import VectorStore
4
+ from embeddings import CLIPEmbedder
5
+ from multimodal_model import GemmaVisionModel
6
+
7
+ class RAGPipeline:
8
+ def __init__(self, pdf_dir: str, chroma_dir: str = "./chroma_db", device: str = "cpu"):
9
+ self.pdf_dir = pdf_dir
10
+ self.device = device
11
+
12
+ # Initialize components
13
+ print("→ Initializing RAG Pipeline...")
14
+
15
+ # Initialize embedder
16
+ self.embedder = CLIPEmbedder(model_name="openai/clip-vit-base-patch32", device=device)
17
+
18
+ # Initialize vector store
19
+ self.vector_store = VectorStore(persist_dir=chroma_dir)
20
+ self.vector_store.get_or_create_collection()
21
+
22
+ # Initialize LLM
23
+ self.llm = GemmaVisionModel(model_name="unsloth/gemma-3-1b-pt", device=device)
24
+
25
+ print("✓ RAG Pipeline initialized")
26
+
27
+ def index_pdfs(self):
28
+ """Index all PDFs from directory"""
29
+ print("→ Indexing PDF documents...")
30
+
31
+ # Extract text from PDFs
32
+ documents, metadatas = extract_text_from_pdfs(self.pdf_dir)
33
+
34
+ if documents:
35
+ # Generate IDs
36
+ ids = [f"doc_{i}" for i in range(len(documents))]
37
+
38
+ # Add to vector store (embeddings generated automatically)
39
+ self.vector_store.add_documents(documents, metadatas, ids)
40
+
41
+ print(f"✓ Indexed {len(documents)} document chunks")
42
+ else:
43
+ print("No documents to index")
44
+
45
+ def retrieve_documents(self, query: str, n_results: int = 5) -> List[Dict]:
46
+ """Retrieve relevant documents"""
47
+ results = self.vector_store.search(query, n_results=n_results)
48
+
49
+ retrieved_docs = []
50
+ for doc, metadata in zip(results["documents"][0], results["metadatas"][0]):
51
+ retrieved_docs.append({
52
+ "content": doc,
53
+ "source": f"{metadata.get('filename', 'Unknown')} (p{metadata.get('page', '?')})"
54
+ })
55
+
56
+ return retrieved_docs
57
+
58
+ def answer_question(self, question: str, n_context_docs: int = 3) -> Dict:
59
+ """Answer question using RAG"""
60
+ # Retrieve relevant documents
61
+ retrieved_docs = self.retrieve_documents(question, n_results=n_context_docs)
62
+
63
+ # Combine context
64
+ context = "\n\n".join([f"[Source: {doc['source']}]\n{doc['content']}" for doc in retrieved_docs])
65
+
66
+ # Generate answer
67
+ answer = self.llm.answer_question(question, context)
68
+
69
+ # Extract just the answer (remove prompt)
70
+ if "Answer:" in answer:
71
+ answer = answer.split("Answer:")[-1].strip()
72
+
73
+ return {
74
+ "answer": answer,
75
+ "sources": [doc["source"] for doc in retrieved_docs],
76
+ "context_used": len(retrieved_docs)
77
+ }
78
+
79
+ def summarize_documents(self) -> str:
80
+ """Summarize all indexed documents"""
81
+ # Get all documents from vector store
82
+ collection_info = self.vector_store.get_collection_info()
83
+ doc_count = collection_info.get("document_count", 0)
84
+
85
+ if doc_count == 0:
86
+ return "No documents to summarize"
87
+
88
+ # Sample documents
89
+ results = self.vector_store.search("main topic summary", n_results=5)
90
+ sampled_content = " ".join([doc for docs in results["documents"] for doc in docs[:200]])
91
+
92
+ summary = self.llm.summarize_text(sampled_content)
93
+ return summary
src/vector_store.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from chromadb.config import Settings
3
+ import os
4
+ from typing import List, Dict, Optional
5
+
6
+ class VectorStore:
7
+ def __init__(self, persist_dir: str = "./chroma_db", embedding_function=None):
8
+ self.persist_dir = persist_dir
9
+ os.makedirs(persist_dir, exist_ok=True)
10
+
11
+ # Initialize ChromaDB persistent client
12
+ self.client = chromadb.PersistentClient(
13
+ path=persist_dir,
14
+ settings=Settings(
15
+ anonymized_telemetry=False,
16
+ allow_reset=True
17
+ )
18
+ )
19
+
20
+ self.embedding_function = embedding_function
21
+ self.collection = None
22
+
23
+ def get_or_create_collection(self, collection_name: str = "pdf_documents"):
24
+ """Get or create ChromaDB collection"""
25
+ try:
26
+ # Try to get existing collection
27
+ self.collection = self.client.get_collection(
28
+ name=collection_name,
29
+ embedding_function=self.embedding_function
30
+ )
31
+ print(f"✓ Loaded existing collection: {collection_name}")
32
+ except:
33
+ # Create new collection
34
+ self.collection = self.client.create_collection(
35
+ name=collection_name,
36
+ embedding_function=self.embedding_function,
37
+ metadata={"hnsw:space": "cosine"}
38
+ )
39
+ print(f"✓ Created new collection: {collection_name}")
40
+
41
+ return self.collection
42
+
43
+ def add_documents(self, documents: List[str], metadatas: List[Dict], ids: Optional[List[str]] = None):
44
+ """Add documents to vector store"""
45
+ if not self.collection:
46
+ self.get_or_create_collection()
47
+
48
+ if ids is None:
49
+ ids = [f"doc_{i}" for i in range(len(documents))]
50
+
51
+ # Get existing IDs to avoid duplicates
52
+ try:
53
+ existing_ids = self.collection.get()["ids"]
54
+ except:
55
+ existing_ids = []
56
+
57
+ # Filter out documents that already exist
58
+ docs_to_add = []
59
+ meta_to_add = []
60
+ ids_to_add = []
61
+
62
+ for doc, meta, doc_id in zip(documents, metadatas, ids):
63
+ if doc_id not in existing_ids:
64
+ docs_to_add.append(doc)
65
+ meta_to_add.append(meta)
66
+ ids_to_add.append(doc_id)
67
+
68
+ if docs_to_add:
69
+ self.collection.add(
70
+ documents=docs_to_add,
71
+ metadatas=meta_to_add,
72
+ ids=ids_to_add
73
+ )
74
+ print(f"✓ Added {len(docs_to_add)} new documents to vector store")
75
+ else:
76
+ print("✓ All documents already in vector store")
77
+
78
+ def search(self, query: str, n_results: int = 5) -> Dict:
79
+ """Search documents in vector store"""
80
+ if not self.collection:
81
+ return {"documents": [], "metadatas": [], "distances": []}
82
+
83
+ results = self.collection.query(
84
+ query_texts=[query],
85
+ n_results=n_results
86
+ )
87
+
88
+ return results
89
+
90
+ def get_collection_info(self) -> Dict:
91
+ """Get collection statistics"""
92
+ if not self.collection:
93
+ return {}
94
+
95
+ count = self.collection.count()
96
+ return {
97
+ "collection_name": self.collection.name,
98
+ "document_count": count
99
+ }