|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
try: |
|
|
import torch_npu |
|
|
from torch_npu.contrib import transfer_to_npu |
|
|
DEVICE_TYPE = "npu" |
|
|
except ModuleNotFoundError: |
|
|
DEVICE_TYPE = "cuda" |
|
|
|
|
|
|
|
|
class TransformersTextEncoderBase(nn.Module): |
|
|
""" |
|
|
Base class for text encoding using HuggingFace Transformers models. |
|
|
|
|
|
""" |
|
|
def __init__(self, model_name: str): |
|
|
super().__init__() |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text: list[str], |
|
|
): |
|
|
device = self.model.device |
|
|
batch = self.tokenizer( |
|
|
text, |
|
|
max_length=self.tokenizer.model_max_length, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
input_ids = batch.input_ids.to(device) |
|
|
attention_mask = batch.attention_mask.to(device) |
|
|
output: BaseModelOutput = self.model( |
|
|
input_ids=input_ids, attention_mask=attention_mask |
|
|
) |
|
|
output = output.last_hidden_state |
|
|
mask = (attention_mask == 1).to(device) |
|
|
|
|
|
return {"output": output, "mask": mask} |
|
|
|
|
|
|
|
|
class T5TextEncoder(TransformersTextEncoderBase): |
|
|
""" |
|
|
Text encoder using T5 encoder model. |
|
|
""" |
|
|
def __init__(self, model_name: str = "/mnt/petrelfs/zhengzihao/cache/google-flan-t5-large"): |
|
|
nn.Module.__init__(self) |
|
|
self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
self.model = T5EncoderModel.from_pretrained(model_name) |
|
|
for param in self.model.parameters(): |
|
|
param.requires_grad = False |
|
|
self.eval() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text: list[str], |
|
|
): |
|
|
with torch.no_grad(), torch.amp.autocast( |
|
|
device_type=DEVICE_TYPE, enabled=False |
|
|
): |
|
|
return super().forward(text) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
text_encoder = T5TextEncoder() |
|
|
text = ["dog barking and cat moving"] |
|
|
text_encoder.eval() |
|
|
with torch.no_grad(): |
|
|
output = text_encoder(text) |
|
|
print(output["output"].shape) |
|
|
|