|
|
import os |
|
|
import gc |
|
|
import inspect |
|
|
import math |
|
|
import multiprocessing as mp |
|
|
import queue |
|
|
from multiprocessing import Queue |
|
|
import warnings |
|
|
from typing import Any, Union, List, Dict, Literal, Optional |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
from torch import nn |
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
from transformers import Qwen2Config |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import ( |
|
|
add_start_docstrings, |
|
|
add_start_docstrings_to_model_forward, |
|
|
is_flash_attn_2_available, |
|
|
is_flash_attn_greater_or_equal_2_10, |
|
|
logging, |
|
|
replace_return_docstrings, |
|
|
) |
|
|
import numpy as np |
|
|
from transformers import Qwen2Config |
|
|
from transformers import Qwen2ForCausalLM |
|
|
import inspect |
|
|
import math |
|
|
import os |
|
|
import warnings |
|
|
from typing import List, Optional, Tuple, Union |
|
|
from tqdm import tqdm, trange |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
from torch import nn |
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import ( |
|
|
add_start_docstrings, |
|
|
add_start_docstrings_to_model_forward, |
|
|
is_flash_attn_2_available, |
|
|
is_flash_attn_greater_or_equal_2_10, |
|
|
logging, |
|
|
replace_return_docstrings, |
|
|
) |
|
|
import numpy as np |
|
|
import torch |
|
|
import os |
|
|
import argparse |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
from typing import cast, List, Union, Tuple |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from peft import LoraConfig, get_peft_model, TaskType |
|
|
import time |
|
|
import torch.nn.functional as F |
|
|
import sys |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from tqdm import tqdm, trange |
|
|
from collections import defaultdict |
|
|
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig |
|
|
import torch.distributed as dist |
|
|
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint |
|
|
import sys |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import re |
|
|
import logging |
|
|
logging.getLogger().setLevel(logging.INFO) |
|
|
from .configuration_c2llm import C2LLMConfig |
|
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Attention |
|
|
|
|
|
class MAB_POST(nn.Module): |
|
|
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): |
|
|
super(MAB_POST, self).__init__() |
|
|
self.dim_V = dim_V |
|
|
self.num_heads = num_heads |
|
|
self.fc_q = nn.Linear(dim_Q, dim_V) |
|
|
self.fc_k = nn.Linear(dim_K, dim_V) |
|
|
self.fc_v = nn.Linear(dim_K, dim_V) |
|
|
if ln: |
|
|
self.ln0 = nn.LayerNorm(dim_V) |
|
|
self.ln1 = nn.LayerNorm(dim_V) |
|
|
self.fc_o = nn.Linear(dim_V, dim_V) |
|
|
nn.init.xavier_uniform_(self.fc_q.weight) |
|
|
nn.init.xavier_uniform_(self.fc_k.weight) |
|
|
nn.init.xavier_uniform_(self.fc_v.weight) |
|
|
nn.init.xavier_uniform_(self.fc_o.weight) |
|
|
|
|
|
def forward(self, Q, K, pad_mask=None): |
|
|
|
|
|
Q_ = self.fc_q(Q) |
|
|
K_, V_ = self.fc_k(K), self.fc_v(K) |
|
|
|
|
|
dim_split = self.dim_V // self.num_heads |
|
|
Q_ = torch.cat(Q_.split(dim_split, 2), 0) |
|
|
K_ = torch.cat(K_.split(dim_split, 2), 0) |
|
|
V_ = torch.cat(V_.split(dim_split, 2), 0) |
|
|
|
|
|
pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) |
|
|
score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) |
|
|
score = score.masked_fill(pad_mask == 0, -1e12) |
|
|
A = torch.softmax(score, 2) |
|
|
A = A * pad_mask |
|
|
O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) |
|
|
O = Q + O |
|
|
O = O if getattr(self, 'ln0', None) is None else self.ln0(O) |
|
|
O = O + F.relu(self.fc_o(O)) |
|
|
O = O if getattr(self, 'ln1', None) is None else self.ln1(O) |
|
|
return O |
|
|
|
|
|
|
|
|
class PMA(nn.Module): |
|
|
def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False, pma_mode=None): |
|
|
super(PMA, self).__init__() |
|
|
self.S = nn.Parameter(torch.Tensor(1, num_seeds, compressed_dim)) |
|
|
nn.init.xavier_uniform_(self.S) |
|
|
if pma_mode == 'post_normal': |
|
|
self.mab = MAB_POST(compressed_dim, dim, compressed_dim, num_heads, ln=ln) |
|
|
else: |
|
|
raise ValueError(f"Error, the pma_mode {pma_mode} is not implemented !") |
|
|
|
|
|
def forward(self, X, pad_mask): |
|
|
if self.S.dtype != torch.bfloat16: |
|
|
X = X.float() |
|
|
return self.mab(self.S.repeat(X.size(0), 1, 1), X, pad_mask) |
|
|
|
|
|
|
|
|
|
|
|
class MAB_POST_v2(nn.Module): |
|
|
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): |
|
|
super(MAB_POST_v2, self).__init__() |
|
|
self.dim_V = dim_V |
|
|
self.num_heads = num_heads |
|
|
self.fc_q = nn.Linear(dim_Q, dim_V) |
|
|
self.fc_k = nn.Linear(dim_K, dim_V) |
|
|
self.fc_v = nn.Linear(dim_K, dim_V) |
|
|
|
|
|
if ln: |
|
|
self.ln0 = nn.LayerNorm(dim_V) |
|
|
self.ln1 = nn.LayerNorm(dim_V) |
|
|
self.fc_o = nn.Linear(dim_V, dim_V) |
|
|
nn.init.xavier_uniform_(self.fc_q.weight) |
|
|
nn.init.xavier_uniform_(self.fc_k.weight) |
|
|
nn.init.xavier_uniform_(self.fc_v.weight) |
|
|
nn.init.xavier_uniform_(self.fc_o.weight) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, Q, K, pad_mask=None): |
|
|
|
|
|
Q_tmp = self.fc_q(Q) |
|
|
K_, V_ = self.fc_k(K), self.fc_v(K) |
|
|
|
|
|
dim_split = self.dim_V // self.num_heads |
|
|
Q_ = torch.cat(Q_tmp.split(dim_split, 2), 0) |
|
|
K_ = torch.cat(K_.split(dim_split, 2), 0) |
|
|
V_ = torch.cat(V_.split(dim_split, 2), 0) |
|
|
|
|
|
pad_mask = pad_mask.unsqueeze(1).repeat(self.num_heads, Q.size(1), 1) |
|
|
score = Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V) |
|
|
score = score.masked_fill(pad_mask == 0, -1e12) |
|
|
A = torch.softmax(score, 2) |
|
|
A = A * pad_mask |
|
|
O = torch.cat(A.bmm(V_).split(Q.size(0), 0), 2) |
|
|
O = Q_tmp + O |
|
|
|
|
|
O = O if getattr(self, 'ln0', None) is None else self.ln0(O) |
|
|
O = O + F.relu(self.fc_o(O)) |
|
|
O = O if getattr(self, 'ln1', None) is None else self.ln1(O) |
|
|
return O |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PMA_v2(nn.Module): |
|
|
def __init__(self, dim, compressed_dim, num_heads, num_seeds, ln=False): |
|
|
super(PMA_v2, self).__init__() |
|
|
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) |
|
|
nn.init.xavier_uniform_(self.S) |
|
|
|
|
|
self.mab = MAB_POST_v2(dim, dim, compressed_dim, num_heads, ln=ln) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, X, pad_mask): |
|
|
if self.S.dtype != torch.bfloat16: |
|
|
X = X.float() |
|
|
return self.mab(self.S.expand(X.size(0), -1, -1), X, pad_mask) |
|
|
|
|
|
|
|
|
class C2LLMModel(PreTrainedModel): |
|
|
config_class = C2LLMConfig |
|
|
config: C2LLMConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Qwen2DecoderLayer"] |
|
|
_skip_keys_device_placement = ["past_key_values"] |
|
|
_supports_flash_attn = True |
|
|
_supports_sdpa = True |
|
|
_supports_flex_attn = True |
|
|
|
|
|
_can_compile_fullgraph = True |
|
|
_supports_attention_backend = True |
|
|
_can_record_outputs = { |
|
|
"hidden_states": Qwen2DecoderLayer, |
|
|
"attentions": Qwen2Attention, |
|
|
} |
|
|
|
|
|
|
|
|
class C2LLMForEmbedding(C2LLMModel): |
|
|
|
|
|
config_class = C2LLMConfig |
|
|
model_type = "c2llm" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
qwen_cfg = Qwen2Config.from_dict(config.to_dict()) |
|
|
self.plm_model = AutoModelForCausalLM.from_config(qwen_cfg) |
|
|
self.embedding_method = config.embedding_method |
|
|
self.inf_seq_length = 2048 |
|
|
self.padding_side = config.padding_side |
|
|
|
|
|
self.emb_dim = self.plm_model.model.embed_tokens.weight.size(1) |
|
|
self.keep_max_layer = self.plm_model.config.num_hidden_layers |
|
|
self.num_heads = config.pma_num_heads |
|
|
self.ln = config.pma_ln |
|
|
self.norm = config.pma_norm |
|
|
self.pma_mode = config.pma_norm_mode |
|
|
self.compressed_dim = config.compressed_dim |
|
|
|
|
|
self.mha_pma_disc = PMA_v2(self.emb_dim, self.compressed_dim, self.num_heads, 1, ln=self.ln) |
|
|
self.pool = None |
|
|
self.target_devices = self.get_target_devices(None) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, padding_side=config.padding_side) if config.tokenizer_name_or_path is not None else None |
|
|
self.config_class = C2LLMConfig |
|
|
|
|
|
|
|
|
def pma_embedding(self, mha_pma, A, mask): |
|
|
res = mha_pma(A, mask).squeeze(1) |
|
|
return res |
|
|
|
|
|
def get_hidden_states(self, **inputs): |
|
|
outputs = self.plm_model(inputs['input_ids'], inputs['attention_mask'], output_hidden_states=True) |
|
|
return outputs.hidden_states[self.keep_max_layer] |
|
|
|
|
|
def get_sentence_embedding(self, embedding_method, hidden_states, emb_type, attention_mask): |
|
|
|
|
|
|
|
|
if embedding_method == 'pma': |
|
|
|
|
|
if emb_type == 'disc': |
|
|
res_embedding = self.pma_embedding(self.mha_pma_disc, hidden_states, attention_mask) |
|
|
if self.norm: |
|
|
res_embedding = torch.nn.functional.normalize(res_embedding, p=2.0, dim=-1, eps=1e-12, out=None) |
|
|
return res_embedding |
|
|
else: |
|
|
raise NotImplementedError(f"emb type {emb_type} hasn't been implemented") |
|
|
else: |
|
|
raise NotImplementedError(f"embedding method {embedding_method} hasn't been implemented") |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]: |
|
|
""" |
|
|
|
|
|
Args: |
|
|
devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`. |
|
|
|
|
|
Raises: |
|
|
ValueError: Devices should be a string or an integer or a list of strings or a list of integers. |
|
|
|
|
|
Returns: |
|
|
List[str]: A list of target devices in format. |
|
|
""" |
|
|
if devices is None: |
|
|
if torch.cuda.is_available(): |
|
|
return [f"cuda:{i}" for i in range(torch.cuda.device_count())] |
|
|
elif is_torch_npu_available(): |
|
|
return [f"npu:{i}" for i in range(torch.npu.device_count())] |
|
|
elif hasattr(torch, "musa") and torch.musa.is_available(): |
|
|
return [f"musa:{i}" for i in range(torch.musa.device_count())] |
|
|
elif torch.backends.mps.is_available(): |
|
|
try: |
|
|
return [f"mps:{i}" for i in range(torch.mps.device_count())] |
|
|
except: |
|
|
return ["mps"] |
|
|
else: |
|
|
return ["cpu"] |
|
|
elif isinstance(devices, str): |
|
|
return [devices] |
|
|
elif isinstance(devices, int): |
|
|
if hasattr(torch, "musa") and torch.musa.is_available(): |
|
|
return [f"musa:{devices}"] |
|
|
else: |
|
|
return [f"cuda:{devices}"] |
|
|
elif isinstance(devices, list): |
|
|
if isinstance(devices[0], str): |
|
|
return devices |
|
|
elif isinstance(devices[0], int): |
|
|
if hasattr(torch, "musa") and torch.musa.is_available(): |
|
|
return [f"musa:{device}" for device in devices] |
|
|
else: |
|
|
return [f"cuda:{device}" for device in devices] |
|
|
else: |
|
|
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.") |
|
|
else: |
|
|
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_multi_process_pool( |
|
|
self, |
|
|
process_target_func: Any, |
|
|
) -> Dict[Literal["input", "output", "processes"], Any]: |
|
|
""" |
|
|
Starts a multi-process pool to process the encoding with several independent processes |
|
|
via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`. |
|
|
|
|
|
This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised |
|
|
to start only one process per GPU. This method works together with encode_multi_process |
|
|
and stop_multi_process_pool. |
|
|
|
|
|
Returns: |
|
|
Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue. |
|
|
""" |
|
|
if self.plm_model is None or self.mha_pma_disc is None: |
|
|
raise ValueError("Model is not initialized.") |
|
|
|
|
|
logging.info("Start multi-process pool on devices: {}".format(", ".join(map(str, self.target_devices)))) |
|
|
|
|
|
self.to("cpu") |
|
|
self.share_memory() |
|
|
ctx = mp.get_context("spawn") |
|
|
input_queue = ctx.Queue() |
|
|
output_queue = ctx.Queue() |
|
|
processes = [] |
|
|
|
|
|
for device_id in tqdm(self.target_devices, desc='initial target device'): |
|
|
p = ctx.Process( |
|
|
target=process_target_func, |
|
|
args=(device_id, self, input_queue, output_queue), |
|
|
daemon=True, |
|
|
) |
|
|
p.start() |
|
|
processes.append(p) |
|
|
|
|
|
return {"input": input_queue, "output": output_queue, "processes": processes} |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _encode_multi_process_worker( |
|
|
target_device: str, model: 'C2LLMForEmbedding', input_queue: Queue, results_queue: Queue |
|
|
) -> None: |
|
|
model = model.to(target_device) |
|
|
while True: |
|
|
try: |
|
|
chunk_id, sentences, kwargs = ( |
|
|
input_queue.get() |
|
|
) |
|
|
embeddings = model.encode_single_device( |
|
|
sentences, |
|
|
device=target_device, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
results_queue.put([chunk_id, embeddings]) |
|
|
except queue.Empty: |
|
|
break |
|
|
|
|
|
def encode_multi_process( |
|
|
self, |
|
|
sentences: List[str], |
|
|
pool: Dict[Literal["input", "output", "processes"], Any], |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
chunk_size = math.ceil(len(sentences) / len(pool["processes"])) |
|
|
|
|
|
input_queue = pool["input"] |
|
|
last_chunk_id = 0 |
|
|
chunk = [] |
|
|
|
|
|
for sentence in sentences: |
|
|
chunk.append(sentence) |
|
|
if len(chunk) >= chunk_size: |
|
|
input_queue.put( |
|
|
[last_chunk_id, chunk, kwargs] |
|
|
) |
|
|
last_chunk_id += 1 |
|
|
chunk = [] |
|
|
|
|
|
if len(chunk) > 0: |
|
|
input_queue.put([last_chunk_id, chunk, kwargs]) |
|
|
last_chunk_id += 1 |
|
|
|
|
|
output_queue = pool["output"] |
|
|
results_list = sorted( |
|
|
[output_queue.get() for _ in trange(last_chunk_id, desc="")], |
|
|
key=lambda x: x[0], |
|
|
) |
|
|
embeddings = self._concatenate_results_from_multi_process([result[1] for result in results_list]) |
|
|
return embeddings |
|
|
|
|
|
def _concatenate_results_from_multi_process(self, results_list: List[Union[torch.Tensor, np.ndarray, Any]]): |
|
|
"""concatenate and return the results from all the processes |
|
|
|
|
|
Args: |
|
|
results_list (List[Union[torch.Tensor, np.ndarray, Any]]): A list of results from all the processes. |
|
|
|
|
|
Raises: |
|
|
NotImplementedError: Unsupported type for results_list |
|
|
|
|
|
Returns: |
|
|
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor. |
|
|
""" |
|
|
if isinstance(results_list[0], torch.Tensor): |
|
|
|
|
|
results_list = [res.to(self.target_devices[0]) for res in results_list] |
|
|
return torch.cat(results_list, dim=0) |
|
|
elif isinstance(results_list[0], np.ndarray): |
|
|
return np.concatenate(results_list, axis=0) |
|
|
else: |
|
|
raise NotImplementedError("Unsupported type for results_list") |
|
|
|
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, return_dict: bool=True, **kwargs): |
|
|
outputs = self.plm_model(input_ids, attention_mask, output_hidden_states=True) |
|
|
hidden_states = outputs.hidden_states[self.keep_max_layer] |
|
|
embeddings = self.get_sentence_embedding(self.embedding_method, hidden_states, 'disc', attention_mask) |
|
|
if not return_dict: |
|
|
return (embeddings,) |
|
|
return {"sentence_embedding": embeddings} |
|
|
|
|
|
def encode_single_device( |
|
|
self, |
|
|
sentences: Union[List[str], str], |
|
|
batch_size: int = 16, |
|
|
convert_to_numpy: bool = False, |
|
|
convert_to_tensor: bool = True, |
|
|
show_progress_bar: bool = True, |
|
|
max_seq_length: int = 2048, |
|
|
device: Optional[str] = None, |
|
|
**kwargs: Any |
|
|
): |
|
|
if max_seq_length is None: |
|
|
max_seq_length = self.inf_seq_length |
|
|
|
|
|
input_is_string = False |
|
|
if isinstance(sentences, str) or not hasattr(sentences, "__len__"): |
|
|
sentences = [sentences] |
|
|
input_is_string = True |
|
|
all_embeddings = [] |
|
|
length_sorted_idx = np.argsort([-len(s) for s in sentences]) |
|
|
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] |
|
|
with torch.no_grad(): |
|
|
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): |
|
|
sentences_batch = sentences_sorted[start_index: start_index + batch_size] |
|
|
inputs = self.tokenizer(sentences_batch, padding=True, truncation=True, max_length=max_seq_length, return_tensors='pt').to(self.plm_model.device) |
|
|
hidden_states = self.get_hidden_states(**inputs) |
|
|
embeddings = self.get_sentence_embedding(self.embedding_method, hidden_states, 'disc', inputs['attention_mask']) |
|
|
embeddings = embeddings.detach() |
|
|
if convert_to_numpy: |
|
|
if embeddings.dtype == torch.bfloat16: |
|
|
embeddings = embeddings.cpu().to(torch.float32) |
|
|
else: |
|
|
embeddings = embeddings.cpu() |
|
|
all_embeddings.extend(embeddings) |
|
|
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] |
|
|
if convert_to_tensor: |
|
|
all_embeddings = torch.stack(all_embeddings) |
|
|
elif convert_to_numpy: |
|
|
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) |
|
|
|
|
|
if input_is_string: |
|
|
all_embeddings = all_embeddings[0] |
|
|
return all_embeddings |
|
|
|
|
|
|
|
|
def encode(self, sentences, batch_size=16, convert_to_numpy=False, |
|
|
convert_to_tensor=True, show_progress_bar=True, max_seq_length=None, **kwargs): |
|
|
|
|
|
if max_seq_length is None: |
|
|
max_seq_length = self.inf_seq_length |
|
|
|
|
|
if convert_to_tensor == convert_to_numpy: |
|
|
convert_to_tensor=True |
|
|
convert_to_numpy=False |
|
|
|
|
|
if isinstance(sentences, str) or len(self.target_devices) == 1: |
|
|
return self.encode_single_device( |
|
|
sentences, |
|
|
batch_size=batch_size, |
|
|
convert_to_numpy=convert_to_numpy, |
|
|
convert_to_tensor=convert_to_tensor, |
|
|
show_progress_bar=show_progress_bar, |
|
|
max_seq_length=max_seq_length, |
|
|
device=self.target_devices[0], |
|
|
**kwargs |
|
|
) |
|
|
if self.pool is None: |
|
|
self.pool = self.start_multi_process_pool(C2LLMForEmbedding._encode_multi_process_worker) |
|
|
|
|
|
|
|
|
all_embeddings = [] |
|
|
length_sorted_idx = np.argsort([-len(s) for s in sentences]) |
|
|
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] |
|
|
with torch.no_grad(): |
|
|
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): |
|
|
sentences_batch = sentences_sorted[start_index: start_index + batch_size] |
|
|
embeddings_batch = self.encode_multi_process( |
|
|
sentences_batch, |
|
|
self.pool, |
|
|
convert_to_numpy=convert_to_numpy, |
|
|
convert_to_tensor=convert_to_tensor, |
|
|
show_progress_bar=show_progress_bar, |
|
|
max_seq_length=max_seq_length, |
|
|
**kwargs |
|
|
) |
|
|
embeddings_batch = embeddings_batch.detach() |
|
|
if convert_to_numpy: |
|
|
if embeddings_batch.dtype == torch.bfloat16: |
|
|
embeddings_batch = embeddings_batch.cpu().to(torch.float32) |
|
|
else: |
|
|
embeddings_batch = embeddings_batch.cpu() |
|
|
all_embeddings.extend(embeddings_batch) |
|
|
all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] |
|
|
if convert_to_tensor: |
|
|
all_embeddings = torch.stack(all_embeddings) |
|
|
elif convert_to_numpy: |
|
|
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) |
|
|
|
|
|
|
|
|
return all_embeddings |
|
|
|
|
|
|
|
|
def encode_queries(self, sentences, batch_size=16, convert_to_numpy=False, |
|
|
convert_to_tensor=True, show_progress_bar=True, max_seq_length=None, **kwargs): |
|
|
if max_seq_length is None: |
|
|
max_seq_length = self.inf_seq_length |
|
|
|
|
|
if convert_to_tensor == convert_to_numpy: |
|
|
convert_to_tensor=True |
|
|
convert_to_numpy=False |
|
|
|
|
|
return self.encode( |
|
|
sentences=sentences, |
|
|
batch_size=batch_size, |
|
|
convert_to_numpy=convert_to_numpy, |
|
|
convert_to_tensor=convert_to_tensor, |
|
|
show_progress_bar=show_progress_bar, |
|
|
max_seq_length=max_seq_length, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
def encode_corpus(self, sentences, batch_size=16, convert_to_numpy=False, |
|
|
convert_to_tensor=True, show_progress_bar=True, max_seq_length=None, **kwargs): |
|
|
|
|
|
if max_seq_length is None: |
|
|
max_seq_length = self.inf_seq_length |
|
|
|
|
|
if convert_to_tensor == convert_to_numpy: |
|
|
convert_to_tensor=True |
|
|
convert_to_numpy=False |
|
|
sentences = [sentence['title']+' '+sentence['text'] for sentence in sentences] |
|
|
|
|
|
return self.encode( |
|
|
sentences=sentences, |
|
|
batch_size=batch_size, |
|
|
convert_to_numpy=convert_to_numpy, |
|
|
convert_to_tensor=convert_to_tensor, |
|
|
show_progress_bar=show_progress_bar, |
|
|
max_seq_length=max_seq_length, |
|
|
**kwargs |
|
|
|
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None: |
|
|
""" |
|
|
Stops all processes started with start_multi_process_pool. |
|
|
|
|
|
Args: |
|
|
pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
for p in pool["processes"]: |
|
|
p.terminate() |
|
|
|
|
|
for p in pool["processes"]: |
|
|
p.join() |
|
|
p.close() |
|
|
|
|
|
pool["input"].close() |
|
|
pool["output"].close() |
|
|
pool = None |
|
|
|
|
|
def stop_self_pool(self): |
|
|
if self.pool is not None: |
|
|
self.stop_multi_process_pool(self.pool) |
|
|
self.pool = None |
|
|
try: |
|
|
self.model.to('cpu') |
|
|
torch.cuda.empty_cache() |
|
|
except: |
|
|
pass |
|
|
if gc is not None and callable(gc.collect): |
|
|
gc.collect() |
|
|
|
|
|
def __del__(self): |
|
|
self.stop_self_pool() |
|
|
|
|
|
|