File size: 2,101 Bytes
13cefc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import faiss
import numpy as np
import json
from tqdm import tqdm
import os
from torch.nn import DataParallel
from transformers import AutoTokenizer, AutoModel, T5EncoderModel
import torch
from sentence_transformers import SentenceTransformer
from multiprocessing import Pool

# 添加总体进度条
print("Loading data...")
with open("merged_triple_processed_new_withID.json", "r") as fi:
    data = json.load(fi)

sentences = [_['contents'] for _ in data]
print(f"Chunks nums: {len(sentences)}")

model_path = 'sentence-transformers/gtr-t5-large'

def encode_sentences_on_gpu(params):
    sentences_chunk, device_id = params
    device = torch.device(f'cuda:{device_id}')
    model = SentenceTransformer(model_path, device=device)
    embeddings = model.encode(
        sentences_chunk,
        batch_size=512,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
        desc=f'GPU {device_id} encoding'
    )
    return embeddings

num_gpus = torch.cuda.device_count()
print(f"Number of GPUs: {num_gpus}")

sentences_chunks = np.array_split(sentences, num_gpus)
params = [(sentences_chunks[i], i) for i in range(num_gpus)]

print("Starting encoding process...")
with Pool(processes=num_gpus) as pool:
    embeddings_list = list(tqdm(
        pool.imap(encode_sentences_on_gpu, params),
        total=num_gpus,
        desc='Overall progress'
    ))

print("Concatenating embeddings...")
sentence_embeddings = np.concatenate(embeddings_list, axis=0)

# Create a FAISS index
print("Creating FAISS index...")
dim = sentence_embeddings.shape[1]
faiss_index = faiss.IndexFlatIP(dim)

# 添加进度提示
print("Adding embeddings to FAISS index...")
faiss_index.add(sentence_embeddings)

# 保存索引和嵌入
print("Saving FAISS index...")
faiss_index_file = 'faiss_index.bin'
faiss.write_index(faiss_index, faiss_index_file)
print(f"FAISS index saved to {faiss_index_file}")

print("Saving embeddings...")
embeddings_file = 'document_embeddings.npy'
np.save(embeddings_file, sentence_embeddings)
print(f"Document embeddings saved to {embeddings_file}")