Update buffalo_rag/vector_store/db.py
Browse files
buffalo_rag/vector_store/db.py
CHANGED
|
@@ -17,21 +17,15 @@ class VectorStore:
|
|
| 17 |
self.chunk_ids = []
|
| 18 |
self.chunks = {}
|
| 19 |
|
| 20 |
-
# Load embedding model
|
| 21 |
self.model = SentenceTransformer(model_name)
|
| 22 |
-
|
| 23 |
-
# Load reranker model
|
| 24 |
self.reranker = CrossEncoder(reranker_name)
|
| 25 |
|
| 26 |
-
# Load or create index
|
| 27 |
self.load_or_create_index()
|
| 28 |
|
| 29 |
def load_or_create_index(self) -> None:
|
| 30 |
-
"""Load existing index or create a new one."""
|
| 31 |
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
|
| 32 |
|
| 33 |
if os.path.exists(index_path):
|
| 34 |
-
# Load existing index
|
| 35 |
with open(index_path, 'rb') as f:
|
| 36 |
data = pickle.load(f)
|
| 37 |
self.index = data['index']
|
|
@@ -39,7 +33,6 @@ class VectorStore:
|
|
| 39 |
self.chunks = data['chunks']
|
| 40 |
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
|
| 41 |
else:
|
| 42 |
-
# Create new index
|
| 43 |
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
|
| 44 |
if os.path.exists(embeddings_path):
|
| 45 |
self.create_index()
|
|
@@ -53,22 +46,18 @@ class VectorStore:
|
|
| 53 |
with open(embeddings_path, 'rb') as f:
|
| 54 |
embedding_map = pickle.load(f)
|
| 55 |
|
| 56 |
-
# Extract embeddings and chunk IDs
|
| 57 |
chunk_ids = list(embedding_map.keys())
|
| 58 |
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
|
| 59 |
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
|
| 60 |
|
| 61 |
-
# Create FAISS index
|
| 62 |
dimension = embeddings.shape[1]
|
| 63 |
index = faiss.IndexFlatL2(dimension)
|
| 64 |
index.add(embeddings.astype(np.float32))
|
| 65 |
|
| 66 |
-
# Save index and metadata
|
| 67 |
self.index = index
|
| 68 |
self.chunk_ids = chunk_ids
|
| 69 |
self.chunks = chunks
|
| 70 |
|
| 71 |
-
# Save to disk
|
| 72 |
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
|
| 73 |
pickle.dump({
|
| 74 |
'index': index,
|
|
@@ -83,24 +72,20 @@ class VectorStore:
|
|
| 83 |
k: int = 5,
|
| 84 |
filter_categories: Optional[List[str]] = None,
|
| 85 |
rerank: bool = True) -> List[Dict[str, Any]]:
|
| 86 |
-
|
| 87 |
if self.index is None:
|
| 88 |
print("No index available. Please create an index first.")
|
| 89 |
return []
|
| 90 |
|
| 91 |
-
# Create query embedding
|
| 92 |
query_embedding = self.model.encode([query])[0]
|
| 93 |
|
| 94 |
-
# Search index
|
| 95 |
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
|
| 96 |
|
| 97 |
-
# Get results
|
| 98 |
results = []
|
| 99 |
for i, idx in enumerate(I[0]):
|
| 100 |
chunk_id = self.chunk_ids[idx]
|
| 101 |
chunk = self.chunks[chunk_id]
|
| 102 |
|
| 103 |
-
# Apply category filter if specified
|
| 104 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
| 105 |
continue
|
| 106 |
|
|
@@ -111,22 +96,16 @@ class VectorStore:
|
|
| 111 |
}
|
| 112 |
results.append(result)
|
| 113 |
|
| 114 |
-
# Rerank results if requested
|
| 115 |
if rerank and results:
|
| 116 |
-
# Prepare pairs for reranking
|
| 117 |
pairs = [(query, result['chunk']['content']) for result in results]
|
| 118 |
-
|
| 119 |
-
# Get reranking scores
|
| 120 |
rerank_scores = self.reranker.predict(pairs)
|
| 121 |
|
| 122 |
-
# Update scores and sort
|
| 123 |
for i, score in enumerate(rerank_scores):
|
| 124 |
results[i]['rerank_score'] = float(score)
|
| 125 |
|
| 126 |
-
# Sort by rerank score
|
| 127 |
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
|
| 128 |
|
| 129 |
-
# Limit to k results
|
| 130 |
results = results[:k]
|
| 131 |
|
| 132 |
return results
|
|
@@ -135,29 +114,22 @@ class VectorStore:
|
|
| 135 |
query: str,
|
| 136 |
k: int = 5,
|
| 137 |
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
| 138 |
-
"""Combine dense vector search with BM25-style keyword matching."""
|
| 139 |
-
# Get vector search results
|
| 140 |
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
|
| 141 |
|
| 142 |
-
# Simple keyword matching (simulating BM25)
|
| 143 |
keywords = query.lower().split()
|
| 144 |
-
|
| 145 |
-
# Score all chunks by keyword presence
|
| 146 |
keyword_scores = {}
|
|
|
|
| 147 |
for chunk_id, chunk_data in self.chunks.items():
|
| 148 |
chunk = chunk_data
|
| 149 |
content = (chunk['title'] + " " + chunk['content']).lower()
|
| 150 |
|
| 151 |
-
# Count keyword matches
|
| 152 |
score = sum(content.count(keyword) for keyword in keywords)
|
| 153 |
|
| 154 |
-
# Apply category filter if specified
|
| 155 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
| 156 |
continue
|
| 157 |
|
| 158 |
keyword_scores[chunk_id] = score
|
| 159 |
|
| 160 |
-
# Get top keyword matches
|
| 161 |
keyword_results = sorted(
|
| 162 |
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
|
| 163 |
for chunk_id, score in keyword_scores.items() if score > 0],
|
|
@@ -165,49 +137,28 @@ class VectorStore:
|
|
| 165 |
reverse=True
|
| 166 |
)[:k]
|
| 167 |
|
| 168 |
-
# Combine results (remove duplicates)
|
| 169 |
seen_ids = set()
|
| 170 |
combined_results = []
|
| 171 |
|
| 172 |
-
# Add vector results first
|
| 173 |
for result in vector_results:
|
| 174 |
combined_results.append(result)
|
| 175 |
seen_ids.add(result['chunk_id'])
|
| 176 |
|
| 177 |
-
# Add keyword results if not already added
|
| 178 |
for result in keyword_results:
|
| 179 |
if result['chunk_id'] not in seen_ids:
|
| 180 |
combined_results.append(result)
|
| 181 |
seen_ids.add(result['chunk_id'])
|
| 182 |
|
| 183 |
-
# Limit to k results
|
| 184 |
combined_results = combined_results[:k]
|
| 185 |
|
| 186 |
-
# Rerank final results
|
| 187 |
if combined_results:
|
| 188 |
-
# Prepare pairs for reranking
|
| 189 |
pairs = [(query, result['chunk']['content']) for result in combined_results]
|
| 190 |
|
| 191 |
-
# Get reranking scores
|
| 192 |
rerank_scores = self.reranker.predict(pairs)
|
| 193 |
|
| 194 |
-
# Update scores and sort
|
| 195 |
for i, score in enumerate(rerank_scores):
|
| 196 |
combined_results[i]['rerank_score'] = float(score)
|
| 197 |
|
| 198 |
-
# Sort by rerank score
|
| 199 |
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
|
| 200 |
|
| 201 |
-
return combined_results
|
| 202 |
-
|
| 203 |
-
# Example usage
|
| 204 |
-
if __name__ == "__main__":
|
| 205 |
-
vector_store = VectorStore()
|
| 206 |
-
results = vector_store.hybrid_search("How do I apply for OPT?")
|
| 207 |
-
|
| 208 |
-
print(f"Found {len(results)} results")
|
| 209 |
-
for i, result in enumerate(results[:3]):
|
| 210 |
-
print(f"Result {i+1}: {result['chunk']['title']}")
|
| 211 |
-
print(f"Score: {result.get('rerank_score', result['score'])}")
|
| 212 |
-
print(f"Content: {result['chunk']['content'][:100]}...")
|
| 213 |
-
print()
|
|
|
|
| 17 |
self.chunk_ids = []
|
| 18 |
self.chunks = {}
|
| 19 |
|
|
|
|
| 20 |
self.model = SentenceTransformer(model_name)
|
|
|
|
|
|
|
| 21 |
self.reranker = CrossEncoder(reranker_name)
|
| 22 |
|
|
|
|
| 23 |
self.load_or_create_index()
|
| 24 |
|
| 25 |
def load_or_create_index(self) -> None:
|
|
|
|
| 26 |
index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl')
|
| 27 |
|
| 28 |
if os.path.exists(index_path):
|
|
|
|
| 29 |
with open(index_path, 'rb') as f:
|
| 30 |
data = pickle.load(f)
|
| 31 |
self.index = data['index']
|
|
|
|
| 33 |
self.chunks = data['chunks']
|
| 34 |
print(f"Loaded existing index with {len(self.chunk_ids)} chunks")
|
| 35 |
else:
|
|
|
|
| 36 |
embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl')
|
| 37 |
if os.path.exists(embeddings_path):
|
| 38 |
self.create_index()
|
|
|
|
| 46 |
with open(embeddings_path, 'rb') as f:
|
| 47 |
embedding_map = pickle.load(f)
|
| 48 |
|
|
|
|
| 49 |
chunk_ids = list(embedding_map.keys())
|
| 50 |
embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids])
|
| 51 |
chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids}
|
| 52 |
|
|
|
|
| 53 |
dimension = embeddings.shape[1]
|
| 54 |
index = faiss.IndexFlatL2(dimension)
|
| 55 |
index.add(embeddings.astype(np.float32))
|
| 56 |
|
|
|
|
| 57 |
self.index = index
|
| 58 |
self.chunk_ids = chunk_ids
|
| 59 |
self.chunks = chunks
|
| 60 |
|
|
|
|
| 61 |
with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f:
|
| 62 |
pickle.dump({
|
| 63 |
'index': index,
|
|
|
|
| 72 |
k: int = 5,
|
| 73 |
filter_categories: Optional[List[str]] = None,
|
| 74 |
rerank: bool = True) -> List[Dict[str, Any]]:
|
| 75 |
+
|
| 76 |
if self.index is None:
|
| 77 |
print("No index available. Please create an index first.")
|
| 78 |
return []
|
| 79 |
|
|
|
|
| 80 |
query_embedding = self.model.encode([query])[0]
|
| 81 |
|
|
|
|
| 82 |
D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids)))
|
| 83 |
|
|
|
|
| 84 |
results = []
|
| 85 |
for i, idx in enumerate(I[0]):
|
| 86 |
chunk_id = self.chunk_ids[idx]
|
| 87 |
chunk = self.chunks[chunk_id]
|
| 88 |
|
|
|
|
| 89 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
| 90 |
continue
|
| 91 |
|
|
|
|
| 96 |
}
|
| 97 |
results.append(result)
|
| 98 |
|
|
|
|
| 99 |
if rerank and results:
|
|
|
|
| 100 |
pairs = [(query, result['chunk']['content']) for result in results]
|
| 101 |
+
|
|
|
|
| 102 |
rerank_scores = self.reranker.predict(pairs)
|
| 103 |
|
|
|
|
| 104 |
for i, score in enumerate(rerank_scores):
|
| 105 |
results[i]['rerank_score'] = float(score)
|
| 106 |
|
|
|
|
| 107 |
results = sorted(results, key=lambda x: x['rerank_score'], reverse=True)
|
| 108 |
|
|
|
|
| 109 |
results = results[:k]
|
| 110 |
|
| 111 |
return results
|
|
|
|
| 114 |
query: str,
|
| 115 |
k: int = 5,
|
| 116 |
filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
| 117 |
vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False)
|
| 118 |
|
|
|
|
| 119 |
keywords = query.lower().split()
|
|
|
|
|
|
|
| 120 |
keyword_scores = {}
|
| 121 |
+
|
| 122 |
for chunk_id, chunk_data in self.chunks.items():
|
| 123 |
chunk = chunk_data
|
| 124 |
content = (chunk['title'] + " " + chunk['content']).lower()
|
| 125 |
|
|
|
|
| 126 |
score = sum(content.count(keyword) for keyword in keywords)
|
| 127 |
|
|
|
|
| 128 |
if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories):
|
| 129 |
continue
|
| 130 |
|
| 131 |
keyword_scores[chunk_id] = score
|
| 132 |
|
|
|
|
| 133 |
keyword_results = sorted(
|
| 134 |
[{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]}
|
| 135 |
for chunk_id, score in keyword_scores.items() if score > 0],
|
|
|
|
| 137 |
reverse=True
|
| 138 |
)[:k]
|
| 139 |
|
|
|
|
| 140 |
seen_ids = set()
|
| 141 |
combined_results = []
|
| 142 |
|
|
|
|
| 143 |
for result in vector_results:
|
| 144 |
combined_results.append(result)
|
| 145 |
seen_ids.add(result['chunk_id'])
|
| 146 |
|
|
|
|
| 147 |
for result in keyword_results:
|
| 148 |
if result['chunk_id'] not in seen_ids:
|
| 149 |
combined_results.append(result)
|
| 150 |
seen_ids.add(result['chunk_id'])
|
| 151 |
|
|
|
|
| 152 |
combined_results = combined_results[:k]
|
| 153 |
|
|
|
|
| 154 |
if combined_results:
|
|
|
|
| 155 |
pairs = [(query, result['chunk']['content']) for result in combined_results]
|
| 156 |
|
|
|
|
| 157 |
rerank_scores = self.reranker.predict(pairs)
|
| 158 |
|
|
|
|
| 159 |
for i, score in enumerate(rerank_scores):
|
| 160 |
combined_results[i]['rerank_score'] = float(score)
|
| 161 |
|
|
|
|
| 162 |
combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True)
|
| 163 |
|
| 164 |
+
return combined_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|