File size: 8,065 Bytes
7900f86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import math
import torch
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers import PreTrainedModel
from .PreTrainedRMTConfig import PreTrainedRMTConfig
class MemoryCell(torch.nn.Module):
"""Holds memory tensors.
Replicates memory tensor for each batch size.
Adds memory tokens to the input tensor and returns that tensor.
Processes the model output and returns a new memory state.
Parameters
----------
torch : _type_
_description_
"""
def __init__(self, base_model, num_mem_tokens):
super().__init__()
self.model = base_model
self.create_memory(num_mem_tokens)
self.config = base_model.config
# token_type_embeddingsの追加
#self.token_type_embeddings = torch.nn.Embedding(2, getattr(self.model.config, "n_embd", self.model.config.hidden_size))
def create_memory(self, num_mem_tokens):
"""Randomly initializes an embedding matrix (tensor) for memory tokens and registers it for gradient computation.
Sets read and write positions for memory tokens.
Parameters
----------
num_mem_tokens : _type_
Number of memory tokens.
"""
self.read_memory_position = range(num_mem_tokens)
self.write_memory_position = range(-num_mem_tokens, 0)
self.num_mem_tokens = num_mem_tokens
embeddings = self.model.get_input_embeddings()
memory_dim = getattr(self.model.config, "n_embd", self.model.config.hidden_size)
memory_weights = (
torch.randn((num_mem_tokens, memory_dim))# * embeddings.weight.data.std()
)
self.register_parameter(
"memory", torch.nn.Parameter(memory_weights, requires_grad=True)
)
def set_memory(self, input_shape):
"""Replicates memory tensor for each batch size
Parameters
----------
input_shape : _type_
_description_
Returns
-------
_type_
Replicated memory tensor. (batch_size, num_mem_tokens, memory_dim)
"""
memory = self.memory.repeat(
input_shape[0], 1, 1
) # メモリテンソルをバッチサイズ分だけ複製する
return memory # (batch_size, num_mem_tokens, memory_dim)
def forward(self, input_ids, memory_state=None, **kwargs):
"""Performs inference.
Parameters
----------
input_ids : torch.Tensor
Input tensor.
memory_state : torch.Tensor, optional
Memory tensor, by default None (num_mem_tokens, memory_dim)
Returns
-------
tuple(tuple, torch.Tensor)
out : tuple
Model output.
new_memory_state : torch.Tensor
New memory state.
"""
if memory_state is None:
# メモリテンソルをバッチサイズ分だけ複製する
memory_state = self.set_memory(input_ids.shape)
# メモリトークンを入力テンソルに追加し、そのテンソルを返す
seg_kwargs = self.process_input(input_ids, memory_state, **kwargs)
out = self.model(**seg_kwargs)
#print(out)
# モデルの出力を処理し、新しいメモリ状態を返す
out, new_memory_state = self.process_output(out, **kwargs)
return out, new_memory_state
def process_input(self, input_ids, memory_state, **kwargs):
"""Adds memory tokens to the input tensor and returns that tensor
Parameters
----------
input_ids : _type_
Input tensor.
memory_state : _type_
Memory tensor.
Returns
-------
_type_
Input tensor with added memory tokens. (batch_size, seq_len, hidden_size)
"""
seg_kwargs = dict(**kwargs)
inputs_embeds = kwargs.get("inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if inputs_embeds.shape[0] != memory_state.shape[0]: # バッチサイズが異なる場合
memory_state = self.set_memory(inputs_embeds.shape)
# メモリトークンを入力テンソルに追加
inputs_embeds = torch.cat(
[memory_state, inputs_embeds, memory_state], dim=1
).to(input_ids.device)
"""
# token_type_idsの生成
token_type_ids = torch.zeros_like(inputs_embeds[:, :, 0], dtype=torch.long)
token_type_ids[:, self.num_mem_tokens:-self.num_mem_tokens] = 1
# token_type_embeddingsの追加と入力の更新
token_type_embeds = self.token_type_embeddings(token_type_ids)
inputs_embeds = inputs_embeds + token_type_embeds
"""
seg_kwargs["input_ids"] = None
seg_kwargs["inputs_embeds"] = inputs_embeds
if kwargs.get("attention_mask") is not None:
seg_kwargs["attention_mask"] = self.pad_attention_mask(
kwargs["attention_mask"], inputs_embeds.shape
)
seg_kwargs["output_hidden_states"] = True
# Positional Embeddings
pos_mem1 = torch.arange(self.num_mem_tokens, device=input_ids.device)
pos_mem2 = torch.arange(self.num_mem_tokens, self.num_mem_tokens * 2, device=input_ids.device)
pos_seg = torch.arange(self.num_mem_tokens * 2, self.num_mem_tokens * 2 + input_ids.shape[1], device=input_ids.device)
pos = torch.cat([pos_mem1, pos_seg, pos_mem2], dim=0)
pos = pos.unsqueeze(0).expand(input_ids.shape[0], -1)
seg_kwargs["position_ids"] = pos
return seg_kwargs
def pad_attention_mask(self, attention_mask, shape):
if self.num_mem_tokens in {0, None}:
return attention_mask
else:
attention_mask = torch.cat(
[
torch.ones(
shape[0], self.num_mem_tokens, device=attention_mask.device
),
attention_mask,
torch.ones(
shape[0], self.num_mem_tokens, device=attention_mask.device
),
],
dim=1,
)
return attention_mask
def compute_logpi(mean, stddev, action):
a1 =-0.5 * torch.log(2*torch.fill(stddev.shape, math.pi))
a2 = -torch.log(stddev)
a3 = -0.5 * (((action - mean) / stddev) ** 2)
return a1 + a2 + a3
def process_output(self, model_outputs, **kwargs):
if self.num_mem_tokens not in {0, None}:
out = CausalLMOutputWithCrossAttentions()
memory_state = model_outputs.hidden_states[-1][:, -self.num_mem_tokens :]
out["logits"] = model_outputs.logits[
:, self.num_mem_tokens : -self.num_mem_tokens
]
if kwargs.get("output_hidden_states"):
out["hidden_states"] = [
lh[:, self.num_mem_tokens : -self.num_mem_tokens]
for lh in model_outputs.hidden_states
]
if kwargs.get("output_attentions"):
out["attentions"] = model_outputs["attentions"]
else:
memory_state = None
out = model_outputs
return out, memory_state
def generate(self, input_ids, memory_state, attention_mask, **generate_kwargs):
if memory_state is None:
memory_state = self.set_memory(input_ids.shape)
seg_kwargs = self.process_input(input_ids, memory_state, attention_mask=attention_mask)
out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'], attention_mask=seg_kwargs['attention_mask'], **generate_kwargs)
return out |