File size: 8,065 Bytes
7900f86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import math
import torch
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers import PreTrainedModel

from .PreTrainedRMTConfig import PreTrainedRMTConfig

class MemoryCell(torch.nn.Module):
    """Holds memory tensors.

    Replicates memory tensor for each batch size.

    Adds memory tokens to the input tensor and returns that tensor.

    Processes the model output and returns a new memory state.



    Parameters

    ----------

    torch : _type_

        _description_

    """

    def __init__(self, base_model, num_mem_tokens):
        super().__init__()
        self.model = base_model
        self.create_memory(num_mem_tokens)
        self.config = base_model.config
        
        # token_type_embeddingsの追加
        #self.token_type_embeddings = torch.nn.Embedding(2, getattr(self.model.config, "n_embd", self.model.config.hidden_size))

    def create_memory(self, num_mem_tokens):
        """Randomly initializes an embedding matrix (tensor) for memory tokens and registers it for gradient computation.

           Sets read and write positions for memory tokens.



        Parameters

        ----------

        num_mem_tokens : _type_

            Number of memory tokens.

        """
        self.read_memory_position = range(num_mem_tokens)
        self.write_memory_position = range(-num_mem_tokens, 0)
        
        self.num_mem_tokens = num_mem_tokens
        embeddings = self.model.get_input_embeddings()
        memory_dim = getattr(self.model.config, "n_embd", self.model.config.hidden_size)
        memory_weights = (
            torch.randn((num_mem_tokens, memory_dim))# * embeddings.weight.data.std()
        )
        
        self.register_parameter(
            "memory", torch.nn.Parameter(memory_weights, requires_grad=True)
        )

    def set_memory(self, input_shape):
        """Replicates memory tensor for each batch size



        Parameters

        ----------

        input_shape : _type_

            _description_



        Returns

        -------

        _type_

            Replicated memory tensor. (batch_size, num_mem_tokens, memory_dim)

        """
        memory = self.memory.repeat(
            input_shape[0], 1, 1
        )  #  メモリテンソルをバッチサイズ分だけ複製する
        return memory  # (batch_size, num_mem_tokens, memory_dim)

    def forward(self, input_ids, memory_state=None, **kwargs):
        """Performs inference.



        Parameters

        ----------

        input_ids : torch.Tensor

            Input tensor.

        memory_state : torch.Tensor, optional

            Memory tensor, by default None (num_mem_tokens, memory_dim)



        Returns

        -------

        tuple(tuple, torch.Tensor)

            out : tuple

                Model output.

            new_memory_state : torch.Tensor

                New memory state.

        """
        if memory_state is None:
            # メモリテンソルをバッチサイズ分だけ複製する
            memory_state = self.set_memory(input_ids.shape)

        # メモリトークンを入力テンソルに追加し、そのテンソルを返す
        seg_kwargs = self.process_input(input_ids, memory_state, **kwargs)
        out = self.model(**seg_kwargs)
        #print(out)

        # モデルの出力を処理し、新しいメモリ状態を返す
        out, new_memory_state = self.process_output(out, **kwargs)

        return out, new_memory_state

    def process_input(self, input_ids, memory_state, **kwargs):
        """Adds memory tokens to the input tensor and returns that tensor



        Parameters

        ----------

        input_ids : _type_

            Input tensor.

        memory_state : _type_

            Memory tensor.



        Returns

        -------

        _type_

            Input tensor with added memory tokens. (batch_size, seq_len, hidden_size)

        """
        seg_kwargs = dict(**kwargs)

        inputs_embeds = kwargs.get("inputs_embeds")
        if inputs_embeds is None:
            inputs_embeds = self.model.get_input_embeddings()(input_ids)
        if inputs_embeds.shape[0] != memory_state.shape[0]: # バッチサイズが異なる場合
            memory_state = self.set_memory(inputs_embeds.shape)
        
        # メモリトークンを入力テンソルに追加
        inputs_embeds = torch.cat(
            [memory_state, inputs_embeds, memory_state], dim=1
        ).to(input_ids.device)
        """

        # token_type_idsの生成

        token_type_ids = torch.zeros_like(inputs_embeds[:, :, 0], dtype=torch.long)

        token_type_ids[:, self.num_mem_tokens:-self.num_mem_tokens] = 1



        # token_type_embeddingsの追加と入力の更新

        token_type_embeds = self.token_type_embeddings(token_type_ids)

        inputs_embeds = inputs_embeds + token_type_embeds

        """

        seg_kwargs["input_ids"] = None
        seg_kwargs["inputs_embeds"] = inputs_embeds
        if kwargs.get("attention_mask") is not None:
            seg_kwargs["attention_mask"] = self.pad_attention_mask(
                kwargs["attention_mask"], inputs_embeds.shape
            )
        seg_kwargs["output_hidden_states"] = True
        
        # Positional Embeddings
        pos_mem1 = torch.arange(self.num_mem_tokens, device=input_ids.device)
        pos_mem2 = torch.arange(self.num_mem_tokens, self.num_mem_tokens * 2, device=input_ids.device)
        pos_seg = torch.arange(self.num_mem_tokens * 2, self.num_mem_tokens * 2 + input_ids.shape[1], device=input_ids.device)
        pos = torch.cat([pos_mem1, pos_seg, pos_mem2], dim=0)
        pos = pos.unsqueeze(0).expand(input_ids.shape[0], -1)
        seg_kwargs["position_ids"] = pos
        
        return seg_kwargs

    def pad_attention_mask(self, attention_mask, shape):
        if self.num_mem_tokens in {0, None}:
            return attention_mask
        else:
            attention_mask = torch.cat(
                [
                    torch.ones(
                        shape[0], self.num_mem_tokens, device=attention_mask.device
                    ),
                    attention_mask,
                    torch.ones(
                        shape[0], self.num_mem_tokens, device=attention_mask.device
                    ),
                ],
                dim=1,
            )
            return attention_mask

    def compute_logpi(mean, stddev, action):
        a1 =-0.5 * torch.log(2*torch.fill(stddev.shape, math.pi))
        a2 = -torch.log(stddev)
        a3 = -0.5 * (((action - mean) / stddev) ** 2)
        return a1 + a2 + a3

    def process_output(self, model_outputs, **kwargs):
        if self.num_mem_tokens not in {0, None}:
            out = CausalLMOutputWithCrossAttentions()
            memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens :]
            out["logits"] = model_outputs.logits[
                :, self.num_mem_tokens : -self.num_mem_tokens
            ]

            if kwargs.get("output_hidden_states"):
                out["hidden_states"] = [
                    lh[:, self.num_mem_tokens : -self.num_mem_tokens]
                    for lh in model_outputs.hidden_states
                ]
            if kwargs.get("output_attentions"):
                out["attentions"] = model_outputs["attentions"]
        else:
            memory_state = None
            out = model_outputs

        return out, memory_state

    def generate(self, input_ids, memory_state, attention_mask, **generate_kwargs):
        if memory_state is None:
            memory_state = self.set_memory(input_ids.shape)

        seg_kwargs = self.process_input(input_ids, memory_state, attention_mask=attention_mask)
        out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'], attention_mask=seg_kwargs['attention_mask'], **generate_kwargs)
        return out