cryogenic22 commited on
Commit
3d508bf
·
verified ·
1 Parent(s): b171d30

Update utils/vector_store.py

Browse files
Files changed (1) hide show
  1. utils/vector_store.py +82 -100
utils/vector_store.py CHANGED
@@ -4,141 +4,123 @@ from typing import List, Dict, Any
4
  from sentence_transformers import SentenceTransformer, util
5
  import numpy as np
6
  from datetime import datetime
7
-
8
 
9
  class VectorStore:
10
- def __init__(self, storage_path: str = "data/vector_store", model_name: str = 'all-MiniLM-L6-v2'):
11
- """Initialize VectorStore with improved chunk handling."""
12
  self.storage_path = storage_path
13
  os.makedirs(storage_path, exist_ok=True)
 
 
 
 
14
 
15
- self.model = SentenceTransformer(model_name)
16
- self.vectors = self._load_vectors()
17
- self.chunk_size = 512 # Optimal size for most transformer models
18
- self.chunk_overlap = 50 # Overlap to maintain context
19
-
20
- def _load_vectors(self) -> List[Dict]:
21
- """Load vectors with error handling and versioning."""
22
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
23
  try:
24
  if os.path.exists(vector_file):
25
  with open(vector_file, "rb") as f:
26
- vectors = pickle.load(f)
27
- return vectors if isinstance(vectors, list) else []
 
28
  except Exception as e:
29
- print(f"Error loading vectors: {e}")
30
- return []
31
 
32
  def _save_vectors(self):
33
- """Save vectors with backup and atomic write."""
34
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
35
- backup_file = vector_file + ".backup"
36
-
37
- # Create backup of existing vectors
38
- if os.path.exists(vector_file):
39
- os.replace(vector_file, backup_file)
40
-
41
  try:
42
  with open(vector_file, "wb") as f:
43
  pickle.dump(self.vectors, f)
44
- # Remove backup after successful save
45
- if os.path.exists(backup_file):
46
- os.remove(backup_file)
47
  except Exception as e:
48
- print(f"Error saving vectors: {e}")
49
- # Restore from backup if save failed
50
- if os.path.exists(backup_file):
51
- os.replace(backup_file, vector_file)
52
-
53
- def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any]):
54
- """Add document with improved chunking and metadata."""
55
- # Create chunks with overlap
56
- chunks = self._create_chunks(text)
57
-
58
- # Add timestamp and chunk info to metadata
59
- base_metadata = {
60
- **metadata,
61
- "added_at": datetime.now().isoformat(),
62
- "doc_id": doc_id,
63
- "total_chunks": len(chunks)
64
- }
65
-
66
- # Process and store chunks
67
- for chunk_idx, chunk in enumerate(chunks):
68
- chunk_metadata = {
69
- **base_metadata,
70
- "chunk_idx": chunk_idx,
71
- "chunk_text": chunk[:200] # Store preview of chunk text
72
- }
73
-
74
- # Encode chunk
75
- vector = self.model.encode(chunk, convert_to_tensor=True)
76
 
77
- # Store chunk with metadata
78
- self.vectors.append({
79
- "doc_id": f"{doc_id}_chunk_{chunk_idx}",
 
 
 
 
 
 
80
  "vector": vector,
81
- "text": chunk,
82
- "metadata": chunk_metadata
83
- })
84
-
85
- self._save_vectors()
86
-
87
- def _create_chunks(self, text: str) -> List[str]:
88
- """Create overlapping chunks with improved sentence boundary handling."""
89
- # Split into sentences first
90
- sentences = [s.strip() for s in text.split('.') if s.strip()]
91
- chunks = []
92
- current_chunk = []
93
- current_size = 0
94
-
95
- for sentence in sentences:
96
- sentence_size = len(sentence.split())
97
-
98
- if current_size + sentence_size > self.chunk_size:
99
- # Save current chunk
100
- if current_chunk:
101
- chunks.append(' '.join(current_chunk))
102
- # Start new chunk with overlap
103
- overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
104
- current_chunk = current_chunk[overlap_start:] + [sentence]
105
- current_size = sum(len(s.split()) for s in current_chunk)
106
- else:
107
- current_chunk.append(sentence)
108
- current_size += sentence_size
109
-
110
- # Add final chunk
111
- if current_chunk:
112
- chunks.append(' '.join(current_chunk))
113
-
114
- return chunks
115
 
116
  def similarity_search(self, query: str, k: int = 3) -> List[Dict]:
117
  """Perform similarity search with error handling."""
118
  try:
119
- # If no vectors are stored yet, return empty list
120
  if not self.vectors:
121
  return []
122
 
 
123
  query_vector = self.model.encode(query, convert_to_tensor=True)
 
 
124
  results = []
125
-
126
  for doc in self.vectors:
