| | import faiss |
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| | from typing import List, Optional, Tuple |
| | from langchain_community.graphs import Neo4jGraph |
| | import pickle |
| |
|
| | class FAISSVectorStore: |
| | def __init__(self, model_name: str = None, dimension: int = 384, embedding_file: str = None, trust_remote_code = False): |
| | self.model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code) if model_name is not None else None |
| | self.index = faiss.IndexFlatIP(dimension) |
| | self.dimension = dimension |
| | if embedding_file: |
| | self.load_embeddings(embedding_file) |
| |
|
| | def load_embeddings(self, file_path: str): |
| | if file_path.endswith('.pkl'): |
| | with open(file_path, 'rb') as f: |
| | embeddings = pickle.load(f) |
| | elif file_path.endswith('.npy'): |
| | embeddings = np.load(file_path) |
| | else: |
| | raise ValueError("Unsupported file format. Use .pkl or .npy") |
| | |
| | self.add_embeddings(embeddings) |
| |
|
| | def add_embeddings(self, embeddings: np.ndarray): |
| | faiss.normalize_L2(embeddings) |
| | self.index.add(embeddings) |
| |
|
| | def similarity_search(self, query: str, k: int = 5, use_mmr: bool = False, lambda_param: float = 0.5, doc_types: list[str] = None, neo4j_graph: Neo4jGraph = None): |
| | query_vector = self.model.encode([query]) |
| | faiss.normalize_L2(query_vector) |
| | |
| | if use_mmr: |
| | return self._mmr_search(query_vector, k, lambda_param, neo4j_graph, doc_types) |
| | else: |
| | return self._simple_search(query_vector, k, neo4j_graph, doc_types) |
| |
|
| | def _simple_search(self, query_vector: np.ndarray, k: int, neo4j_graph: Neo4jGraph, doc_types : list[str] = None) -> List[dict]: |
| | distances, indices = self.index.search(query_vector, k) |
| | |
| | results = [] |
| | results_idx = [] |
| | for i, idx in enumerate(indices[0]): |
| | document = self._get_text_by_index(neo4j_graph, idx, doc_types) |
| | if document is not None: |
| | results.append({ |
| | 'document': document, |
| | 'score': distances[0][i] |
| | }) |
| | results_idx.append(idx) |
| | |
| | return results, results_idx |
| |
|
| | def _mmr_search(self, query_vector: np.ndarray, k: int, lambda_param: float, neo4j_graph: Neo4jGraph, doc_types: list[str] = None) -> Tuple[List[dict], List[int]]: |
| | initial_k = min(k * 2, self.index.ntotal) |
| | distances, indices = self.index.search(query_vector, initial_k) |
| | |
| | |
| | initial_embeddings = self._reconstruct_embeddings(indices[0]) |
| | |
| | selected_indices = [] |
| | unselected_indices = list(range(len(indices[0]))) |
| | |
| | for _ in range(min(k, len(indices[0]))): |
| | mmr_scores = [] |
| | for i in unselected_indices: |
| | if not selected_indices: |
| | mmr_scores.append((i, distances[0][i])) |
| | else: |
| | embedding_i = initial_embeddings[i] |
| | redundancy = max(self._cosine_similarity(embedding_i, initial_embeddings[j]) for j in selected_indices) |
| | mmr_scores.append((i, lambda_param * distances[0][i] - (1 - lambda_param) * redundancy)) |
| | |
| | selected_idx = max(mmr_scores, key=lambda x: x[1])[0] |
| | selected_indices.append(selected_idx) |
| | unselected_indices.remove(selected_idx) |
| | |
| | results = [] |
| | results_idx = [] |
| | for idx in selected_indices: |
| | document = self._get_text_by_index(neo4j_graph, indices[0][idx], doc_types) |
| | if document is not None: |
| | results.append({ |
| | 'document': document, |
| | 'score': distances[0][idx] |
| | }) |
| | results_idx.append(idx) |
| | |
| | return results, results_idx |
| |
|
| | def _reconstruct_embeddings(self, indices: np.ndarray) -> np.ndarray: |
| | return self.index.reconstruct_batch(indices) |
| |
|
| | def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: |
| | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
| | |
| | def _get_text_by_index(self, neo4j_graph, index, doc_types): |
| | if doc_types is None: |
| | query = f""" |
| | MATCH (n) |
| | WHERE n.id = $index |
| | RETURN n AS document, labels(n) AS document_type, n.id AS node_id |
| | """ |
| | result = neo4j_graph.query(query, {"index": index}) |
| | else: |
| | for doc_type in doc_types: |
| | query = f""" |
| | MATCH (n:{doc_type}) |
| | WHERE n.id = $index |
| | RETURN n AS document, labels(n) AS document_type, n.id AS node_id |
| | """ |
| | result = neo4j_graph.query(query, {"index": index}) |
| | if result: |
| | break |
| |
|
| | if result: |
| | return f"[{result[0]['document_type'][0]}] {result[0]['document']}" |
| | return None |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | vector_store = FAISSVectorStore(dimension=384, embedding_file="path/to/your/embeddings.pkl") |
| |
|
| | |
| | neo4j_graph = Neo4jGraph( |
| | url="bolt://localhost:7687", |
| | username="neo4j", |
| | password="password" |
| | ) |
| |
|
| | |
| | query = "How to start a long journey" |
| | results_simple = vector_store.similarity_search(query, k=5, use_mmr=False, neo4j_graph=neo4j_graph) |
| | results_mmr = vector_store.similarity_search(query, k=5, use_mmr=True, lambda_param=0.5, neo4j_graph=neo4j_graph) |
| |
|
| | |
| | print(f"Top 5 similar texts for query: '{query}' (without MMR)") |
| | for i, result in enumerate(results_simple, 1): |
| | print(f"{i}. Text: {result['text']}") |
| | print(f" Score: {result['score']}") |
| | print() |
| |
|
| | print(f"Top 5 similar texts for query: '{query}' (with MMR)") |
| | for i, result in enumerate(results_mmr, 1): |
| | print(f"{i}. Text: {result['text']}") |
| | print(f" Score: {result['score']}") |
| | print() |