|
|
import faiss |
|
|
import numpy as np |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
from torch.nn import DataParallel |
|
|
from transformers import AutoTokenizer, AutoModel, T5EncoderModel |
|
|
from torch.utils.data import DataLoader, TensorDataset |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from multiprocessing import Pool |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with open("/mnt/ceph_rbd/hf_models/feedbackcorpus/merged_triple_processed_new_withID.json", "r") as fi: |
|
|
data = json.load(fi) |
|
|
|
|
|
sentences = [_['contents'] for _ in data] |
|
|
print("Chunks nums: ", len(sentences)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = '/mnt/ceph_rbd/hf_models/contriever' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModel.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
|
model = torch.nn.DataParallel(model) |
|
|
model = model.to('cuda') |
|
|
print("Model load success...") |
|
|
|
|
|
|
|
|
batch_size = 2048*4 |
|
|
print("len(sentences)//b_z: ", len(sentences)//batch_size) |
|
|
|
|
|
|
|
|
def mean_pooling(token_embeddings, mask): |
|
|
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0) |
|
|
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] |
|
|
return sentence_embeddings |
|
|
|
|
|
|
|
|
def sentence_generator(sentences, tokenizer, batch_size): |
|
|
for i in range(0, len(sentences), batch_size): |
|
|
batch_sentences = sentences[i:i + batch_size] |
|
|
inputs = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt') |
|
|
yield inputs['input_ids'], inputs['attention_mask'] |
|
|
|
|
|
|
|
|
all_embeddings = [] |
|
|
with torch.no_grad(): |
|
|
for input_ids, attention_mask in tqdm(sentence_generator(sentences, tokenizer, batch_size), |
|
|
total=(len(sentences) + batch_size - 1) // batch_size, |
|
|
desc="Processing batches"): |
|
|
input_ids, attention_mask = input_ids.to('cuda'), attention_mask.to('cuda') |
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
embeddings = mean_pooling(outputs[0], attention_mask) |
|
|
all_embeddings.append(embeddings.cpu()) |
|
|
|
|
|
sentence_embeddings = torch.cat(all_embeddings).numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Starting create FAISS...") |
|
|
dim = sentence_embeddings.shape[1] |
|
|
faiss_index = faiss.IndexFlatIP(dim) |
|
|
|
|
|
batch_size = 100000 |
|
|
for i in tqdm(range(0, len(sentence_embeddings), batch_size), desc="Adding embeddings to FAISS index"): |
|
|
faiss_index.add(sentence_embeddings[i:i + batch_size]) |
|
|
|
|
|
|
|
|
|
|
|
faiss_index_file = 'results/faiss_index.bin' |
|
|
faiss.write_index(faiss_index, faiss_index_file) |
|
|
print(f"FAISS index saved to {faiss_index_file}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
execution_time_hours = (end_time - start_time) / 3600 |
|
|
print(f"Total execution time: {execution_time_hours:.2f} hours") |