File size: 3,183 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from torch import nn
from torch.cuda.amp import autocast

from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BertConfig,
    BertModel,
    RobertaConfig,
    RobertaModel,
)


class Bert(nn.Module):
    def __init__(
        self,
        pretrained_model_name_or_path,
        dtype="float32",
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.dtype = getattr(torch, dtype)

        self.config = BertConfig.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path
        )
        self.bert_model = BertModel.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            add_pooling_layer=False,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path
        )

        self.bert_model.eval()
        for name, param in self.bert_model.named_parameters():
            param.requires_grad = False
            param.data = param.data.to(self.dtype)

        self.register_buffer("unused_tensor", torch.zeros(1), False)

        self.text_list_to_feature = {}

    @property
    def device(self):
        return self.unused_tensor.device

    @autocast(enabled=False)
    @torch.no_grad()
    def forward_text(self, text_list, cache=False):

        if cache and tuple(text_list) in self.text_list_to_feature:
            return self.text_list_to_feature[tuple(text_list)]

        tokenized = self.tokenizer.batch_encode_plus(
            text_list,
            max_length=256,
            padding="max_length" if True else "longest",
            return_special_tokens_mask=True,
            return_tensors="pt",
            truncation=True,
        ).to(self.device)

        input_ids = tokenized.input_ids  # (bs, seq_len)
        attention_mask = tokenized.attention_mask  # (bs, seq_len)

        max_batch_size = 500
        if len(input_ids) > max_batch_size:
            chunck_num = len(input_ids) // max_batch_size + 1
            outputss = [
                self.bert_model(
                    input_ids=input_ids[
                        chunck_id * max_batch_size : (chunck_id + 1) * max_batch_size
                    ],
                    attention_mask=attention_mask[
                        chunck_id * max_batch_size : (chunck_id + 1) * max_batch_size
                    ],
                )
                for chunck_id in range(chunck_num)
            ]

            last_hidden_state = torch.cat(
                [outputs.last_hidden_state for outputs in outputss], dim=0
            )
        else:
            outputs = self.bert_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )

            last_hidden_state = outputs.last_hidden_state

        end_token_idx = input_ids.argmin(dim=-1) - 1

        ret = {
            "end_token_idx": end_token_idx,
            "attention_mask": attention_mask,
            "last_hidden_state": last_hidden_state,
        }

        if cache:
            self.text_list_to_feature[tuple(text_list)] = ret

        return ret