cryogenic22 commited on
Commit
d2ccb82
·
verified ·
1 Parent(s): 3ab44aa

Update utils/vector_store.py

Browse files
Files changed (1) hide show
  1. utils/vector_store.py +191 -81
utils/vector_store.py CHANGED
@@ -2,106 +2,216 @@ import os
2
  import pickle
3
  from typing import List, Dict, Any
4
  from sentence_transformers import SentenceTransformer, util
5
-
 
6
 
7
  class VectorStore:
8
- def __init__(self, storage_path: str = "data/vector_store"):
9
- """
10
- Initialize VectorStore.
11
-
12
- Args:
13
- storage_path (str): Path to store vectorized documents.
14
- """
15
  self.storage_path = storage_path
16
  os.makedirs(storage_path, exist_ok=True)
17
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
 
18
  self.vectors = self._load_vectors()
 
 
19
 
20
  def _load_vectors(self) -> List[Dict]:
21
- """
22
- Load stored vectors from the file system.
23
-
24
- Returns:
25
- List[Dict]: List of stored vectorized documents.
26
- """
27
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
28
- if os.path.exists(vector_file):
29
- with open(vector_file, "rb") as f:
30
- return pickle.load(f)
31
- return []
 
 
 
 
32
 
33
  def _save_vectors(self):
34
- """
35
- Save the current vectors to the file system.
36
- """
37
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
38
- with open(vector_file, "wb") as f:
39
- pickle.dump(self.vectors, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any]):
42
- """
43
- Add a new document to the vector store.
 
 
 
 
 
 
 
 
 
44
 
45
- Args:
46
- doc_id (str): Unique document identifier.
47
- text (str): Full text of the document.
48
- metadata (Dict[str, Any]): Metadata associated with the document.
49
- """
50
- vector = self.model.encode(text, convert_to_tensor=True)
51
- self.vectors.append({"doc_id": doc_id, "vector": vector, "text": text, "metadata": metadata})
 
 
 
 
 
 
 
 
 
 
 
 
52
  self._save_vectors()
53
 
54
- def similarity_search(self, query: str, top_k: int = 5) -> List[Dict]:
55
- """
56
- Perform a similarity search for the query against stored vectors.
57
-
58
- Args:
59
- query (str): Query string to search for.
60
- top_k (int): Number of top results to return.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- Returns:
63
- List[Dict]: List of the most similar documents.
64
- """
 
 
 
 
 
 
65
  query_vector = self.model.encode(query, convert_to_tensor=True)
 
 
66
  results = []
67
  for doc in self.vectors:
