import faiss import numpy as np import json from tqdm import tqdm import os from torch.nn import DataParallel from transformers import AutoTokenizer, AutoModel, T5EncoderModel import torch from sentence_transformers import SentenceTransformer from multiprocessing import Pool import time start_time = time.time() with open("merged_triple_processed_new_withID.json", "r") as fi: data = json.load(fi) sentences = [_['contents'] for _ in data] print("Chunks nums: ", len(sentences)) # model_path = '/mnt/ceph_rbd/hf_models/gtr-t5-xl' # model_path = '/mnt/ceph_rbd/hf_models/bge-large-en-v1.5' model_path = 'facebook/contriever' ### Using SentenceTransformer # def encode_sentences_on_gpu(params): # sentences_chunk, device_id = params # device = torch.device(f'cuda:{device_id}') # model = SentenceTransformer(model_path, device=device) # embeddings = model.encode( # sentences_chunk, # batch_size=1024, # show_progress_bar=False, # convert_to_numpy=True, # normalize_embeddings=True # ) # return embeddings # num_gpus = torch.cuda.device_count() # print(f"Number of GPUs: {num_gpus}") # sentences_chunks = np.array_split(sentences, num_gpus) # params = [(sentences_chunks[i], i) for i in range(num_gpus)] # with Pool(processes=num_gpus) as pool: # embeddings_list = pool.map(encode_sentences_on_gpu, params) # sentence_embeddings = np.concatenate(embeddings_list, axis=0) ### Using Transformers tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModel.from_pretrained(model_path) # model = T5EncoderModel.from_pretrained(model_path) model = DataParallel(model) # Wrap the model for multi-GPU support model.eval() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) batch_size = 1024 def mean_pooling(token_embeddings, mask): token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] return sentence_embeddings def process_in_batches(sentences, batch_size): sentence_embeddings_list = [] for i in tqdm(range(0, len(sentences), batch_size)): batch_sentences = sentences[i:i + batch_size] encoded_input = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt').to(device) with torch.no_grad(): model_output = model(**encoded_input) batch_sentence_embeddings = mean_pooling(model_output[0], encoded_input['attention_mask']) # CLS pooling for BGE # batch_sentence_embeddings = model_output[0][:, 0] # pooling for GTR # batch_sentence_embeddings = model_output.last_hidden_state.mean(dim=1) # batch_sentence_embeddings = torch.nn.functional.normalize(batch_sentence_embeddings, p=2, dim=1) sentence_embeddings_list.append(batch_sentence_embeddings.cpu()) # Move to CPU to save GPU memory sentence_embeddings = torch.cat(sentence_embeddings_list, dim=0) return sentence_embeddings sentence_embeddings = process_in_batches(sentences, batch_size) sentence_embeddings = sentence_embeddings.cpu().numpy() # Create a FAISS index dim = sentence_embeddings.shape[1] faiss_index = faiss.IndexFlatIP(dim) # Inner product for cosine similarity faiss_index.add(sentence_embeddings) faiss_index_file = 'faiss_index.bin' faiss.write_index(faiss_index, faiss_index_file) print(f"FAISS index saved to {faiss_index_file}") embeddings_file = 'document_embeddings.npy' np.save(embeddings_file, sentence_embeddings) print(f"Document embeddings saved to {embeddings_file}") end_time = time.time() execution_time_hours = (end_time - start_time) / 3600 print(f"Total execution time: {execution_time_hours:.2f} hours") # instruction = "Represent this sentence for searching relevant passages: " # queries = ["Who is the president of U.S.A.?"] # encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt') # # Compute token embeddings # with torch.no_grad(): # model_output = model(**encoded_input) # # Perform pooling. In this case, cls pooling. # sentence_embeddings = model_output[0][:, 0] # # normalize embeddings # query_vector = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) # k = 5 # Number of nearest neighbors to retrieve # distances, indices = faiss_index.search(np.array([query_vector.cpu()], dtype=np.float32), k) # # Print the most similar documents # for i, index in enumerate(indices[0]): # distance = distances[0][i] # print(f"Nearest neighbor {i+1}: {documents[index]}, Distance {distance}")