cryogenic22 commited on
Commit
2b4b6b2
·
verified ·
1 Parent(s): 827fd16

Update utils/vector_store.py

Browse files
Files changed (1) hide show
  1. utils/vector_store.py +77 -62
utils/vector_store.py CHANGED
@@ -1,90 +1,105 @@
1
- from sentence_transformers import SentenceTransformer
2
- import faiss
3
- import numpy as np
4
- from typing import List, Dict
5
  import os
6
  import pickle
 
 
7
 
8
 
9
  class VectorStore:
10
- def __init__(self, storage_path: str = "data/vector_store", dimension: int = 384):
11
  """
12
- Initialize the VectorStore.
13
 
14
  Args:
15
- storage_path (str): Path to store the FAISS index and metadata.
16
- dimension (int): Dimension of the embeddings (depends on the embedding model used).
17
  """
18
  self.storage_path = storage_path
19
- os.makedirs(self.storage_path, exist_ok=True)
 
 
20
 
21
- self.embedder = SentenceTransformer("all-MiniLM-L6-v2") # Pre-trained model
22
- self.dimension = dimension
 
23
 
24
- # Initialize FAISS index and metadata
25
- self.index = faiss.IndexFlatL2(self.dimension)
26
- self.metadata = []
27
- self._load_vector_store()
 
 
 
 
28
 
29
- def _load_vector_store(self):
30
- """Load the FAISS index and metadata from persistent storage."""
31
- try:
32
- index_path = os.path.join(self.storage_path, "faiss.index")
33
- metadata_path = os.path.join(self.storage_path, "metadata.pkl")
34
- if os.path.exists(index_path) and os.path.exists(metadata_path):
35
- self.index = faiss.read_index(index_path)
36
- with open(metadata_path, "rb") as f:
37
- self.metadata = pickle.load(f)
38
- except Exception as e:
39
- print(f"Failed to load vector store: {e}")
40
 
41
- def add_texts(self, texts: List[str], metadatas: List[Dict] = None):
42
  """
43
- Add texts and their metadata to the vector store.
44
 
45
  Args:
46
- texts (List[str]): List of text chunks to be added.
47
- metadatas (List[Dict]): List of metadata dictionaries corresponding to the text chunks.
 
48
  """
49
- embeddings = self.embedder.encode(texts, show_progress_bar=True)
50
- self.index.add(np.array(embeddings).astype("float32"))
51
- self.metadata.extend(metadatas if metadatas else [{}] * len(texts))
52
- self._save_vector_store()
53
 
54
- def similarity_search(self, query: str, k: int = 5) -> List[Dict]:
55
  """
56
- Perform a similarity search for the given query.
57
 
58
  Args:
59
- query (str): The query text.
60
- k (int): Number of closest matches to retrieve.
61
 
62
  Returns:
63
- List[Dict]: A list of dictionaries containing the text and its relevance score.
64
  """
65
- query_embedding = self.embedder.encode([query]).astype("float32")
66
- distances, indices = self.index.search(query_embedding, k)
67
- return [
68
- {"text": self.metadata[i]["text"], "distance": distances[0][j]}
69
- for j, i in enumerate(indices[0]) if i < len(self.metadata)
70
- ]
71
-
72
- def _save_vector_store(self):
73
- """Save the FAISS index and metadata to persistent storage."""
74
- try:
75
- index_path = os.path.join(self.storage_path, "faiss.index")
76
- metadata_path = os.path.join(self.storage_path, "metadata.pkl")
77
- faiss.write_index(self.index, index_path)
78
- with open(metadata_path, "wb") as f:
79
- pickle.dump(self.metadata, f)
80
- except Exception as e:
81
- print(f"Failed to save vector store: {e}")
82
 
83
- def reset_store(self):
84
  """
85
- Reset the vector store by clearing the FAISS index and metadata.
86
- This is useful for starting fresh.
 
 
 
 
 
 
87
  """
