|
|
import math
|
|
|
import torch
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
|
from .PreTrainedRMTConfig import PreTrainedRMTConfig
|
|
|
from .MemoryCell import MemoryCell
|
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
from transformers import PreTrainedModel
|
|
|
|
|
|
class RecurrentWrapper(torch.nn.Module):
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
memory_cell: MemoryCell,
|
|
|
is_memory_all: bool,
|
|
|
max_n_segments: int,
|
|
|
input_seg_len: int,
|
|
|
output_seg_len: int,
|
|
|
align: str = "left"):
|
|
|
|
|
|
super().__init__()
|
|
|
self.memory_cell:MemoryCell = memory_cell
|
|
|
self.is_memory_all = is_memory_all
|
|
|
self.memory_state: torch.Tensor = None
|
|
|
self.config = memory_cell.config
|
|
|
self.max_n_segments = max_n_segments
|
|
|
self.input_seg_len = input_seg_len
|
|
|
self.output_seg_len = output_seg_len
|
|
|
self.align = align
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids,
|
|
|
labels=None,
|
|
|
labels_mask=None,
|
|
|
inputs_embeds=None,
|
|
|
attention_mask=None,
|
|
|
output_attentions=None,
|
|
|
output_hidden_states=None,
|
|
|
**kwargs
|
|
|
):
|
|
|
"""Performs inference.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
input_ids : torch.Tensor
|
|
|
Input tensor. (batch_size, seq_len * n_segments)
|
|
|
labels : _type_, torch.Tensor
|
|
|
Input tensor. (batch_size, seq_len * n_segments)
|
|
|
|
|
|
Returns
|
|
|
----------
|
|
|
dict
|
|
|
"loss" : torch.Tensor
|
|
|
Loss value.
|
|
|
"logits" : torch.Tensor
|
|
|
Model output.
|
|
|
"out[f"{key}_{seg_num}"]" : torch.Tensor
|
|
|
Output for each segment.
|
|
|
"""
|
|
|
if self.memory_state is not None:
|
|
|
if self.is_memory_all is False:
|
|
|
self.memory_state = None
|
|
|
else :
|
|
|
self.memory_state.detach()
|
|
|
|
|
|
|
|
|
segmented = self.segment(
|
|
|
self.input_seg_len,
|
|
|
input_ids=input_ids,
|
|
|
inputs_embeds=inputs_embeds,
|
|
|
attention_mask=attention_mask,
|
|
|
)
|
|
|
|
|
|
cell_outputs = []
|
|
|
for seg_num, segment in enumerate(segmented):
|
|
|
cell_out, self.memory_state = self.memory_cell(
|
|
|
**segment, memory_state=self.memory_state, **kwargs
|
|
|
)
|
|
|
cell_outputs.append(cell_out)
|
|
|
a = self.manage_gradients(
|
|
|
self.memory_state, seg_num, len(segmented)
|
|
|
)
|
|
|
|
|
|
|
|
|
out = self.process_outputs(
|
|
|
cell_outputs,
|
|
|
labels=labels,
|
|
|
labels_mask=labels_mask,
|
|
|
output_attentions=output_attentions,
|
|
|
output_hidden_states=output_hidden_states,
|
|
|
)
|
|
|
return out
|
|
|
|
|
|
def log(self, t, eps = 1e-20):
|
|
|
return torch.log(t.clamp(min = eps))
|
|
|
|
|
|
def gumbel_noise(self, t):
|
|
|
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
|
return -self.log(-self.log(noise))
|
|
|
|
|
|
def gumbel_sample(self, t, temperature = 1., dim = -1):
|
|
|
return ((t / max(float(temperature), float(1e-10))) + self.gumbel_noise(t)).argmax(dim = dim)
|
|
|
|
|
|
def top_k(self, logits, thres = 0.9):
|
|
|
k = math.ceil((1 - thres) * logits.shape[-1])
|
|
|
val, ind = torch.topk(logits, k)
|
|
|
probs = torch.full_like(logits, float('-inf'))
|
|
|
probs.scatter_(1, ind, val)
|
|
|
return probs
|
|
|
|
|
|
def segment(self, seg_len, **kwargs):
|
|
|
"""
|
|
|
Segments input tensors and adjusts their size. Returns a list of dicts.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
**kwargs : dict
|
|
|
Tensors to be segmented.
|
|
|
Specify tensors that need to be split in keyword argument format.
|
|
|
Example: segment(input_ids=tensor1, attention_mask=tensor2)
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
segments : list of dict
|
|
|
List of dictionaries containing segmented tensors.
|
|
|
Example: [{'input_ids': segment1, 'attention_mask': segment1}, {'input_ids': segment2, 'attention_mask': segment2}, ...]
|
|
|
|
|
|
Notes
|
|
|
-----
|
|
|
- This function uses the `self.split_tensor` method, so `self` must implement it.
|
|
|
- Each tensor is split in a specific way by `self.split_tensor`. The same keys are stored with the same order of indices.
|
|
|
"""
|
|
|
segments = []
|
|
|
for k, tensor in kwargs.items():
|
|
|
if tensor is not None:
|
|
|
k_segments = self.split_tensor(
|
|
|
tensor, seg_len
|
|
|
)
|
|
|
for s, k_seg in enumerate(k_segments):
|
|
|
if s < len(segments):
|
|
|
segments[s][k] = k_seg
|
|
|
else:
|
|
|
segments.append({k: k_seg})
|
|
|
|
|
|
return segments
|
|
|
|
|
|
def split_tensor(self, tensor, seg_len):
|
|
|
if self.align in {"left", None}:
|
|
|
split_inds = list(range(0, tensor.shape[1], seg_len)) + [
|
|
|
tensor.shape[1]
|
|
|
]
|
|
|
segments = [
|
|
|
tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])
|
|
|
]
|
|
|
elif self.align in {"right", None}:
|
|
|
split_inds = (list(range(tensor.shape[1], 0, -seg_len)) + [0])[::-1]
|
|
|
segments = [
|
|
|
tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])
|
|
|
]
|
|
|
elif self.align == "center":
|
|
|
n_seg = math.ceil(tensor.shape[1] / seg_len)
|
|
|
segments = torch.chunk(tensor, n_seg, dim=1)
|
|
|
else:
|
|
|
split_inds = list(range(0, tensor.shape[1], seg_len)) + [
|
|
|
tensor.shape[1]
|
|
|
]
|
|
|
segments = [
|
|
|
tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])
|
|
|
]
|
|
|
return segments
|
|
|
|
|
|
def process_outputs(self, cell_outputs, **kwargs):
|
|
|
"""Calculates loss for a list of outputs. Also concatenates and returns logits.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
cell_outputs : list of torch.Tensor
|
|
|
List containing outputs from each segment.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
dict
|
|
|
"loss" : torch.Tensor
|
|
|
Loss value.
|
|
|
"logits" : torch.Tensor
|
|
|
Model output.
|
|
|
"out[f"{key}_{seg_num}"]" : torch.Tensor
|
|
|
Output for each segment.
|
|
|
"""
|
|
|
out = CausalLMOutputWithCrossAttentions()
|
|
|
full_logits = torch.cat(
|
|
|
[o.logits for o in cell_outputs], dim=1
|
|
|
)
|
|
|
|
|
|
if kwargs.get("output_hidden_states"):
|
|
|
full_hidden_states = tuple(
|
|
|
[
|
|
|
torch.cat(layer_hs, dim=1)
|
|
|
for layer_hs in zip(*[o.hidden_states for o in cell_outputs])
|
|
|
]
|
|
|
)
|
|
|
|
|
|
labels = kwargs.get("labels")
|
|
|
if labels is not None:
|
|
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
shift_logits = full_logits[..., :-1, :].contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
flat_labels = shift_labels.view(
|
|
|
-1
|
|
|
)
|
|
|
flat_logits = shift_logits.view(
|
|
|
-1, shift_logits.size(-1)
|
|
|
)
|
|
|
|
|
|
loss_fct = CrossEntropyLoss()
|
|
|
labels_mask = kwargs.get("labels_mask")
|
|
|
if labels_mask is not None:
|
|
|
shift_mask = labels_mask[..., :-1].contiguous()
|
|
|
|
|
|
flat_labels = flat_labels[shift_mask.view(-1)]
|
|
|
flat_logits = flat_logits[shift_mask.view(-1)]
|
|
|
out["loss"] = loss_fct(flat_logits, flat_labels)
|
|
|
else:
|
|
|
out["loss"] = 0
|
|
|
print("labels is None")
|
|
|
|
|
|
out["logits"] = full_logits
|
|
|
segment_keys = ["loss", "logits"]
|
|
|
if kwargs.get("output_attentions"):
|
|
|
segment_keys.append("attentions")
|
|
|
if kwargs.get("output_hidden_states"):
|
|
|
segment_keys.append("hidden_states")
|
|
|
out["hidden_states"] = full_hidden_states
|
|
|
|
|
|
for seg_num, o in enumerate(cell_outputs):
|
|
|
for key, value in o.items():
|
|
|
if any([sk in key for sk in segment_keys]):
|
|
|
out[f"{key}_{seg_num}"] = value
|
|
|
|
|
|
return out
|
|
|
|
|
|
def manage_gradients(self, memory_state, seg_num, seg_len):
|
|
|
"""Controls gradient calculation for memory state
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
memory_state : torch.Tensor
|
|
|
Memory state. (batch_size, num_mem_tokens, memory_dim)
|
|
|
seg_num : int
|
|
|
Number of the segment currently being processed.
|
|
|
|
|
|
Returns
|
|
|
----------
|
|
|
bool
|
|
|
Whether to calculate gradients. True: calculate gradients, False: do not calculate gradients
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if seg_num == 0 or self.max_n_segments in {-1, None} or seg_len - seg_num <= self.max_n_segments:
|
|
|
self.memory_state = memory_state
|
|
|
return True
|
|
|
else:
|
|
|
self.memory_state = memory_state.detach()
|
|
|
return False
|
|
|
|
|
|
def generate_groq(
|
|
|
self,
|
|
|
input_ids,
|
|
|
max_length=25,
|
|
|
temperature=1.0,
|
|
|
top_k=None,
|
|
|
top_p=None,
|
|
|
do_sample=True,
|
|
|
pad_token_id=None,
|
|
|
eos_token_id=None,
|
|
|
**kwargs
|
|
|
):
|
|
|
"""
|
|
|
Generate new tokens based on the input sequence.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
input_ids : torch.Tensor
|
|
|
Initial input sequence. Shape: (batch_size, seq_len)
|
|
|
max_length : int
|
|
|
Maximum number of tokens to generate (including initial sequence length).
|
|
|
temperature : float, default 1.0
|
|
|
Temperature parameter for sampling. Lower values make it more deterministic.
|
|
|
top_k : int, optional
|
|
|
Used to sample from top k tokens.
|
|
|
top_p : float, optional
|
|
|
Used to filter tokens based on cumulative probability p.
|
|
|
do_sample : bool, default True
|
|
|
If True, use probabilistic sampling. If False, use greedy decoding.
|
|
|
pad_token_id : int, optional
|
|
|
ID of the padding token.
|
|
|
eos_token_id : int, optional
|
|
|
ID of the end-of-sequence token.
|
|
|
**kwargs : dict
|
|
|
Additional arguments passed to MemoryCell.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
torch.Tensor
|
|
|
Generated token sequence. Shape: (batch_size, generated_seq_len)
|
|
|
"""
|
|
|
|
|
|
segmented = self.segment(self.input_seg_len, input_ids=input_ids)
|
|
|
memory_state = None
|
|
|
for segment in segmented:
|
|
|
cell_out, memory_state = self.memory_cell(
|
|
|
**segment, memory_state=memory_state, **kwargs
|
|
|
)
|
|
|
|
|
|
|
|
|
output_ids = input_ids
|
|
|
while output_ids.shape[1] < max_length:
|
|
|
|
|
|
last_token = output_ids[:, -1:]
|
|
|
|
|
|
cell_out, memory_state = self.memory_cell(
|
|
|
input_ids=last_token, memory_state=memory_state, **kwargs
|
|
|
)
|
|
|
|
|
|
logits = cell_out.logits[:, -1, :]
|
|
|
|
|
|
next_token = self.sample_next_token(
|
|
|
logits, temperature, top_k, top_p, do_sample
|
|
|
)
|
|
|
|
|
|
output_ids = torch.cat([output_ids, next_token], dim=1)
|
|
|
|
|
|
if eos_token_id is not None and next_token.item() == eos_token_id:
|
|
|
break
|
|
|
|
|
|
return output_ids
|
|
|
|
|
|
def sample_next_token(self, logits, temperature=1, top_k=50, top_p=0.9, do_sample=False):
|
|
|
"""
|
|
|
logits から次のトークンをサンプリングする。
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
logits : torch.Tensor
|
|
|
トークンの予測スコア。形状: (batch_size, vocab_size)
|
|
|
temperature : float
|
|
|
サンプリング時の温度パラメータ。
|
|
|
top_k : int, optional
|
|
|
上位 k トークンからサンプリングする場合に使用。
|
|
|
top_p : float, optional
|
|
|
累積確率 p に基づいてトークンをフィルタリングする場合に使用。
|
|
|
do_sample : bool
|
|
|
True の場合、確率的サンプリングを使用。False の場合、貪欲法を使用。
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
torch.Tensor
|
|
|
サンプリングされたトークン。形状: (batch_size, 1)
|
|
|
"""
|
|
|
if do_sample:
|
|
|
if temperature != 1.0:
|
|
|
logits = logits / temperature
|
|
|
if top_k is not None:
|
|
|
logits = self.top_k_groq(logits, top_k)
|
|
|
if top_p is not None:
|
|
|
logits = self.top_p(logits, top_p)
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
else:
|
|
|
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
return next_token
|
|
|
|
|
|
def top_k_groq(self, logits, k):
|
|
|
"""
|
|
|
上位 k トークンのみを考慮するように logits をフィルタリングする。
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
logits : torch.Tensor
|
|
|
トークンの予測スコア。形状: (batch_size, vocab_size)
|
|
|
k : int
|
|
|
上位 k トークンを選択。
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
torch.Tensor
|
|
|
フィルタリングされた logits。形状: (batch_size, vocab_size)
|
|
|
"""
|
|
|
values, indices = torch.topk(logits, k, dim=-1)
|
|
|
min_values = values[:, -1].unsqueeze(-1).expand_as(logits)
|
|
|
return torch.where(
|
|
|
logits >= min_values, logits, torch.full_like(logits, float('-inf'))
|
|
|
)
|
|
|
|
|
|
def top_p(self, logits, p):
|
|
|
"""
|
|
|
累積確率 p に基づいてトークンをフィルタリングする。
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
logits : torch.Tensor
|
|
|
トークンの予測スコア。形状: (batch_size, vocab_size)
|
|
|
p : float
|
|
|
累積確率の閾値。
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
torch.Tensor
|
|
|
フィルタリングされた logits。形状: (batch_size, vocab_size)
|
|
|
"""
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
sorted_indices_to_remove = cumulative_probs > p
|
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
|
|
sorted_indices_to_remove[:, 0] = 0
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
|
logits.scatter_(1, indices_to_remove, float('-inf'))
|
|
|
return logits
|
|
|
|
|
|
def generate_default(self, input_ids, attention_mask = None, **generate_kwargs):
|
|
|
memory_state = None
|
|
|
segmented = self.segment(self.input_seg_len, input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
for seg_num, segment in enumerate(segmented[:-1]):
|
|
|
cell_out, memory_state = self.memory_cell(**segment, memory_state=memory_state)
|
|
|
|
|
|
final_segment = segmented[-1]
|
|
|
out = self.memory_cell.generate(**final_segment, memory_state=memory_state, **generate_kwargs)
|
|
|
|
|
|
return out
|
|
|
|
|
|
def generate(self, input_ids:torch.Tensor, **generate_kwargs):
|
|
|
with torch.no_grad():
|
|
|
if self.is_memory_all is False:
|
|
|
self.memory_state = None
|
|
|
elif self.memory_state is not None:
|
|
|
self.memory_state.detach()
|
|
|
|
|
|
|
|
|
segmented = self.segment(self.input_seg_len, input_ids=input_ids)
|
|
|
|
|
|
for seg_num, segment in enumerate(segmented[:-1]):
|
|
|
|
|
|
cell_out, self.memory_state = self.memory_cell(
|
|
|
**segment, memory_state=self.memory_state, output_hidden_states=True
|
|
|
)
|
|
|
|
|
|
curr_segment = segmented[-1]
|
|
|
"""
|
|
|
outs = []
|
|
|
for i in range(math.ceil(generate_kwargs["max_length"] / self.input_seg_len)):
|
|
|
out = self.memory_cell.generate(
|
|
|
**curr_segment,
|
|
|
memory_state=self.memory_state,
|
|
|
max_length=min(generate_kwargs["max_length"] - i * self.input_seg_len, self.input_seg_len - curr_segment["input_ids"].shape[-1]),
|
|
|
**generate_kwargs)
|
|
|
outs.append(out)
|
|
|
|
|
|
for out in outs:
|
|
|
for key, value in out.items():
|
|
|
curr_segment[key] = torch.cat((curr_segment[key], value), dim = -1)
|
|
|
self.memory_state = out["memory_state"]
|
|
|
"""
|
|
|
|
|
|
output_ids = None
|
|
|
if generate_kwargs.get("max_length") is None:
|
|
|
length = generate_kwargs.get("max_new_tokens", 25)
|
|
|
else:
|
|
|
length = generate_kwargs.get("max_length") - curr_segment["input_ids"].shape[-1]
|
|
|
|
|
|
for ind in range(length):
|
|
|
|
|
|
out, next_memories = self.memory_cell(**curr_segment, memory_state=self.memory_state, output_hidden_states=True)
|
|
|
logits = out["logits"][:,-1]
|
|
|
sampled = self.sample_next_token(logits, temperature = generate_kwargs.get("temperature", 1), top_k = generate_kwargs.get("top_k", 0.9), top_p = generate_kwargs.get("top_p", 0.9), do_sample = generate_kwargs.get("do_sample", False))
|
|
|
|
|
|
|
|
|
|
|
|
output_ids = sampled if output_ids is None else torch.cat((output_ids, sampled), dim = 1)
|
|
|
|
|
|
curr_segment["input_ids"] = torch.cat((curr_segment["input_ids"], sampled), dim = -1)
|
|
|
|
|
|
|
|
|
if curr_segment["input_ids"].shape[-1] > self.input_seg_len:
|
|
|
for key, value in curr_segment.items():
|
|
|
curr_segment[key] = value[:, -1:]
|
|
|
self.memory_state = next_memories
|
|
|
|
|
|
return output_ids
|
|
|
|
|
|
def generate_with_tokenizer(self, tokenizer, input_text, **generate_kwargs):
|
|
|
if isinstance(input_text, str):
|
|
|
tok = tokenizer(input_text, return_tensors="pt")
|
|
|
tok["input_ids"] = tok["input_ids"]
|
|
|
tok["attention_mask"] = tok["attention_mask"]
|
|
|
else:
|
|
|
tok = tokenizer(input_text)
|
|
|
for k, v in tok.items():
|
|
|
pd = tokenizer.pad_token_id if k != 'attention_mask' else 0
|
|
|
tok[k] = pad_sequence([torch.tensor(o) for o in v], padding_value=pd, padding_side="left").T
|
|
|
|
|
|
output_ids = self.generate(tok["input_ids"], **generate_kwargs)
|
|
|
|
|
|
if isinstance(input_text, str):
|
|
|
return tokenizer.decode(torch.cat((tok["input_ids"][0], output_ids[0]), dim=0), skip_special_tokens=True)
|
|
|
else:
|
|
|
return tokenizer.batch_decode(torch.cat((tok["input_ids"], output_ids), dim=-1), skip_special_tokens=True)
|
|
|
|
|
|
def can_generate(self):
|
|
|
return True
|
|
|
|
|
|
|