|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from collections import defaultdict |
|
|
from typing import List, Dict, Tuple, Union |
|
|
import torch |
|
|
from PIL import Image |
|
|
import pickle |
|
|
from openai import OpenAI |
|
|
import os |
|
|
import torch |
|
|
import time |
|
|
import yaml |
|
|
|
|
|
class MemoryIndex: |
|
|
def __init__(self,number_of_neighbours,use_openai=False): |
|
|
self.documents = {} |
|
|
self.document_vectors = {} |
|
|
self.use_openai=use_openai |
|
|
if use_openai: |
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
self.client = OpenAI(api_key=api_key) |
|
|
self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') |
|
|
|
|
|
with open('test_configs/llama2_test_config.yaml') as file: |
|
|
config = yaml.load(file, Loader=yaml.FullLoader) |
|
|
embedding_gpu_id=config['model']['minigpt4_gpu_id'] |
|
|
self.device = f"cuda:{embedding_gpu_id}" if torch.cuda.is_available() else "cpu" |
|
|
self.number_of_neighbours=int(number_of_neighbours) |
|
|
|
|
|
def load_documents_from_json(self, file_path,emdedding_path=""): |
|
|
|
|
|
with open(file_path, 'r') as file: |
|
|
data = json.load(file) |
|
|
for doc_id, doc_data in data.items(): |
|
|
self.documents[doc_id] = doc_data |
|
|
self.document_vectors[doc_id] = self._compute_sentence_embedding(doc_data) |
|
|
|
|
|
|
|
|
m=[self.documents,self.document_vectors] |
|
|
with open(emdedding_path, 'wb') as file: |
|
|
pickle.dump(m, file) |
|
|
return emdedding_path |
|
|
def load_embeddings_from_pkl(self, pkl_file_path): |
|
|
|
|
|
with open(pkl_file_path, 'rb') as file: |
|
|
data = pickle.load(file) |
|
|
self.documents=data[0] |
|
|
self.document_vectors=data[1] |
|
|
|
|
|
|
|
|
def load_data_from_pkl(self, pkl_file_path): |
|
|
with open(pkl_file_path, 'rb') as file: |
|
|
data = pickle.load(file) |
|
|
for doc_id, doc_data in data.items(): |
|
|
self.documents[doc_id] = doc_data |
|
|
self.document_vectors[doc_id] = doc_data |
|
|
def _compute_sentence_embedding(self, text: str) -> torch.Tensor: |
|
|
if self.use_openai: |
|
|
done=False |
|
|
while not done: |
|
|
try: |
|
|
embedding=self.client.embeddings.create(input = [text], model="text-embedding-3-small").data[0].embedding |
|
|
|
|
|
embedding = torch.tensor(embedding) |
|
|
done=True |
|
|
except Exception as e: |
|
|
print("error",e) |
|
|
print("text",text) |
|
|
|
|
|
time.sleep(5) |
|
|
continue |
|
|
else: |
|
|
return self.model.encode(text, convert_to_tensor=True).to(self.device) |
|
|
|
|
|
return embedding |
|
|
|
|
|
def search_by_similarity(self, query: str) -> List[str]: |
|
|
|
|
|
query_vector = self._compute_sentence_embedding(query) |
|
|
scores = {doc_id: torch.nn.functional.cosine_similarity(query_vector, doc_vector, dim=0).item() |
|
|
for doc_id, doc_vector in self.document_vectors.items()} |
|
|
sorted_doc_ids = sorted(scores, key=scores.get, reverse=True) |
|
|
sorted_documents=[self.documents[doc_id] for doc_id in sorted_doc_ids] |
|
|
if self.number_of_neighbours == -1: |
|
|
return list(self.documents.values()), list(self.documents.keys()) |
|
|
if self.number_of_neighbours > len(sorted_documents): |
|
|
return sorted_documents, sorted_doc_ids |
|
|
|
|
|
if self.number_of_neighbours==1 and sorted_doc_ids[0]=='summary': |
|
|
return sorted_documents[0:2], sorted_doc_ids[:2] |
|
|
print("Number of neighbours",self.number_of_neighbours) |
|
|
return sorted_documents[:self.number_of_neighbours], sorted_doc_ids[:self.number_of_neighbours] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|