88
- self.index = faiss.IndexFlatL2(self.dimension)
89
- self.metadata = []
90
- self._save_vector_store()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import pickle
3
+ from typing import List, Dict, Any
4
+ from sentence_transformers import SentenceTransformer, util
5
 
6
 
7
  class VectorStore:
8
+ def __init__(self, storage_path: str = "data/vector_store"):
9
  """
10
+ Initialize VectorStore.
11
 
12
  Args:
13
+ storage_path (str): Path to store vectorized documents.
 
14
  """
15
  self.storage_path = storage_path
16
+ os.makedirs(storage_path, exist_ok=True)
17
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
18
+ self.vectors = self._load_vectors()
19
 
20
+ def _load_vectors(self) -> List[Dict]:
21
+ """
22
+ Load stored vectors from the file system.
23
 
24
+ Returns:
25
+ List[Dict]: List of stored vectorized documents.
26
+ """
27
+ vector_file = os.path.join(self.storage_path, "vectors.pkl")
28
+ if os.path.exists(vector_file):
29
+ with open(vector_file, "rb") as f:
30
+ return pickle.load(f)
31
+ return []
32
 
33
+ def _save_vectors(self):
34
+ """
35
+ Save the current vectors to the file system.
36
+ """
37
+ vector_file = os.path.join(self.storage_path, "vectors.pkl")
38
+ with open(vector_file, "wb") as f:
39
+ pickle.dump(self.vectors, f)
 
 
 
 
40
 
41
+ def add_document(self, doc_id: str, text: str, metadata: Dict[str, Any]):
42
  """
43
+ Add a new document to the vector store.
44
 
45
  Args:
46
+ doc_id (str): Unique document identifier.
47
+ text (str): Full text of the document.
48
+ metadata (Dict[str, Any]): Metadata associated with the document.
49
  """
50
+ vector = self.model.encode(text, convert_to_tensor=True)
51
+ self.vectors.append({"doc_id": doc_id, "vector": vector, "text": text, "metadata": metadata})
52
+ self._save_vectors()
 
53
 
54
+ def similarity_search(self, query: str, top_k: int = 5) -> List[Dict]:
55
  """
56
+ Perform a similarity search for the query against stored vectors.
57
 
58
  Args:
59
+ query (str): Query string to search for.
60
+ top_k (int): Number of top results to return.
61
 
62
  Returns:
63
+ List[Dict]: List of the most similar documents.
64
  """
65
+ query_vector = self.model.encode(query, convert_to_tensor=True)
66
+ results = []
67
+ for doc in self.vectors:
68
+ similarity_score = util.pytorch_cos_sim(query_vector, doc["vector"]).item()
69
+ results.append({"doc_id": doc["doc_id"], "text": doc["text"], "metadata": doc["metadata"], "score": similarity_score})
70
+ results = sorted(results, key=lambda x: x["score"], reverse=True)
71
+ return results[:top_k]
 
 
 
 
 
 
 
 
 
 
72
 
73
+ def chat_with_context(self, query: str, context: str) -> str:
74
  """
75
+ Generate a response to the query using the provided context.
76
+
77
+ Args:
78
+ query (str): Query string from the user.
79
+ context (str): Context string from relevant documents.
80
+
81
+ Returns:
82
+ str: Generated response.
83
  """
84
+ # Combine query and context for the final prompt
85
+ combined_input = f"""
86
+ Context:
87
+ {context}
88
+
89
+ Question:
90
+ {query}
91
+
92
+ Please provide a detailed and accurate response.
93
+ """
94
+
95
+ # Placeholder for LLM API integration
96
+ try:
97
+ # Simulate response using pre-trained embeddings and relevance
98
+ results = self.similarity_search(query, top_k=3)
99
+ return (
100
+ f"Based on the context:\n\n"
101
+ f"{results[0]['text'][:500]}...\n\n"
102
+ f"Response: The query '{query}' relates to the provided context."
103
+ )
104
+ except Exception as e:
105
+ return f"Error generating response: {str(e)}"