127
- similarity = util.pytorch_cos_sim(query_vector, doc["vector"]).item()
128
- results.append({
129
- "text": doc["text"],
130
- "metadata": doc["metadata"],
131
- "score": similarity
132
- })
133
-
134
- # Sort by similarity and return top k
 
 
 
 
 
135
  results.sort(key=lambda x: x["score"], reverse=True)
136
  return results[:k]
137
 
138
  except Exception as e:
139
- print(f"Error in similarity search: {str(e)}")
140
  return []
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def _rerank_results(self, results: List[Dict], query: str) -> List[Dict]:
143
  """Re-rank results considering chunk position and metadata relevance."""
144
  for result in results:
 
4
  from sentence_transformers import SentenceTransformer, util
5
  import numpy as np
6
  from datetime import datetime
7
+ import streamlit as st
8
 
9
  class VectorStore:
10
+ def __init__(self, storage_path: str = "data/vector_store"):
11
+ """Initialize VectorStore with storage management."""
12
  self.storage_path = storage_path
13
  os.makedirs(storage_path, exist_ok=True)
14
+
15
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
16
+ self.vectors = [] # Initialize empty list
17
+ self._load_vectors() # Load any existing vectors
18
 
19
+ def _load_vectors(self):
20
+ """Load stored vectors with error handling."""
 
 
 
 
 
21
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
22
  try:
23
  if os.path.exists(vector_file):
24
  with open(vector_file, "rb") as f:
25
+ self.vectors = pickle.load(f)
26
+ if not isinstance(self.vectors, list):
27
+ self.vectors = []
28
  except Exception as e:
29
+ st.error(f"Error loading vectors: {str(e)}")
30
+ self.vectors = []
31
 
32
  def _save_vectors(self):
33
+ """Save vectors with error handling."""
34
  vector_file = os.path.join(self.storage_path, "vectors.pkl")
 
 
 
 
 
 
35
  try:
36
  with open(vector_file, "wb") as f:
37
  pickle.dump(self.vectors, f)
 
 
 
38
  except Exception as e:
39
+ st.error(f"Error saving vectors: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any] = None):
42
+ """Add a document to the vector store."""
43
+ try:
44
+ # Create vector embedding
45
+ vector = self.model.encode(text, convert_to_tensor=True)
46
+
47
+ # Create document record
48
+ doc_record = {
49
+ "doc_id": doc_id,
50
  "vector": vector,
51
+ "text": text,
52
+ "metadata": metadata or {}
53
+ }
54
+
55
+ # Add to vectors list
56
+ if not isinstance(self.vectors, list):
57
+ self.vectors = []
58
+ self.vectors.append(doc_record)
59
+
60
+ # Save updated vectors
61
+ self._save_vectors()
62
+
63
+ except Exception as e:
64
+ st.error(f"Error adding document to vector store: {str(e)}")
65
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def similarity_search(self, query: str, k: int = 3) -> List[Dict]:
68
  """Perform similarity search with error handling."""
69
  try:
70
+ # Handle empty vectors
71
  if not self.vectors:
72
  return []
73
 
74
+ # Encode query
75
  query_vector = self.model.encode(query, convert_to_tensor=True)
76
+
77
+ # Calculate similarities
78
  results = []
 
79
  for doc in self.vectors:
80
+ try:
81
+ similarity = util.pytorch_cos_sim(query_vector, doc["vector"]).item()
82
+ results.append({
83
+ "doc_id": doc["doc_id"],
84
+ "text": doc["text"],
85
+ "metadata": doc["metadata"],
86
+ "score": float(similarity) # Convert to float for serialization
87
+ })
88
+ except Exception as e:
89
+ st.warning(f"Skipping document due to error: {str(e)}")
90
+ continue
91
+
92
+ # Sort by similarity
93
  results.sort(key=lambda x: x["score"], reverse=True)
94
  return results[:k]
95
 
96
  except Exception as e:
97
+ st.error(f"Error in similarity search: {str(e)}")
98
  return []
99
 
100
+ def get_document(self, doc_id: str) -> Dict:
101
+ """Retrieve a document by ID."""
102
+ try:
103
+ for doc in self.vectors:
104
+ if doc["doc_id"] == doc_id:
105
+ return {
106
+ "doc_id": doc["doc_id"],
107
+ "text": doc["text"],
108
+ "metadata": doc["metadata"]
109
+ }
110
+ return None
111
+ except Exception as e:
112
+ st.error(f"Error retrieving document: {str(e)}")
113
+ return None
114
+
115
+ def clear(self):
116
+ """Clear all vectors."""
117
+ self.vectors = []
118
+ self._save_vectors()
119
+
120
+ def __len__(self):
121
+ """Get number of documents in store."""
122
+ return len(self.vectors) if self.vectors is not None else 0
123
+
124
  def _rerank_results(self, results: List[Dict], query: str) -> List[Dict]:
125
  """Re-rank results considering chunk position and metadata relevance."""
126
  for result in results: