cryogenic22 commited on
Commit
8a2ab7f
·
verified ·
1 Parent(s): 6ac8bc4

Create utils/vector_store.py

Browse files
Files changed (1) hide show
  1. utils/vector_store.py +76 -0
utils/vector_store.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/vector_store.py
2
+ import faiss
3
+ import numpy as np
4
+ from typing import List, Dict, Optional
5
+ import pickle
6
+ import os
7
+ from sentence_transformers import SentenceTransformer
8
+ import streamlit as st
9
+
10
+ class VectorStore:
11
+ def __init__(self, case_id: str):
12
+ self.case_id = case_id
13
+ self.base_path = f"data/cases/{case_id}/vector_store"
14
+ os.makedirs(self.base_path, exist_ok=True)
15
+
16
+ if 'embedder' not in st.session_state:
17
+ st.session_state.embedder = SentenceTransformer('all-MiniLM-L6-v2')
18
+
19
+ self.embedder = st.session_state.embedder
20
+ self.dimension = 384 # Dimension of embeddings
21
+
22
+ self._load_or_create_index()
23
+
24
+ def _load_or_create_index(self):
25
+ """Load existing index or create new one"""
26
+ index_path = os.path.join(self.base_path, "faiss.index")
27
+ metadata_path = os.path.join(self.base_path, "metadata.pkl")
28
+
29
+ if os.path.exists(index_path) and os.path.exists(metadata_path):
30
+ self.index = faiss.read_index(index_path)
31
+ with open(metadata_path, 'rb') as f:
32
+ self.metadata = pickle.load(f)
33
+ else:
34
+ self.index = faiss.IndexFlatL2(self.dimension)
35
+ self.metadata = []
36
+
37
+ def add_texts(self, texts: List[str], metadatas: Optional[List[Dict]] = None):
38
+ """Add texts to the vector store"""
39
+ if not texts:
40
+ return
41
+
42
+ embeddings = self.embedder.encode(texts)
43
+
44
+ self.index.add(np.array(embeddings).astype('float32'))
45
+
46
+ if metadatas:
47
+ self.metadata.extend(metadatas)
48
+ else:
49
+ self.metadata.extend([{} for _ in texts])
50
+
51
+ self._save_index()
52
+
53
+ def similarity_search(self, query: str, k: int = 5) -> List[Dict]:
54
+ """Search for similar texts"""
55
+ query_embedding = self.embedder.encode([query])
56
+ D, I = self.index.search(np.array(query_embedding).astype('float32'), k)
57
+
58
+ results = []
59
+ for score, idx in zip(D[0], I[0]):
60
+ if idx < len(self.metadata):
61
+ result = {
62
+ "metadata": self.metadata[idx],
63
+ "score": float(score)
64
+ }
65
+ results.append(result)
66
+
67
+ return results
68
+
69
+ def _save_index(self):
70
+ """Save index and metadata"""
71
+ index_path = os.path.join(self.base_path, "faiss.index")
72
+ metadata_path = os.path.join(self.base_path, "metadata.pkl")
73
+
74
+ faiss.write_index(self.index, index_path)
75
+ with open(metadata_path, 'wb') as f:
76
+ pickle.dump(self.metadata, f)