File size: 2,256 Bytes
79f3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
from transformers.modeling_outputs import BaseModelOutput

try:
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    DEVICE_TYPE = "npu"
except ModuleNotFoundError:
    DEVICE_TYPE = "cuda"


class TransformersTextEncoderBase(nn.Module):
    """
    Base class for text encoding using HuggingFace Transformers models.

    """
    def __init__(self, model_name: str):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(
        self,
        text: list[str],
    ):
        device = self.model.device
        batch = self.tokenizer(
            text,
            max_length=self.tokenizer.model_max_length,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )
        input_ids = batch.input_ids.to(device)
        attention_mask = batch.attention_mask.to(device)
        output: BaseModelOutput = self.model(
            input_ids=input_ids, attention_mask=attention_mask
        )
        output = output.last_hidden_state
        mask = (attention_mask == 1).to(device)

        return {"output": output, "mask": mask}


class T5TextEncoder(TransformersTextEncoderBase):
    """
    Text encoder using T5 encoder model.
    """
    def __init__(self, model_name: str = "/mnt/petrelfs/zhengzihao/cache/google-flan-t5-large"):
        nn.Module.__init__(self)
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5EncoderModel.from_pretrained(model_name)
        for param in self.model.parameters():
            param.requires_grad = False
        self.eval()

    def forward(
        self,
        text: list[str],
    ):
        with torch.no_grad(), torch.amp.autocast(
            device_type=DEVICE_TYPE, enabled=False
        ):
            return super().forward(text)


if __name__ == '__main__':
    text_encoder = T5TextEncoder()
    text = ["dog barking and cat moving"]
    text_encoder.eval()
    with torch.no_grad():
        output = text_encoder(text)
    print(output["output"].shape)
    #print(output)