updated rag
Browse files- embeddings_qdrant.py +382 -0
- index_docs.py +101 -0
- main.py +196 -0
- rag_with_gemini.py +201 -0
- requirements.txt +6 -10
embeddings_qdrant.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import google.generativeai as genai
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from typing import List, Dict, Optional, Union
|
| 6 |
+
import json
|
| 7 |
+
import pickle
|
| 8 |
+
import uuid
|
| 9 |
+
from qdrant_client import QdrantClient
|
| 10 |
+
from qdrant_client.http import models
|
| 11 |
+
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
| 12 |
+
|
| 13 |
+
# Load environment variables
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
class EmbeddingManager:
|
| 17 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 18 |
+
"""Initialize the embedding manager with Gemini API."""
|
| 19 |
+
self.api_key = api_key or os.getenv('GEMINI_API_KEY')
|
| 20 |
+
if not self.api_key:
|
| 21 |
+
raise ValueError("GEMINI_API_KEY not found in environment variables")
|
| 22 |
+
|
| 23 |
+
genai.configure(api_key=self.api_key)
|
| 24 |
+
self.model_name = "models/text-embedding-004"
|
| 25 |
+
|
| 26 |
+
def generate_embedding(self, text: str) -> np.ndarray:
|
| 27 |
+
"""Generate embedding for a single text."""
|
| 28 |
+
try:
|
| 29 |
+
result = genai.embed_content(
|
| 30 |
+
model=self.model_name,
|
| 31 |
+
content=text,
|
| 32 |
+
task_type="retrieval_document"
|
| 33 |
+
)
|
| 34 |
+
return np.array(result['embedding'], dtype=np.float32)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Error generating embedding: {e}")
|
| 37 |
+
return np.array([])
|
| 38 |
+
|
| 39 |
+
def generate_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
|
| 40 |
+
"""Generate embeddings for multiple texts."""
|
| 41 |
+
embeddings = []
|
| 42 |
+
for i, text in enumerate(texts):
|
| 43 |
+
print(f"Generating embedding {i+1}/{len(texts)}")
|
| 44 |
+
embedding = self.generate_embedding(text)
|
| 45 |
+
if embedding.size > 0:
|
| 46 |
+
embeddings.append(embedding)
|
| 47 |
+
else:
|
| 48 |
+
print(f"Failed to generate embedding for text {i+1}")
|
| 49 |
+
return embeddings
|
| 50 |
+
|
| 51 |
+
def generate_query_embedding(self, query: str) -> np.ndarray:
|
| 52 |
+
"""Generate embedding for a query (search)."""
|
| 53 |
+
try:
|
| 54 |
+
result = genai.embed_content(
|
| 55 |
+
model=self.model_name,
|
| 56 |
+
content=query,
|
| 57 |
+
task_type="retrieval_query"
|
| 58 |
+
)
|
| 59 |
+
return np.array(result['embedding'], dtype=np.float32)
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"Error generating query embedding: {e}")
|
| 62 |
+
return np.array([])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class QdrantVectorStore:
|
| 66 |
+
def __init__(self, collection_name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None):
|
| 67 |
+
"""Initialize Qdrant vector store."""
|
| 68 |
+
self.collection_name = collection_name or os.getenv('QDRANT_COLLECTION_NAME', 'rag_documents')
|
| 69 |
+
|
| 70 |
+
# Get Qdrant configuration from environment
|
| 71 |
+
qdrant_url = url or os.getenv('QDRANT_URL')
|
| 72 |
+
qdrant_api_key = api_key or os.getenv('QDRANT_API_KEY')
|
| 73 |
+
|
| 74 |
+
# Initialize Qdrant client
|
| 75 |
+
if qdrant_url and qdrant_api_key:
|
| 76 |
+
# Qdrant Cloud
|
| 77 |
+
print(f"Connecting to Qdrant Cloud at {qdrant_url}")
|
| 78 |
+
self.client = QdrantClient(
|
| 79 |
+
url=qdrant_url,
|
| 80 |
+
api_key=qdrant_api_key,
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
# Local Qdrant (default)
|
| 84 |
+
print("Using local Qdrant instance at http://localhost:6333")
|
| 85 |
+
self.client = QdrantClient("localhost", port=6333)
|
| 86 |
+
|
| 87 |
+
self.embedding_dim = 768 # Gemini embedding dimension
|
| 88 |
+
|
| 89 |
+
def create_collection(self, force_recreate: bool = False):
|
| 90 |
+
"""Create or recreate the collection."""
|
| 91 |
+
try:
|
| 92 |
+
# Check if collection exists
|
| 93 |
+
collections = self.client.get_collections().collections
|
| 94 |
+
collection_exists = any(col.name == self.collection_name for col in collections)
|
| 95 |
+
|
| 96 |
+
if collection_exists and force_recreate:
|
| 97 |
+
print(f"Deleting existing collection: {self.collection_name}")
|
| 98 |
+
self.client.delete_collection(collection_name=self.collection_name)
|
| 99 |
+
collection_exists = False
|
| 100 |
+
|
| 101 |
+
if not collection_exists:
|
| 102 |
+
print(f"Creating collection: {self.collection_name}")
|
| 103 |
+
self.client.create_collection(
|
| 104 |
+
collection_name=self.collection_name,
|
| 105 |
+
vectors_config=VectorParams(size=self.embedding_dim, distance=Distance.COSINE),
|
| 106 |
+
)
|
| 107 |
+
print(f"✓ Collection '{self.collection_name}' created successfully")
|
| 108 |
+
else:
|
| 109 |
+
print(f"✓ Collection '{self.collection_name}' already exists")
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error creating collection: {e}")
|
| 113 |
+
raise
|
| 114 |
+
|
| 115 |
+
def add_documents(self, chunks: List[str], embeddings: List[np.ndarray], metadata: List[Dict] = None, session_id: Optional[str] = None):
|
| 116 |
+
"""Add documents with their embeddings to Qdrant.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
chunks: list of text chunks
|
| 120 |
+
embeddings: list of numpy embeddings corresponding to chunks
|
| 121 |
+
metadata: optional list of dicts with metadata per chunk
|
| 122 |
+
session_id: optional session identifier to attach to each point payload
|
| 123 |
+
"""
|
| 124 |
+
if metadata is None:
|
| 125 |
+
metadata = [{"index": i} for i in range(len(chunks))]
|
| 126 |
+
|
| 127 |
+
if len(chunks) != len(embeddings) or len(chunks) != len(metadata):
|
| 128 |
+
raise ValueError("chunks, embeddings, and metadata must have the same length")
|
| 129 |
+
|
| 130 |
+
# Ensure collection exists
|
| 131 |
+
self.create_collection()
|
| 132 |
+
|
| 133 |
+
# Prepare points for Qdrant
|
| 134 |
+
points = []
|
| 135 |
+
for i, (chunk, embedding, meta) in enumerate(zip(chunks, embeddings, metadata)):
|
| 136 |
+
point_id = str(uuid.uuid4())
|
| 137 |
+
|
| 138 |
+
# Combine text and metadata for payload
|
| 139 |
+
payload = {
|
| 140 |
+
"text": chunk,
|
| 141 |
+
"metadata": meta
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# Attach session info if provided
|
| 145 |
+
if session_id is not None:
|
| 146 |
+
payload["session_id"] = session_id
|
| 147 |
+
|
| 148 |
+
point = PointStruct(
|
| 149 |
+
id=point_id,
|
| 150 |
+
vector=embedding.tolist(),
|
| 151 |
+
payload=payload
|
| 152 |
+
)
|
| 153 |
+
points.append(point)
|
| 154 |
+
|
| 155 |
+
# Upload points to Qdrant
|
| 156 |
+
try:
|
| 157 |
+
print(f"Uploading {len(points)} documents to Qdrant...")
|
| 158 |
+
self.client.upsert(
|
| 159 |
+
collection_name=self.collection_name,
|
| 160 |
+
points=points
|
| 161 |
+
)
|
| 162 |
+
print(f"✓ Successfully uploaded {len(points)} documents")
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"Error uploading documents: {e}")
|
| 165 |
+
raise
|
| 166 |
+
|
| 167 |
+
def similarity_search(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.0,
|
| 168 |
+
include_context: bool = False) -> List[Dict]:
|
| 169 |
+
"""
|
| 170 |
+
Search for similar documents in Qdrant.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
query_embedding: The query vector
|
| 174 |
+
top_k: Number of results to return
|
| 175 |
+
score_threshold: Minimum similarity score
|
| 176 |
+
include_context: If True, try to include surrounding chunks for context
|
| 177 |
+
"""
|
| 178 |
+
try:
|
| 179 |
+
search_results = self.client.search(
|
| 180 |
+
collection_name=self.collection_name,
|
| 181 |
+
query_vector=query_embedding.tolist(),
|
| 182 |
+
limit=top_k,
|
| 183 |
+
score_threshold=score_threshold
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
results = []
|
| 187 |
+
for hit in search_results:
|
| 188 |
+
metadata = hit.payload['metadata']
|
| 189 |
+
|
| 190 |
+
# Basic result structure
|
| 191 |
+
result = {
|
| 192 |
+
'id': hit.id,
|
| 193 |
+
'similarity': hit.score,
|
| 194 |
+
'chunk': hit.payload['text'],
|
| 195 |
+
'metadata': metadata,
|
| 196 |
+
'source': {
|
| 197 |
+
'file_name': metadata.get('file_name', 'Unknown'),
|
| 198 |
+
'file_path': metadata.get('file_path', 'Unknown'),
|
| 199 |
+
'chunk_index': metadata.get('chunk_index', 0)
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
# Add context if requested
|
| 204 |
+
if include_context:
|
| 205 |
+
result['context'] = self._get_surrounding_context(metadata)
|
| 206 |
+
|
| 207 |
+
# Add citation format
|
| 208 |
+
result['citation'] = f"{metadata.get('file_name', 'Unknown')} (chunk {metadata.get('chunk_index', 0)})"
|
| 209 |
+
|
| 210 |
+
results.append(result)
|
| 211 |
+
|
| 212 |
+
return results
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"Error searching documents: {e}")
|
| 216 |
+
return []
|
| 217 |
+
|
| 218 |
+
def _get_surrounding_context(self, metadata: Dict) -> Dict:
|
| 219 |
+
"""Get surrounding chunks for context (if available)."""
|
| 220 |
+
try:
|
| 221 |
+
file_path = metadata.get('file_path')
|
| 222 |
+
chunk_index = metadata.get('chunk_index', 0)
|
| 223 |
+
|
| 224 |
+
# Try to find adjacent chunks from the same file
|
| 225 |
+
context_filter = {
|
| 226 |
+
"must": [
|
| 227 |
+
{"key": "metadata.file_path", "match": {"value": file_path}}
|
| 228 |
+
]
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
# Search for chunks from same file
|
| 232 |
+
context_results = self.client.search(
|
| 233 |
+
collection_name=self.collection_name,
|
| 234 |
+
query_vector=[0.0] * self.embedding_dim, # Dummy vector
|
| 235 |
+
query_filter=context_filter,
|
| 236 |
+
limit=10,
|
| 237 |
+
score_threshold=0.0
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Sort by chunk index and get surrounding chunks
|
| 241 |
+
file_chunks = []
|
| 242 |
+
for hit in context_results:
|
| 243 |
+
hit_metadata = hit.payload['metadata']
|
| 244 |
+
if hit_metadata.get('chunk_index') is not None:
|
| 245 |
+
file_chunks.append({
|
| 246 |
+
'index': hit_metadata['chunk_index'],
|
| 247 |
+
'text': hit.payload['text']
|
| 248 |
+
})
|
| 249 |
+
|
| 250 |
+
file_chunks.sort(key=lambda x: x['index'])
|
| 251 |
+
|
| 252 |
+
# Find current chunk and get neighbors
|
| 253 |
+
current_idx = None
|
| 254 |
+
for i, chunk in enumerate(file_chunks):
|
| 255 |
+
if chunk['index'] == chunk_index:
|
| 256 |
+
current_idx = i
|
| 257 |
+
break
|
| 258 |
+
|
| 259 |
+
context = {
|
| 260 |
+
'previous_chunk': file_chunks[current_idx - 1]['text'] if current_idx and current_idx > 0 else None,
|
| 261 |
+
'next_chunk': file_chunks[current_idx + 1]['text'] if current_idx is not None and current_idx < len(file_chunks) - 1 else None,
|
| 262 |
+
'total_chunks_in_file': len(file_chunks)
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
return context
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Error getting context: {e}")
|
| 269 |
+
return {'error': 'Could not retrieve context'}
|
| 270 |
+
|
| 271 |
+
def get_relevant_passages(self, query_embedding: np.ndarray, top_k: int = 5) -> List[str]:
|
| 272 |
+
"""Return just the text passages for RAG prompt creation."""
|
| 273 |
+
results = self.similarity_search(query_embedding, top_k)
|
| 274 |
+
return [result['chunk'] for result in results if result['chunk']]
|
| 275 |
+
|
| 276 |
+
def enhanced_search(self, query_embedding: np.ndarray, top_k: int = 5) -> str:
|
| 277 |
+
"""Return a formatted string with search results ready for RAG."""
|
| 278 |
+
results = self.similarity_search(query_embedding, top_k, include_context=True)
|
| 279 |
+
|
| 280 |
+
if not results:
|
| 281 |
+
return "No relevant documents found."
|
| 282 |
+
|
| 283 |
+
formatted_results = []
|
| 284 |
+
for i, result in enumerate(results, 1):
|
| 285 |
+
formatted_result = f"""
|
| 286 |
+
**Result {i}** (Similarity: {result['similarity']:.3f})
|
| 287 |
+
**Source**: {result['citation']}
|
| 288 |
+
**Content**: {result['chunk']}
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
# Add context if available
|
| 292 |
+
if 'context' in result and not result['context'].get('error'):
|
| 293 |
+
context = result['context']
|
| 294 |
+
if context.get('previous_chunk'):
|
| 295 |
+
formatted_result += f"\n**Previous Context**: ...{context['previous_chunk'][-100:]}"
|
| 296 |
+
if context.get('next_chunk'):
|
| 297 |
+
formatted_result += f"\n**Following Context**: {context['next_chunk'][:100]}..."
|
| 298 |
+
|
| 299 |
+
formatted_results.append(formatted_result)
|
| 300 |
+
|
| 301 |
+
return "\n" + "="*50 + "\n".join(formatted_results)
|
| 302 |
+
|
| 303 |
+
def get_collection_info(self) -> Dict:
|
| 304 |
+
"""Get information about the collection."""
|
| 305 |
+
try:
|
| 306 |
+
info = self.client.get_collection(collection_name=self.collection_name)
|
| 307 |
+
return {
|
| 308 |
+
'name': self.collection_name,
|
| 309 |
+
'points_count': info.points_count,
|
| 310 |
+
'vectors_count': info.vectors_count,
|
| 311 |
+
'status': info.status
|
| 312 |
+
}
|
| 313 |
+
except Exception as e:
|
| 314 |
+
print(f"Error getting collection info: {e}")
|
| 315 |
+
return {}
|
| 316 |
+
|
| 317 |
+
def delete_collection(self):
|
| 318 |
+
"""Delete the collection."""
|
| 319 |
+
try:
|
| 320 |
+
self.client.delete_collection(collection_name=self.collection_name)
|
| 321 |
+
print(f"✓ Collection '{self.collection_name}' deleted")
|
| 322 |
+
except Exception as e:
|
| 323 |
+
print(f"Error deleting collection: {e}")
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
# Example usage with Qdrant
|
| 328 |
+
print("Testing Qdrant Vector Store...")
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
embedding_manager = EmbeddingManager()
|
| 332 |
+
qdrant_store = QdrantVectorStore()
|
| 333 |
+
|
| 334 |
+
# Test with sample texts
|
| 335 |
+
sample_texts = [
|
| 336 |
+
"This is a sample document about machine learning and artificial intelligence.",
|
| 337 |
+
"Python is a great programming language for data science and AI development.",
|
| 338 |
+
"Qdrant is a vector database that enables similarity search at scale."
|
| 339 |
+
]
|
| 340 |
+
|
| 341 |
+
print("Generating embeddings...")
|
| 342 |
+
embeddings = embedding_manager.generate_embeddings_batch(sample_texts)
|
| 343 |
+
|
| 344 |
+
if embeddings:
|
| 345 |
+
# Create metadata
|
| 346 |
+
metadata = [
|
| 347 |
+
{"source": "sample_doc", "topic": "machine_learning", "index": 0},
|
| 348 |
+
{"source": "sample_doc", "topic": "programming", "index": 1},
|
| 349 |
+
{"source": "sample_doc", "topic": "database", "index": 2}
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
# Add to Qdrant
|
| 353 |
+
qdrant_store.add_documents(sample_texts, embeddings, metadata)
|
| 354 |
+
|
| 355 |
+
# Test search - Basic
|
| 356 |
+
query = "What is vector database?"
|
| 357 |
+
query_embedding = embedding_manager.generate_query_embedding(query)
|
| 358 |
+
|
| 359 |
+
if query_embedding.size > 0:
|
| 360 |
+
print(f"\n🔍 BASIC SEARCH: {query}")
|
| 361 |
+
results = qdrant_store.similarity_search(query_embedding, top_k=2)
|
| 362 |
+
for result in results:
|
| 363 |
+
print(f"Similarity: {result['similarity']:.4f}")
|
| 364 |
+
print(f"Source: {result['citation']}")
|
| 365 |
+
print(f"Text: {result['chunk']}")
|
| 366 |
+
print(f"Topic: {result['metadata']['topic']}")
|
| 367 |
+
print("---")
|
| 368 |
+
|
| 369 |
+
# Test enhanced search
|
| 370 |
+
print(f"\n🚀 ENHANCED SEARCH (RAG-ready format):")
|
| 371 |
+
enhanced_results = qdrant_store.enhanced_search(query_embedding, top_k=2)
|
| 372 |
+
print(enhanced_results)
|
| 373 |
+
|
| 374 |
+
# Show collection info
|
| 375 |
+
info = qdrant_store.get_collection_info()
|
| 376 |
+
print(f"\nCollection Info: {info}")
|
| 377 |
+
|
| 378 |
+
except Exception as e:
|
| 379 |
+
print(f"Error in test: {e}")
|
| 380 |
+
print("Make sure:")
|
| 381 |
+
print("1. Your GEMINI_API_KEY is valid in .env file")
|
| 382 |
+
print("2. Qdrant is running (docker run -p 6333:6333 qdrant/qdrant) or configure Qdrant Cloud")
|
index_docs.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from docx import Document
|
| 3 |
+
try:
|
| 4 |
+
import fitz # PyMuPDF
|
| 5 |
+
except Exception:
|
| 6 |
+
# fall back to pymupdf module name if present
|
| 7 |
+
import pymupdf as fitz
|
| 8 |
+
|
| 9 |
+
def load_pdf_text(file_path: str) -> str:
|
| 10 |
+
try:
|
| 11 |
+
doc = fitz.open(file_path)
|
| 12 |
+
text = ""
|
| 13 |
+
# iterate directly over pages
|
| 14 |
+
for page in doc:
|
| 15 |
+
# use standard PyMuPDF API
|
| 16 |
+
try:
|
| 17 |
+
page_text = page.get_text()
|
| 18 |
+
except Exception:
|
| 19 |
+
# try alternate name for older versions
|
| 20 |
+
page_text = page.getText() if hasattr(page, 'getText') else ''
|
| 21 |
+
if page_text:
|
| 22 |
+
text += page_text + "\n"
|
| 23 |
+
try:
|
| 24 |
+
doc.close()
|
| 25 |
+
except Exception:
|
| 26 |
+
pass
|
| 27 |
+
return text.strip()
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Error reading PDF {file_path}: {e}")
|
| 30 |
+
return ""
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_docx_text(file_path: str) -> str:
|
| 34 |
+
try:
|
| 35 |
+
doc = Document(file_path)
|
| 36 |
+
paragraphs = [p.text for p in doc.paragraphs if p.text]
|
| 37 |
+
return "\n".join(paragraphs).strip()
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print(f"Error reading DOCX {file_path}: {e}")
|
| 40 |
+
return ""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_txt_text(file_path: str) -> str:
|
| 44 |
+
try:
|
| 45 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 46 |
+
return f.read()
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error reading TXT {file_path}: {e}")
|
| 49 |
+
return ""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def extract_text_from_path(path: str) -> Optional[str]:
|
| 53 |
+
if path.lower().endswith('.pdf'):
|
| 54 |
+
return load_pdf_text(path)
|
| 55 |
+
if path.lower().endswith('.docx'):
|
| 56 |
+
return load_docx_text(path)
|
| 57 |
+
if path.lower().endswith('.txt'):
|
| 58 |
+
return load_txt_text(path)
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 100) -> list:
|
| 63 |
+
chunks = []
|
| 64 |
+
start = 0
|
| 65 |
+
text_length = len(text)
|
| 66 |
+
while start < text_length:
|
| 67 |
+
end = min(start + chunk_size, text_length)
|
| 68 |
+
chunk = text[start:end]
|
| 69 |
+
chunks.append(chunk)
|
| 70 |
+
start += chunk_size - overlap
|
| 71 |
+
return chunks
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == '__main__':
|
| 75 |
+
import sys
|
| 76 |
+
|
| 77 |
+
def usage():
|
| 78 |
+
print('Usage: python src/index_docs.py <path-to-file-or-folder> [chunk_size]')
|
| 79 |
+
|
| 80 |
+
if len(sys.argv) < 2:
|
| 81 |
+
usage()
|
| 82 |
+
sys.exit(1)
|
| 83 |
+
|
| 84 |
+
path = sys.argv[1]
|
| 85 |
+
chunk_size = int(sys.argv[2]) if len(sys.argv) > 2 else 500
|
| 86 |
+
|
| 87 |
+
print(f'Testing extraction for: {path}')
|
| 88 |
+
text = extract_text_from_path(path)
|
| 89 |
+
if not text:
|
| 90 |
+
print('No text extracted or unsupported file type.')
|
| 91 |
+
sys.exit(1)
|
| 92 |
+
|
| 93 |
+
print('Characters extracted:', len(text))
|
| 94 |
+
chunks = chunk_text(text, chunk_size=chunk_size)
|
| 95 |
+
print('Chunks produced:', len(chunks))
|
| 96 |
+
if chunks:
|
| 97 |
+
preview = 300
|
| 98 |
+
print('\n--- First chunk preview ---')
|
| 99 |
+
print(chunks[0][:preview])
|
| 100 |
+
print('\n--- Second chunk preview ---')
|
| 101 |
+
print(chunks[1][:preview] if len(chunks) > 1 else '<none>')
|
main.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
import hashlib
|
| 5 |
+
from typing import List
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from rag_with_gemini import RAGSystem
|
| 8 |
+
|
| 9 |
+
# Load environment variables
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
# --- PAGE CONFIG ---
|
| 13 |
+
st.set_page_config(
|
| 14 |
+
page_title="RAG Document Assistant",
|
| 15 |
+
page_icon="🤖",
|
| 16 |
+
layout="wide",
|
| 17 |
+
initial_sidebar_state="expanded"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# --- SESSION STATE INIT ---
|
| 21 |
+
def initialize_session_state():
|
| 22 |
+
if 'rag_system' not in st.session_state:
|
| 23 |
+
st.session_state.rag_system = None
|
| 24 |
+
if 'documents_processed' not in st.session_state:
|
| 25 |
+
st.session_state.documents_processed = []
|
| 26 |
+
# store SHA256 hashes of processed files to avoid reprocessing the same file in a session
|
| 27 |
+
if 'processed_hashes' not in st.session_state:
|
| 28 |
+
st.session_state.processed_hashes = set()
|
| 29 |
+
if 'chat_history' not in st.session_state:
|
| 30 |
+
st.session_state.chat_history = []
|
| 31 |
+
if 'processing_status' not in st.session_state:
|
| 32 |
+
st.session_state.processing_status = ""
|
| 33 |
+
if 'system_initialized' not in st.session_state:
|
| 34 |
+
st.session_state.system_initialized = False
|
| 35 |
+
|
| 36 |
+
# --- RAG SYSTEM INIT ---
|
| 37 |
+
def initialize_rag_system():
|
| 38 |
+
if st.session_state.system_initialized:
|
| 39 |
+
return True
|
| 40 |
+
try:
|
| 41 |
+
gemini_api_key = os.getenv('GEMINI_API_KEY')
|
| 42 |
+
qdrant_url = os.getenv('QDRANT_URL')
|
| 43 |
+
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
| 44 |
+
|
| 45 |
+
if not gemini_api_key or not qdrant_url or not qdrant_api_key:
|
| 46 |
+
st.error("❌ Missing API keys in your .env file.")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
with st.spinner("🚀 Initializing RAG system..."):
|
| 50 |
+
rag_system = RAGSystem(gemini_api_key, qdrant_url, qdrant_api_key)
|
| 51 |
+
st.session_state.rag_system = rag_system
|
| 52 |
+
st.session_state.system_initialized = True
|
| 53 |
+
return True
|
| 54 |
+
except Exception as e:
|
| 55 |
+
st.error(f"❌ Initialization error: {e}")
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
# --- DOCUMENT PROCESSING ---
|
| 59 |
+
def process_uploaded_files(uploaded_files):
|
| 60 |
+
if not uploaded_files or not st.session_state.rag_system:
|
| 61 |
+
return False
|
| 62 |
+
try:
|
| 63 |
+
temp_paths = []
|
| 64 |
+
to_process = []
|
| 65 |
+
skipped = []
|
| 66 |
+
|
| 67 |
+
# Determine which files are new by hashing contents
|
| 68 |
+
for uploaded_file in uploaded_files:
|
| 69 |
+
data = uploaded_file.getvalue()
|
| 70 |
+
h = hashlib.sha256(data).hexdigest()
|
| 71 |
+
if h in st.session_state.processed_hashes:
|
| 72 |
+
skipped.append(uploaded_file.name)
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
# write temp file for processing
|
| 76 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
|
| 77 |
+
tmp.write(data)
|
| 78 |
+
temp_paths.append(tmp.name)
|
| 79 |
+
to_process.append((uploaded_file.name, h))
|
| 80 |
+
|
| 81 |
+
# If there are no new files to process, short-circuit
|
| 82 |
+
if not temp_paths:
|
| 83 |
+
st.session_state.processing_status = f"⚠️ No new files to process. Skipped: {', '.join(skipped)}" if skipped else "⚠️ No files provided."
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
with st.spinner("📄 Processing documents..."):
|
| 87 |
+
success = st.session_state.rag_system.add_documents(temp_paths)
|
| 88 |
+
|
| 89 |
+
for path in temp_paths:
|
| 90 |
+
try:
|
| 91 |
+
os.unlink(path)
|
| 92 |
+
except:
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
if success:
|
| 96 |
+
# record processed filenames and their hashes
|
| 97 |
+
for name, h in to_process:
|
| 98 |
+
st.session_state.documents_processed.append(name)
|
| 99 |
+
st.session_state.processed_hashes.add(h)
|
| 100 |
+
|
| 101 |
+
# if some were skipped, include that in the status
|
| 102 |
+
status_msg = f"✅ Processed {len(to_process)} documents!"
|
| 103 |
+
if skipped:
|
| 104 |
+
status_msg += f" Skipped {len(skipped)} duplicate(s): {', '.join(skipped)}"
|
| 105 |
+
st.session_state.processing_status = status_msg
|
| 106 |
+
return True
|
| 107 |
+
else:
|
| 108 |
+
st.session_state.processing_status = "❌ Failed to process documents."
|
| 109 |
+
return False
|
| 110 |
+
except Exception as e:
|
| 111 |
+
st.session_state.processing_status = f"❌ Error: {str(e)}"
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
# --- CHAT DISPLAY ---
|
| 115 |
+
def display_chat_message(role: str, content: str, sources: List[str] = None):
|
| 116 |
+
avatar_url = (
|
| 117 |
+
"https://cdn-icons-png.flaticon.com/512/4712/4712035.png"
|
| 118 |
+
if role == "assistant"
|
| 119 |
+
else "https://cdn-icons-png.flaticon.com/512/1077/1077012.png"
|
| 120 |
+
)
|
| 121 |
+
with st.chat_message(role, avatar=avatar_url):
|
| 122 |
+
st.markdown(content)
|
| 123 |
+
|
| 124 |
+
# --- MAIN ---
|
| 125 |
+
def main():
|
| 126 |
+
initialize_session_state()
|
| 127 |
+
st.markdown('<h1 class="main-header">RAG Document Assistant</h1>', unsafe_allow_html=True)
|
| 128 |
+
|
| 129 |
+
if not initialize_rag_system():
|
| 130 |
+
st.stop()
|
| 131 |
+
|
| 132 |
+
# Sidebar
|
| 133 |
+
with st.sidebar:
|
| 134 |
+
st.markdown("### 📁 Upload Documents")
|
| 135 |
+
uploaded_files = st.file_uploader("Choose files", type=['pdf', 'txt', 'docx'], accept_multiple_files=True)
|
| 136 |
+
if uploaded_files and st.button("📤 Process Documents"):
|
| 137 |
+
if process_uploaded_files(uploaded_files):
|
| 138 |
+
st.rerun()
|
| 139 |
+
|
| 140 |
+
if st.session_state.processing_status:
|
| 141 |
+
msg = st.session_state.processing_status
|
| 142 |
+
cls = "success-message" if "✅" in msg else "error-message"
|
| 143 |
+
st.markdown(f'<div class="{cls}">{msg}</div>', unsafe_allow_html=True)
|
| 144 |
+
|
| 145 |
+
if st.session_state.documents_processed:
|
| 146 |
+
st.markdown("### ✅ Processed Files")
|
| 147 |
+
for doc in st.session_state.documents_processed:
|
| 148 |
+
st.write(f"- {doc}")
|
| 149 |
+
|
| 150 |
+
if st.button("🗑️ Clear Chat"):
|
| 151 |
+
st.session_state.chat_history = []
|
| 152 |
+
st.rerun()
|
| 153 |
+
|
| 154 |
+
if not st.session_state.chat_history and not st.session_state.documents_processed:
|
| 155 |
+
st.markdown("""
|
| 156 |
+
<div style="text-align:center; padding:3rem; color:#9ca3af;">
|
| 157 |
+
<h3>👋 Welcome to your RAG Assistant</h3>
|
| 158 |
+
<p>Upload documents in the sidebar, then ask me anything about their content.</p>
|
| 159 |
+
</div>
|
| 160 |
+
""", unsafe_allow_html=True)
|
| 161 |
+
|
| 162 |
+
for message in st.session_state.chat_history:
|
| 163 |
+
display_chat_message(message["role"], message["content"], message.get("sources", []))
|
| 164 |
+
|
| 165 |
+
# Chat input
|
| 166 |
+
if prompt := st.chat_input("💬 Ask me anything..."):
|
| 167 |
+
if not st.session_state.documents_processed:
|
| 168 |
+
st.warning("⚠️ Upload and process documents first!")
|
| 169 |
+
return
|
| 170 |
+
|
| 171 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
| 172 |
+
display_chat_message("user", prompt)
|
| 173 |
+
|
| 174 |
+
with st.chat_message("assistant"):
|
| 175 |
+
with st.spinner("🤔 Thinking..."):
|
| 176 |
+
try:
|
| 177 |
+
result = st.session_state.rag_system.query(prompt)
|
| 178 |
+
st.markdown(result['answer'])
|
| 179 |
+
# if result['sources']:
|
| 180 |
+
# with st.expander("📚 Sources"):
|
| 181 |
+
# for i, src in enumerate(result['sources'], 1):
|
| 182 |
+
# st.write(f"{i}. {os.path.basename(src)}")
|
| 183 |
+
|
| 184 |
+
st.session_state.chat_history.append({
|
| 185 |
+
"role": "assistant",
|
| 186 |
+
"content": result['answer'],
|
| 187 |
+
"sources": result['sources']
|
| 188 |
+
})
|
| 189 |
+
except Exception as e:
|
| 190 |
+
error_msg = f"❌ Error: {str(e)}"
|
| 191 |
+
st.error(error_msg)
|
| 192 |
+
st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
main()
|
rag_with_gemini.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG (Retrieval-Augmented Generation) system with Gemini
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import google.generativeai as genai
|
| 6 |
+
import time
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
from embeddings_qdrant import EmbeddingManager, QdrantVectorStore
|
| 10 |
+
from index_docs import extract_text_from_path, chunk_text
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class RAGSystem:
|
| 17 |
+
"""Complete RAG system with Gemini AI"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, gemini_api_key: str, qdrant_url: str, qdrant_api_key: str):
|
| 20 |
+
"""Initialize RAG system with Gemini and Qdrant"""
|
| 21 |
+
# Configure Gemini
|
| 22 |
+
genai.configure(api_key=gemini_api_key)
|
| 23 |
+
self.model = genai.GenerativeModel("models/gemini-2.5-flash")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Initialize components
|
| 27 |
+
self.embedding_manager = EmbeddingManager(gemini_api_key)
|
| 28 |
+
|
| 29 |
+
# Try Qdrant Cloud first, fallback to simple vector store
|
| 30 |
+
try:
|
| 31 |
+
self.vector_store = QdrantVectorStore(url=qdrant_url, api_key=qdrant_api_key)
|
| 32 |
+
self.vector_store.create_collection(force_recreate=True)
|
| 33 |
+
logger.info("✅ Connected to Qdrant Cloud")
|
| 34 |
+
self.using_qdrant = True
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.warning(f"❌ Qdrant Cloud connection failed: {e}")
|
| 37 |
+
logger.info("🔄 Falling back to simple vector store")
|
| 38 |
+
self.vector_store.create_collection()
|
| 39 |
+
self.using_qdrant = False
|
| 40 |
+
|
| 41 |
+
def add_documents(self, file_paths: List[str], session_id: Optional[str] = None) -> bool:
|
| 42 |
+
"""Add documents to the vector store"""
|
| 43 |
+
try:
|
| 44 |
+
all_chunks = []
|
| 45 |
+
|
| 46 |
+
for file_path in file_paths:
|
| 47 |
+
logger.info(f"Processing {file_path}")
|
| 48 |
+
|
| 49 |
+
# Extract text
|
| 50 |
+
text = extract_text_from_path(file_path)
|
| 51 |
+
if not text:
|
| 52 |
+
logger.warning(f"No text extracted from {file_path}")
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
# Chunk text
|
| 56 |
+
chunks = chunk_text(text)
|
| 57 |
+
|
| 58 |
+
# Add metadata
|
| 59 |
+
for chunk in chunks:
|
| 60 |
+
all_chunks.append({
|
| 61 |
+
'text': chunk,
|
| 62 |
+
'source': file_path,
|
| 63 |
+
'chunk_id': len(all_chunks)
|
| 64 |
+
})
|
| 65 |
+
|
| 66 |
+
if not all_chunks:
|
| 67 |
+
logger.error("No chunks to process")
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
# Generate embeddings
|
| 71 |
+
logger.info(f"Generating embeddings for {len(all_chunks)} chunks")
|
| 72 |
+
|
| 73 |
+
embeddings = []
|
| 74 |
+
texts = []
|
| 75 |
+
metadata_list = []
|
| 76 |
+
|
| 77 |
+
for i, chunk in enumerate(all_chunks):
|
| 78 |
+
try:
|
| 79 |
+
# Generate embedding
|
| 80 |
+
embedding = self.embedding_manager.generate_embedding(chunk['text'])
|
| 81 |
+
|
| 82 |
+
embeddings.append(embedding)
|
| 83 |
+
texts.append(chunk['text'])
|
| 84 |
+
metadata_list.append({
|
| 85 |
+
'source': chunk['source'],
|
| 86 |
+
'chunk_id': chunk['chunk_id']
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
logger.info(f"Generated embedding {i+1}/{len(all_chunks)}")
|
| 90 |
+
|
| 91 |
+
# Small delay to avoid rate limits
|
| 92 |
+
time.sleep(0.1)
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Error processing chunk {i}: {e}")
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
# Store all embeddings in vector database
|
| 99 |
+
if embeddings and texts:
|
| 100 |
+
logger.info(f"Storing {len(embeddings)} embeddings in vector database (session={session_id})")
|
| 101 |
+
# Forward session_id so it is stored with each point
|
| 102 |
+
self.vector_store.add_documents(texts, embeddings, metadata_list, session_id=session_id)
|
| 103 |
+
|
| 104 |
+
logger.info("Document processing completed successfully!")
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error adding documents: {e}")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def make_rag_prompt(self, query: str, context_passages: List[str]) -> str:
|
| 112 |
+
"""Create RAG prompt with the user's specified format"""
|
| 113 |
+
context = "\n\n".join([f"Context {i+1}: {passage}" for i, passage in enumerate(context_passages)])
|
| 114 |
+
|
| 115 |
+
prompt = f"""You are a helpful assistant. Answer the user's question based on the provided context. If the context doesn't contain enough information to answer the question, say so clearly.
|
| 116 |
+
|
| 117 |
+
Context:
|
| 118 |
+
{context}
|
| 119 |
+
|
| 120 |
+
Question: {query}
|
| 121 |
+
|
| 122 |
+
Answer:"""
|
| 123 |
+
|
| 124 |
+
return prompt
|
| 125 |
+
|
| 126 |
+
def generate_answer(self, prompt: str, max_retries: int = 3) -> str:
|
| 127 |
+
"""Generate answer using Gemini with retry logic"""
|
| 128 |
+
for attempt in range(max_retries):
|
| 129 |
+
try:
|
| 130 |
+
response = self.model.generate_content(prompt)
|
| 131 |
+
|
| 132 |
+
if response and response.text:
|
| 133 |
+
return response.text.strip()
|
| 134 |
+
else:
|
| 135 |
+
logger.warning(f"Empty response on attempt {attempt + 1}")
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error generating answer (attempt {attempt + 1}): {e}")
|
| 139 |
+
|
| 140 |
+
if "429" in str(e) or "quota" in str(e).lower():
|
| 141 |
+
if attempt < max_retries - 1:
|
| 142 |
+
wait_time = (2 ** attempt) * 2 # Exponential backoff
|
| 143 |
+
logger.info(f"Rate limit hit, waiting {wait_time} seconds...")
|
| 144 |
+
time.sleep(wait_time)
|
| 145 |
+
else:
|
| 146 |
+
return "I'm sorry, I'm currently experiencing high demand. Please try again in a few minutes."
|
| 147 |
+
elif attempt < max_retries - 1:
|
| 148 |
+
time.sleep(1)
|
| 149 |
+
else:
|
| 150 |
+
return f"I encountered an error while generating the answer: {str(e)}"
|
| 151 |
+
|
| 152 |
+
return "I'm sorry, I couldn't generate an answer at this time. Please try again."
|
| 153 |
+
|
| 154 |
+
def query(self, question: str, top_k: int = 3) -> Dict[str, Any]:
|
| 155 |
+
"""Handle complete RAG query process"""
|
| 156 |
+
try:
|
| 157 |
+
logger.info(f"Processing query: {question}")
|
| 158 |
+
|
| 159 |
+
# Generate query embedding
|
| 160 |
+
query_embedding = self.embedding_manager.generate_embedding(question)
|
| 161 |
+
|
| 162 |
+
# Search for relevant passages
|
| 163 |
+
search_results = self.vector_store.similarity_search(
|
| 164 |
+
query_embedding=query_embedding,
|
| 165 |
+
top_k=top_k
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if not search_results:
|
| 169 |
+
return {
|
| 170 |
+
'answer': "I couldn't find relevant information to answer your question.",
|
| 171 |
+
'sources': [],
|
| 172 |
+
'context_used': []
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Extract context passages and sources
|
| 176 |
+
context_passages = [result.get('chunk', '') for result in search_results]
|
| 177 |
+
sources = [result.get('metadata', {}).get('source', 'Unknown') for result in search_results]
|
| 178 |
+
|
| 179 |
+
# Create RAG prompt
|
| 180 |
+
rag_prompt = self.make_rag_prompt(question, context_passages)
|
| 181 |
+
|
| 182 |
+
# Generate answer
|
| 183 |
+
answer = self.generate_answer(rag_prompt)
|
| 184 |
+
|
| 185 |
+
return {
|
| 186 |
+
'answer': answer,
|
| 187 |
+
'sources': list(set(sources)), # Remove duplicates
|
| 188 |
+
'context_used': context_passages
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Error in query processing: {e}")
|
| 193 |
+
return {
|
| 194 |
+
'answer': f"I encountered an error while processing your question: {str(e)}",
|
| 195 |
+
'sources': [],
|
| 196 |
+
'context_used': []
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
def handle_query(rag_system: RAGSystem, query: str) -> Dict[str, Any]:
|
| 200 |
+
"""Handle a single query through the RAG system"""
|
| 201 |
+
return rag_system.query(query)
|
requirements.txt
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
google-generativeai>=0.3.0
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
python-docx>=0.8.11
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
streamlit>=1.18.0
|
| 10 |
-
typing-extensions>=3.7.4
|
| 11 |
-
tika
|
| 12 |
-
pymupdf
|
|
|
|
| 1 |
google-generativeai>=0.3.0
|
| 2 |
+
python-dotenv>=1.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
pypdf>=3.0.0
|
| 5 |
python-docx>=0.8.11
|
| 6 |
+
qdrant-client>=1.7.0
|
| 7 |
+
streamlit>=1.28.0
|
| 8 |
+
scikit-learn>=1.3.0
|
|
|
|
|
|
|
|
|
|
|
|