Spaces:
Build error
Build error
Create utils/vector_store.py
Browse files- utils/vector_store.py +76 -0
utils/vector_store.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/vector_store.py
|
| 2 |
+
import faiss
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
import pickle
|
| 6 |
+
import os
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
class VectorStore:
|
| 11 |
+
def __init__(self, case_id: str):
|
| 12 |
+
self.case_id = case_id
|
| 13 |
+
self.base_path = f"data/cases/{case_id}/vector_store"
|
| 14 |
+
os.makedirs(self.base_path, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
if 'embedder' not in st.session_state:
|
| 17 |
+
st.session_state.embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
| 18 |
+
|
| 19 |
+
self.embedder = st.session_state.embedder
|
| 20 |
+
self.dimension = 384 # Dimension of embeddings
|
| 21 |
+
|
| 22 |
+
self._load_or_create_index()
|
| 23 |
+
|
| 24 |
+
def _load_or_create_index(self):
|
| 25 |
+
"""Load existing index or create new one"""
|
| 26 |
+
index_path = os.path.join(self.base_path, "faiss.index")
|
| 27 |
+
metadata_path = os.path.join(self.base_path, "metadata.pkl")
|
| 28 |
+
|
| 29 |
+
if os.path.exists(index_path) and os.path.exists(metadata_path):
|
| 30 |
+
self.index = faiss.read_index(index_path)
|
| 31 |
+
with open(metadata_path, 'rb') as f:
|
| 32 |
+
self.metadata = pickle.load(f)
|
| 33 |
+
else:
|
| 34 |
+
self.index = faiss.IndexFlatL2(self.dimension)
|
| 35 |
+
self.metadata = []
|
| 36 |
+
|
| 37 |
+
def add_texts(self, texts: List[str], metadatas: Optional[List[Dict]] = None):
|
| 38 |
+
"""Add texts to the vector store"""
|
| 39 |
+
if not texts:
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
embeddings = self.embedder.encode(texts)
|
| 43 |
+
|
| 44 |
+
self.index.add(np.array(embeddings).astype('float32'))
|
| 45 |
+
|
| 46 |
+
if metadatas:
|
| 47 |
+
self.metadata.extend(metadatas)
|
| 48 |
+
else:
|
| 49 |
+
self.metadata.extend([{} for _ in texts])
|
| 50 |
+
|
| 51 |
+
self._save_index()
|
| 52 |
+
|
| 53 |
+
def similarity_search(self, query: str, k: int = 5) -> List[Dict]:
|
| 54 |
+
"""Search for similar texts"""
|
| 55 |
+
query_embedding = self.embedder.encode([query])
|
| 56 |
+
D, I = self.index.search(np.array(query_embedding).astype('float32'), k)
|
| 57 |
+
|
| 58 |
+
results = []
|
| 59 |
+
for score, idx in zip(D[0], I[0]):
|
| 60 |
+
if idx < len(self.metadata):
|
| 61 |
+
result = {
|
| 62 |
+
"metadata": self.metadata[idx],
|
| 63 |
+
"score": float(score)
|
| 64 |
+
}
|
| 65 |
+
results.append(result)
|
| 66 |
+
|
| 67 |
+
return results
|
| 68 |
+
|
| 69 |
+
def _save_index(self):
|
| 70 |
+
"""Save index and metadata"""
|
| 71 |
+
index_path = os.path.join(self.base_path, "faiss.index")
|
| 72 |
+
metadata_path = os.path.join(self.base_path, "metadata.pkl")
|
| 73 |
+
|
| 74 |
+
faiss.write_index(self.index, index_path)
|
| 75 |
+
with open(metadata_path, 'wb') as f:
|
| 76 |
+
pickle.dump(self.metadata, f)
|