dnj0 commited on
Commit
e4ac86d
Β·
verified Β·
1 Parent(s): cffffd0

Update src/vector_store.py

Browse files
Files changed (1) hide show
  1. src/vector_store.py +3 -30
src/vector_store.py CHANGED
@@ -1,7 +1,4 @@
1
- """
2
- Vector Store and Embeddings Module using ChromaDB with sentence-transformers
3
- UPDATED for ChromaDB v0.4.22+ (auto-persist, no manual persist needed)
4
- """
5
  import os
6
  import json
7
  from typing import List, Dict
@@ -12,14 +9,12 @@ from config import CHROMA_DB_PATH, EMBEDDING_MODEL, EMBEDDING_DIM
12
 
13
 
14
  class CLIPEmbedder:
15
- """Custom embedder using sentence-transformers for multimodal content"""
16
  def __init__(self, model_name: str = EMBEDDING_MODEL):
17
  print(f"πŸ”„ Loading embedding model: {model_name}")
18
  self.model = SentenceTransformer(model_name)
19
  print(f"βœ… Model loaded successfully")
20
 
21
  def embed(self, text: str) -> List[float]:
22
- """Generate embedding for text"""
23
  try:
24
  embedding = self.model.encode(text, convert_to_numpy=False)
25
  return embedding.tolist() if hasattr(embedding, 'tolist') else embedding
@@ -28,7 +23,6 @@ class CLIPEmbedder:
28
  return [0.0] * EMBEDDING_DIM
29
 
30
  def embed_batch(self, texts: List[str]) -> List[List[float]]:
31
- """Generate embeddings for batch of texts"""
32
  try:
33
  embeddings = self.model.encode(texts, convert_to_numpy=False)
34
  return [e.tolist() if hasattr(e, 'tolist') else e for e in embeddings]
@@ -38,14 +32,12 @@ class CLIPEmbedder:
38
 
39
 
40
  class VectorStore:
41
- """Vector store manager using ChromaDB (v0.4.22+ with auto-persist)"""
42
  def __init__(self):
43
  self.persist_directory = CHROMA_DB_PATH
44
  self.embedder = CLIPEmbedder()
45
 
46
  print(f"\nπŸ”„ Initializing ChromaDB at: {self.persist_directory}")
47
 
48
- # NEW ChromaDB v0.4.22+ - PersistentClient auto-persists
49
  try:
50
  self.client = chromadb.PersistentClient(
51
  path=self.persist_directory
@@ -58,7 +50,6 @@ class VectorStore:
58
  path=self.persist_directory
59
  )
60
 
61
- # Get or create collection
62
  try:
63
  self.collection = self.client.get_or_create_collection(
64
  name="multimodal_rag",
@@ -73,14 +64,12 @@ class VectorStore:
73
  )
74
 
75
  def add_documents(self, documents: List[Dict], doc_id: str):
76
- """Add documents to vector store"""
77
  texts = []
78
  metadatas = []
79
  ids = []
80
 
81
  print(f"\nπŸ“š Adding documents for: {doc_id}")
82
 
83
- # Add text chunks
84
  if 'text' in documents and documents['text']:
85
  chunks = self._chunk_text(documents['text'], chunk_size=1000, overlap=200)
86
  for idx, chunk in enumerate(chunks):
@@ -93,7 +82,6 @@ class VectorStore:
93
  ids.append(f"{doc_id}_text_{idx}")
94
  print(f" βœ… Text: {len(chunks)} chunks")
95
 
96
- # Add image descriptions and OCR text
97
  if 'images' in documents:
98
  image_count = 0
99
  for idx, image_data in enumerate(documents['images']):
@@ -110,7 +98,6 @@ class VectorStore:
110
  if image_count > 0:
111
  print(f" βœ… Images: {image_count} with OCR text")
112
 
113
- # Add table content
114
  if 'tables' in documents:
115
  table_count = 0
116
  for idx, table_data in enumerate(documents['tables']):
@@ -127,11 +114,9 @@ class VectorStore:
127
  print(f" βœ… Tables: {table_count}")
128
 
129
  if texts:
130
- # Generate embeddings
131
  print(f" πŸ”„ Generating {len(texts)} embeddings...")
132
  embeddings = self.embedder.embed_batch(texts)
133
 
134
- # Add to collection
135
  try:
136
  self.collection.add(
137
  ids=ids,
@@ -140,13 +125,11 @@ class VectorStore:
140
  metadatas=metadatas
141
  )
142
  print(f"βœ… Successfully added {len(texts)} items to vector store")
143
- # Auto-persist happens here
144
  print(f"βœ… Data persisted automatically to: {self.persist_directory}")
145
  except Exception as e:
146
  print(f"❌ Error adding to collection: {e}")
147
 
148
  def search(self, query: str, n_results: int = 5) -> List[Dict]:
149
- """Search vector store for similar documents"""
150
  try:
151
  query_embedding = self.embedder.embed(query)
152
 
@@ -155,7 +138,6 @@ class VectorStore:
155
  n_results=n_results
156
  )
157
 
158
- # Format results
159
  formatted_results = []
160
  if results['documents']:
161
  for i, doc in enumerate(results['documents'][0]):
@@ -175,7 +157,6 @@ class VectorStore:
175
  return []
176
 
177
  def _chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
178
- """Split text into chunks with overlap"""
179
  chunks = []
180
  start = 0
181
  while start < len(text):
@@ -185,7 +166,6 @@ class VectorStore:
185
  return chunks
186
 
187
  def get_collection_info(self) -> Dict:
188
- """Get information about the collection"""
189
  try:
190
  count = self.collection.count()
191
  return {
@@ -199,7 +179,6 @@ class VectorStore:
199
  return {'status': 'error', 'message': str(e)}
200
 
201
  def delete_by_doc_id(self, doc_id: str):
202
- """Delete all documents related to a specific doc_id"""
203
  try:
204
  # Get all IDs with this doc_id
205
  results = self.collection.get(where={'doc_id': doc_id})
@@ -212,17 +191,11 @@ class VectorStore:
212
  print(f"Error deleting documents: {e}")
213
 
214
  def persist(self):
215
- """
216
- No-op for compatibility with older code.
217
- ChromaDB v0.4.22+ uses PersistentClient which auto-persists.
218
- This method kept for backward compatibility.
219
- """
220
- print("βœ… Vector store is using auto-persist (no manual persist needed)")
221
 
222
  def clear_all(self):
223
- """Clear all documents from collection"""
224
  try:
225
- # Delete collection and recreate
226
  self.client.delete_collection(name="multimodal_rag")
227
  self.collection = self.client.get_or_create_collection(
228
  name="multimodal_rag",
 
1
+
 
 
 
2
  import os
3
  import json
4
  from typing import List, Dict
 
9
 
10
 
11
  class CLIPEmbedder:
 
12
  def __init__(self, model_name: str = EMBEDDING_MODEL):
13
  print(f"πŸ”„ Loading embedding model: {model_name}")
14
  self.model = SentenceTransformer(model_name)
15
  print(f"βœ… Model loaded successfully")
16
 
17
  def embed(self, text: str) -> List[float]:
 
18
  try:
19
  embedding = self.model.encode(text, convert_to_numpy=False)
20
  return embedding.tolist() if hasattr(embedding, 'tolist') else embedding
 
23
  return [0.0] * EMBEDDING_DIM
24
 
25
  def embed_batch(self, texts: List[str]) -> List[List[float]]:
 
26
  try:
27
  embeddings = self.model.encode(texts, convert_to_numpy=False)
28
  return [e.tolist() if hasattr(e, 'tolist') else e for e in embeddings]
 
32
 
33
 
34
  class VectorStore:
 
35
  def __init__(self):
36
  self.persist_directory = CHROMA_DB_PATH
37
  self.embedder = CLIPEmbedder()
38
 
39
  print(f"\nπŸ”„ Initializing ChromaDB at: {self.persist_directory}")
40
 
 
41
  try:
42
  self.client = chromadb.PersistentClient(
43
  path=self.persist_directory
 
50
  path=self.persist_directory
51
  )
52
 
 
53
  try:
54
  self.collection = self.client.get_or_create_collection(
55
  name="multimodal_rag",
 
64
  )
65
 
66
  def add_documents(self, documents: List[Dict], doc_id: str):
 
67
  texts = []
68
  metadatas = []
69
  ids = []
70
 
71
  print(f"\nπŸ“š Adding documents for: {doc_id}")
72
 
 
73
  if 'text' in documents and documents['text']:
74
  chunks = self._chunk_text(documents['text'], chunk_size=1000, overlap=200)
75
  for idx, chunk in enumerate(chunks):
 
82
  ids.append(f"{doc_id}_text_{idx}")
83
  print(f" βœ… Text: {len(chunks)} chunks")
84
 
 
85
  if 'images' in documents:
86
  image_count = 0
87
  for idx, image_data in enumerate(documents['images']):
 
98
  if image_count > 0:
99
  print(f" βœ… Images: {image_count} with OCR text")
100
 
 
101
  if 'tables' in documents:
102
  table_count = 0
103
  for idx, table_data in enumerate(documents['tables']):
 
114
  print(f" βœ… Tables: {table_count}")
115
 
116
  if texts:
 
117
  print(f" πŸ”„ Generating {len(texts)} embeddings...")
118
  embeddings = self.embedder.embed_batch(texts)
119
 
 
120
  try:
121
  self.collection.add(
122
  ids=ids,
 
125
  metadatas=metadatas
126
  )
127
  print(f"βœ… Successfully added {len(texts)} items to vector store")
 
128
  print(f"βœ… Data persisted automatically to: {self.persist_directory}")
129
  except Exception as e:
130
  print(f"❌ Error adding to collection: {e}")
131
 
132
  def search(self, query: str, n_results: int = 5) -> List[Dict]:
 
133
  try:
134
  query_embedding = self.embedder.embed(query)
135
 
 
138
  n_results=n_results
139
  )
140
 
 
141
  formatted_results = []
142
  if results['documents']:
143
  for i, doc in enumerate(results['documents'][0]):
 
157
  return []
158
 
159
  def _chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
 
160
  chunks = []
161
  start = 0
162
  while start < len(text):
 
166
  return chunks
167
 
168
  def get_collection_info(self) -> Dict:
 
169
  try:
170
  count = self.collection.count()
171
  return {
 
179
  return {'status': 'error', 'message': str(e)}
180
 
181
  def delete_by_doc_id(self, doc_id: str):
 
182
  try:
183
  # Get all IDs with this doc_id
184
  results = self.collection.get(where={'doc_id': doc_id})
 
191
  print(f"Error deleting documents: {e}")
192
 
193
  def persist(self):
194
+
195
+ print("βœ… Vector store is using auto-persist")
 
 
 
 
196
 
197
  def clear_all(self):
 
198
  try:
 
199
  self.client.delete_collection(name="multimodal_rag")
200
  self.collection = self.client.get_or_create_collection(
201
  name="multimodal_rag",