68
- similarity_score = util.pytorch_cos_sim(query_vector, doc["vector"]).item()
69
- results.append({"doc_id": doc["doc_id"], "text": doc["text"], "metadata": doc["metadata"], "score": similarity_score})
70
- results = sorted(results, key=lambda x: x["score"], reverse=True)
71
- return results[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def chat_with_context(self, query: str, context: str) -> str:
74
- """
75
- Generate a response to the query using the provided context.
76
-
77
- Args:
78
- query (str): Query string from the user.
79
- context (str): Context string from relevant documents.
80
-
81
- Returns:
82
- str: Generated response.
83
- """
84
- # Combine query and context for the final prompt
85
- combined_input = f"""
86
- Context:
87
- {context}
88
-
89
- Question:
90
- {query}
91
-
92
- Please provide a detailed and accurate response.
93
- """
94
-
95
- try:
96
- results = self.similarity_search(query, top_k=3)
97
- if not results:
98
- return "No relevant context found. Please upload more documents or refine your query."
99
-
100
- # Use the top result for a response simulation or pass to an LLM
101
- return (
102
- f"Based on the context:\n\n"
103
- f"{results[0]['text'][:500]}...\n\n"
104
- f"Response: The query '{query}' relates to the provided context."
105
  )
106
- except Exception as e:
107
- return f"Error generating response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pickle
3
  from typing import List, Dict, Any
4
  from sentence_transformers import SentenceTransformer, util
5
+ import numpy as np
6
+ from datetime import datetime
7
 
8
  class VectorStore:
9
+ def __init__(self, storage_path: str = "data/vector_store", model_name: str = 'all-MiniLM-L6-v2'):
10
+ """Initialize VectorStore with improved chunk handling."""
 
 
 
 
 
11
  self.storage_path = storage_path
12
  os.makedirs(storage_path, exist_ok=True)
13
+
14
+ self.model = SentenceTransformer(model_name)
15
  self.vectors = self._load_vectors()
16
+ self.chunk_size = 512 # Optimal size for most transformer models
17
+ self.chunk_overlap = 50 # Overlap to maintain context
18
 
19
  def _load_vectors(self) -> List[Dict]:
20
+ """Load vectors with error handling and versioning."""
 
 
 
 
 
21
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
22
+ try:
23
+ if os.path.exists(vector_file):
24
+ with open(vector_file, "rb") as f:
25
+ vectors = pickle.load(f)
26
+ return vectors if isinstance(vectors, list) else []
27
+ except Exception as e:
28
+ print(f"Error loading vectors: {e}")
29
+ return []
30
 
31
  def _save_vectors(self):
32
+ """Save vectors with backup and atomic write."""
 
 
33
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
34
+ backup_file = vector_file + ".backup"
35
+
36
+ # Create backup of existing vectors
37
+ if os.path.exists(vector_file):
38
+ os.replace(vector_file, backup_file)
39
+
40
+ try:
41
+ with open(vector_file, "wb") as f:
42
+ pickle.dump(self.vectors, f)
43
+ # Remove backup after successful save
44
+ if os.path.exists(backup_file):
45
+ os.remove(backup_file)
46
+ except Exception as e:
47
+ print(f"Error saving vectors: {e}")
48
+ # Restore from backup if save failed
49
+ if os.path.exists(backup_file):
50
+ os.replace(backup_file, vector_file)
51
 
52
  def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any]):
53
+ """Add document with improved chunking and metadata."""
54
+ # Create chunks with overlap
55
+ chunks = self._create_chunks(text)
56
+
57
+ # Add timestamp and chunk info to metadata
58
+ base_metadata = {
59
+ **metadata,
60
+ "added_at": datetime.now().isoformat(),
61
+ "doc_id": doc_id,
62
+ "total_chunks": len(chunks)
63
+ }
64
 
65
+ # Process and store chunks
66
+ for chunk_idx, chunk in enumerate(chunks):
67
+ chunk_metadata = {
68
+ **base_metadata,
69
+ "chunk_idx": chunk_idx,
70
+ "chunk_text": chunk[:200] # Store preview of chunk text
71
+ }
72
+
73
+ # Encode chunk
74
+ vector = self.model.encode(chunk, convert_to_tensor=True)
75
+
76
+ # Store chunk with metadata
77
+ self.vectors.append({
78
+ "doc_id": f"{doc_id}_chunk_{chunk_idx}",
79
+ "vector": vector,
80
+ "text": chunk,
81
+ "metadata": chunk_metadata
82
+ })
83
+
84
  self._save_vectors()
85
 
86
+ def _create_chunks(self, text: str) -> List[str]:
87
+ """Create overlapping chunks with improved sentence boundary handling."""
88
+ # Split into sentences first
89
+ sentences = [s.strip() for s in text.split('.') if s.strip()]
90
+ chunks = []
91
+ current_chunk = []
92
+ current_size = 0
93
+
94
+ for sentence in sentences:
95
+ sentence_size = len(sentence.split())
96
+
97
+ if current_size + sentence_size > self.chunk_size:
98
+ # Save current chunk
99
+ if current_chunk:
100
+ chunks.append(' '.join(current_chunk))
101
+ # Start new chunk with overlap
102
+ overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
103
+ current_chunk = current_chunk[overlap_start:] + [sentence]
104
+ current_size = sum(len(s.split()) for s in current_chunk)
105
+ else:
106
+ current_chunk.append(sentence)
107
+ current_size += sentence_size
108
+
109
+ # Add final chunk
110
+ if current_chunk:
111
+ chunks.append(' '.join(current_chunk))
112
+
113
+ return chunks
114
 
115
+ def similarity_search(
116
+ self,
117
+ query: str,
118
+ k: int = 5,
119
+ threshold: float = 0.5,
120
+ filter_criteria: Dict[str, List] = None
121
+ ) -> List[Dict]:
122
+ """Enhanced similarity search with filtering and re-ranking."""
123
+ # Encode query
124
  query_vector = self.model.encode(query, convert_to_tensor=True)
125
+
126
+ # Calculate similarities and filter results
127
  results = []
128
  for doc in self.vectors:
129
+ # Apply filters if specified
130
+ if filter_criteria:
131
+ skip = False
132
+ for key, values in filter_criteria.items():
133
+ doc_value = self._get_nested_dict_value(doc["metadata"], key)
134
+ if doc_value not in values:
135
+ skip = True
136
+ break
137
+ if skip:
138
+ continue
139
+
140
+ # Calculate similarity
141
+ similarity = util.pytorch_cos_sim(query_vector, doc["vector"]).item()
142
+ if similarity >= threshold:
143
+ results.append({
144
+ **doc,
145
+ "score": similarity
146
+ })
147
+
148
+ # Sort by similarity score
149
+ results.sort(key=lambda x: x["score"], reverse=True)
150
+
151
+ # Re-rank results based on chunk position and metadata
152
+ reranked_results = self._rerank_results(results[:k*2], query)
153
+
154
+ return reranked_results[:k]
155
 
156
+ def _rerank_results(self, results: List[Dict], query: str) -> List[Dict]:
157
+ """Re-rank results considering chunk position and metadata relevance."""
158
+ for result in results:
159
+ # Adjust score based on chunk position
160
+ chunk_idx = result["metadata"].get("chunk_idx", 0)
161
+ total_chunks = result["metadata"].get("total_chunks", 1)
162
+ position_score = 1 - (chunk_idx / total_chunks) # Favor earlier chunks
163
+
164
+ # Adjust score based on metadata relevance
165
+ metadata_score = self._calculate_metadata_relevance(result["metadata"], query)
166
+
167
+ # Combine scores
168
+ result["final_score"] = (
169
+ result["score"] * 0.6 + # Base similarity
170
+ position_score * 0.2 + # Position importance
171
+ metadata_score * 0.2 # Metadata relevance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  )
173
+
174
+ return sorted(results, key=lambda x: x["final_score"], reverse=True)
175
+
176
+ def _calculate_metadata_relevance(self, metadata: Dict, query: str) -> float:
177
+ """Calculate relevance score based on metadata matching."""
178
+ relevance_score = 0.0
179
+ query_lower = query.lower()
180
+
181
+ # Check for metadata field matches
182
+ for key, value in metadata.items():
183
+ if isinstance(value, str):
184
+ if value.lower() in query_lower:
185
+ relevance_score += 0.2
186
+ elif query_lower in value.lower():
187
+ relevance_score += 0.1
188
+
189
+ return min(1.0, relevance_score) # Normalize to [0,1]
190
+
191
+ def _get_nested_dict_value(self, d: Dict, key_path: str):
192
+ """Get value from nested dictionary using dot notation."""
193
+ keys = key_path.split('.')
194
+ value = d
195
+ for key in keys:
196
+ if isinstance(value, dict):
197
+ value = value.get(key)
198
+ else:
199
+ return None
200
+ return value
201
+
202
+ def get_document_embeddings(self, doc_id: str) -> List[Dict]:
203
+ """Retrieve all embeddings for a specific document."""
204
+ return [doc for doc in self.vectors if doc["metadata"]["doc_id"] == doc_id]
205
+
206
+ def delete_document(self, doc_id: str):
207
+ """Delete all chunks associated with a document."""
208
+ self.vectors = [doc for doc in self.vectors
209
+ if doc["metadata"]["doc_id"] != doc_id]
210
+ self._save_vectors()
211
+
212
+ def update_metadata(self, doc_id: str, metadata_updates: Dict):
213
+ """Update metadata for all chunks of a document."""
214
+ for doc in self.vectors:
215
+ if doc["metadata"]["doc_id"] == doc_id:
216
+ doc["metadata"].update(metadata_updates)
217
+ self._save_vectors()