|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
|
import torch |
|
|
import numpy as np |
|
|
import math |
|
|
|
|
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
|
|
|
|
|
|
def gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4): |
|
|
""" |
|
|
Generates a [1, length, channels] timing signal consisting of sinusoids |
|
|
Adapted from: |
|
|
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py |
|
|
""" |
|
|
position = np.arange(length) |
|
|
num_timescales = channels // 2 |
|
|
log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1)) |
|
|
inv_timescales = min_timescale * np.exp(np.arange(num_timescales).astype(float) * -log_timescale_increment) |
|
|
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) |
|
|
|
|
|
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) |
|
|
signal = np.pad(signal, [[0, 0], [0, channels % 2]], |
|
|
'constant', constant_values=[0.0, 0.0]) |
|
|
signal = signal.reshape([1, length, channels]) |
|
|
|
|
|
return torch.from_numpy(signal).type(torch.FloatTensor) |
|
|
|
|
|
class ACT_basic(nn.Module): |
|
|
def __init__(self,hidden_size): |
|
|
super(ACT_basic, self).__init__() |
|
|
self.sigma = nn.Sigmoid() |
|
|
self.p = nn.Linear(hidden_size,1) |
|
|
self.p.bias.data.fill_(1) |
|
|
self.threshold = 1 - 0.1 |
|
|
self.eps = 0.1 |
|
|
|
|
|
def forward(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
|
|
|
|
|
|
|
|
noisy_halting = False |
|
|
if 'noisy_halting' in kwargs: |
|
|
noisy_halting = kwargs['noisy_halting'] |
|
|
kwargs.pop('noisy_halting') |
|
|
halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
previous_state = torch.zeros_like(inputs).cuda() |
|
|
step = 0 |
|
|
|
|
|
rest = None |
|
|
|
|
|
while( ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any()): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p = self.sigma(self.p(state)).squeeze(-1) |
|
|
if noisy_halting and self.training: |
|
|
p = p + torch.randn_like(p) * self.eps |
|
|
|
|
|
still_running = (halting_probability < 1.0).float() |
|
|
|
|
|
|
|
|
new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running |
|
|
|
|
|
|
|
|
still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running |
|
|
|
|
|
|
|
|
|
|
|
halting_probability = halting_probability + p * still_running |
|
|
|
|
|
|
|
|
remainders = remainders + new_halted * (1 - halting_probability) |
|
|
|
|
|
|
|
|
halting_probability = halting_probability + new_halted * remainders |
|
|
|
|
|
|
|
|
n_updates = n_updates + still_running + new_halted |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
update_weights = p * still_running + new_halted * remainders |
|
|
|
|
|
if(encoder_output): |
|
|
state, _ = fn((state,encoder_output)) |
|
|
else: |
|
|
|
|
|
state = fn(state, *args, **kwargs) |
|
|
if isinstance(state, tuple): |
|
|
rest = state[1:] |
|
|
state = state[0] |
|
|
|
|
|
|
|
|
previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1)))) |
|
|
|
|
|
|
|
|
|
|
|
step+=1 |
|
|
if rest is None: |
|
|
return previous_state, (remainders,n_updates) |
|
|
else: |
|
|
return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
|
|
|
|
class ACT_constant_depth(): |
|
|
def __init__(self): |
|
|
super(ACT_constant_depth, self).__init__() |
|
|
|
|
|
def __call__(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
|
|
|
|
|
|
|
|
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
previous_state = torch.zeros_like(inputs).cuda() |
|
|
step = 0 |
|
|
|
|
|
rest = None |
|
|
|
|
|
|
|
|
while(step < max_hop): |
|
|
print('constsant depth TRUE') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if(encoder_output): |
|
|
state, _ = fn((state,encoder_output)) |
|
|
else: |
|
|
|
|
|
state = fn(state, *args, **kwargs) |
|
|
if isinstance(state, tuple): |
|
|
rest = state[1:] |
|
|
state = state[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
previous_state = state |
|
|
|
|
|
|
|
|
|
|
|
step+=1 |
|
|
if rest is None: |
|
|
return previous_state, (remainders,n_updates) |
|
|
else: |
|
|
return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
|
class ACTForWholeARMT(nn.Module): |
|
|
def __init__(self,hidden_size): |
|
|
super(ACTForWholeARMT, self).__init__() |
|
|
self.sigma = nn.Sigmoid() |
|
|
self.p = nn.Linear(hidden_size,1) |
|
|
self.p.bias.data.fill_(1) |
|
|
self.threshold = 1 - 0.1 |
|
|
|
|
|
def forward(self, *args, state, inputs, fn_no_update, fn_update, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
halting_probability = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
n_updates = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
previous_state = torch.zeros_like(inputs).cuda() |
|
|
step = 0 |
|
|
|
|
|
rest = None |
|
|
while( ((halting_probability < self.threshold) & (n_updates < max_hop)).byte().any()): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p = self.sigma(self.p(state)).squeeze(-1) |
|
|
|
|
|
still_running = (halting_probability < 1.0).float() |
|
|
|
|
|
|
|
|
new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running |
|
|
|
|
|
|
|
|
still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running |
|
|
|
|
|
|
|
|
|
|
|
halting_probability = halting_probability + p * still_running |
|
|
|
|
|
|
|
|
remainders = remainders + new_halted * (1 - halting_probability) |
|
|
|
|
|
|
|
|
halting_probability = halting_probability + new_halted * remainders |
|
|
|
|
|
|
|
|
n_updates = n_updates + still_running + new_halted |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
update_weights = p * still_running + new_halted * remainders |
|
|
|
|
|
if(encoder_output): |
|
|
if ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any(): |
|
|
state, _ = fn_no_update((state,encoder_output)) |
|
|
else: |
|
|
state, _ = fn_update((state, encoder_output)) |
|
|
else: |
|
|
|
|
|
if ((halting_probability<self.threshold) & (n_updates < max_hop)).byte().any(): |
|
|
state = fn_no_update(state, *args, **kwargs) |
|
|
else: |
|
|
state = fn_update(state, *args, **kwargs) |
|
|
if isinstance(state, tuple): |
|
|
rest = state[1:] |
|
|
state = state[0] |
|
|
|
|
|
|
|
|
previous_state = ((state * update_weights.unsqueeze(-1)) + (previous_state * (1 - update_weights.unsqueeze(-1)))) |
|
|
|
|
|
|
|
|
|
|
|
step+=1 |
|
|
if rest is None: |
|
|
return previous_state, (remainders,n_updates) |
|
|
else: |
|
|
return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
|
class ACTForWholeARMT_constant_depth(): |
|
|
def __init__(self): |
|
|
super(ACTForWholeARMT_constant_depth, self).__init__() |
|
|
|
|
|
|
|
|
def __call__(self, *args, state, inputs, fn_no_update, fn_update, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
|
|
print("\n\n\n\n\n\n\n\n\n\nCONSTANT DEPTH TRUE") |
|
|
|
|
|
|
|
|
remainders = torch.zeros(inputs.shape[0],inputs.shape[1]).cuda() |
|
|
|
|
|
n_updates = torch.full((inputs.shape[0],inputs.shape[1]), max_hop).cuda() |
|
|
|
|
|
previous_state = torch.zeros_like(inputs).cuda() |
|
|
step = 0 |
|
|
|
|
|
rest = None |
|
|
while(step < max_hop): |
|
|
|
|
|
|
|
|
|
|
|
if(encoder_output): |
|
|
if (step < max_hop): |
|
|
state, _ = fn_no_update((state,encoder_output)) |
|
|
else: |
|
|
state, _ = fn_update((state, encoder_output)) |
|
|
else: |
|
|
|
|
|
if (step < max_hop): |
|
|
state = fn_no_update(state, *args, **kwargs) |
|
|
else: |
|
|
state = fn_update(state, *args, **kwargs) |
|
|
if isinstance(state, tuple): |
|
|
rest = state[1:] |
|
|
state = state[0] |
|
|
|
|
|
|
|
|
previous_state = state |
|
|
|
|
|
|
|
|
|
|
|
step+=1 |
|
|
if rest is None: |
|
|
return previous_state, (remainders,n_updates) |
|
|
else: |
|
|
return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
|
|
|
|
class ACT_transformer(nn.Module): |
|
|
def __init__(self, hidden_size, num_heads=4, num_transformer_layers=1, dropout=0.1): |
|
|
super(ACT_transformer, self).__init__() |
|
|
|
|
|
transformer_layer = TransformerEncoderLayer( |
|
|
d_model=hidden_size, |
|
|
nhead=num_heads, |
|
|
dim_feedforward=hidden_size, |
|
|
dropout=dropout, |
|
|
norm_first=True |
|
|
) |
|
|
self.transformer = TransformerEncoder(transformer_layer, |
|
|
num_layers=num_transformer_layers) |
|
|
|
|
|
|
|
|
self.logit_ff = nn.Linear(hidden_size, 1) |
|
|
self.logit_ff.bias.data.fill_(1) |
|
|
|
|
|
|
|
|
self.sigma = nn.Sigmoid() |
|
|
self.threshold = 1 - 0.1 |
|
|
|
|
|
def generate_causal_mask(self, seq_len): |
|
|
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) |
|
|
mask = mask.masked_fill(mask == 1, float('-inf')) |
|
|
return mask |
|
|
|
|
|
def forward(self, *args, state, inputs, fn, time_enc, pos_enc, max_hop, encoder_output=None, **kwargs): |
|
|
batch_size, seq_len, hidden_size = inputs.shape |
|
|
halting_probability = torch.zeros(batch_size, seq_len).cuda() |
|
|
remainders = torch.zeros(batch_size, seq_len).cuda() |
|
|
n_updates = torch.zeros(batch_size, seq_len).cuda() |
|
|
previous_state = torch.zeros_like(inputs).cuda() |
|
|
step = 0 |
|
|
rest = None |
|
|
|
|
|
causal_mask = self.generate_causal_mask(seq_len).cuda() |
|
|
|
|
|
while ((halting_probability < self.threshold) & (n_updates < max_hop)).byte().any(): |
|
|
state_transformed = self.transformer( |
|
|
state.permute(1, 0, 2), |
|
|
mask=causal_mask |
|
|
) |
|
|
state_transformed = state_transformed.permute(1, 0, 2) |
|
|
|
|
|
|
|
|
p = self.sigma(self.logit_ff(state_transformed)).squeeze(-1) |
|
|
|
|
|
|
|
|
still_running = (halting_probability < 1.0).float() |
|
|
new_halted = (halting_probability + p * still_running > self.threshold).float() * still_running |
|
|
still_running = (halting_probability + p * still_running <= self.threshold).float() * still_running |
|
|
halting_probability = halting_probability + p * still_running |
|
|
remainders = remainders + new_halted * (1 - halting_probability) |
|
|
halting_probability = halting_probability + new_halted * remainders |
|
|
n_updates = n_updates + still_running + new_halted |
|
|
update_weights = p * still_running + new_halted * remainders |
|
|
|
|
|
if encoder_output is not None: |
|
|
state, _ = fn((state, encoder_output)) |
|
|
else: |
|
|
state = fn(state, *args, **kwargs) |
|
|
if isinstance(state, tuple): |
|
|
rest = state[1:] |
|
|
state = state[0] |
|
|
|
|
|
previous_state = ( |
|
|
(state * update_weights.unsqueeze(-1)) + |
|
|
(previous_state * (1 - update_weights.unsqueeze(-1))) |
|
|
) |
|
|
step += 1 |
|
|
|
|
|
if rest is None: |
|
|
return previous_state, (remainders, n_updates) |
|
|
else: |
|
|
return (previous_state, *rest), (remainders, n_updates) |
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from torch.nn.functional import relu as r |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
|
|
|
@dataclass |
|
|
class ARMTOutput(ModelOutput): |
|
|
""" |
|
|
Custom output format for ARMT with all necessary fields. |
|
|
This replaces Munch in the original implementation. |
|
|
""" |
|
|
logits: torch.FloatTensor = None |
|
|
loss: torch.FloatTensor = None |
|
|
hidden_states: torch.FloatTensor = None |
|
|
attentions: tuple = None |
|
|
past_key_values: tuple = None |
|
|
remainders: torch.FloatTensor = None |
|
|
n_updates: torch.FloatTensor = None |
|
|
ce_loss: torch.FloatTensor = None |
|
|
|
|
|
|
|
|
try: |
|
|
from cut_cross_entropy import linear_cross_entropy |
|
|
CUT_CROSS_ENTROPY_AVAILABLE = True |
|
|
except ImportError: |
|
|
CUT_CROSS_ENTROPY_AVAILABLE = False |
|
|
print("Warning: cut_cross_entropy not available, falling back to standard CrossEntropyLoss") |
|
|
|
|
|
|
|
|
try: |
|
|
from baselines.rwkv.language_modeling import RWKVModel |
|
|
RWKV_imported = True |
|
|
except ImportError: |
|
|
print("*** Can't import RWKV model ***") |
|
|
RWKV_imported = False |
|
|
def dpfp(x, nu=1): |
|
|
x = torch.cat([r(x), r(-x)], dim=-1) |
|
|
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) |
|
|
for j in range(1,nu+1)], dim=-1) |
|
|
x_repeat = torch.cat([x] * nu, dim=-1) |
|
|
return x_repeat * x_rolled |
|
|
|
|
|
class DPFP: |
|
|
def __init__(self, nu): |
|
|
self.nu = nu |
|
|
|
|
|
def __call__(self, x): |
|
|
nu = self.nu |
|
|
x = torch.cat([r(x), r(-x)], dim=-1) |
|
|
x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1,nu+1)], dim=-1) |
|
|
x_repeat = torch.cat([x] * nu, dim=-1) |
|
|
return x_repeat * x_rolled |
|
|
def attn_mask_to_4d(attn_mask, upper, query_len): |
|
|
if attn_mask is None: |
|
|
return None |
|
|
seg_len = attn_mask.size(-1) |
|
|
if upper: |
|
|
tri = torch.triu(torch.ones(query_len, seg_len, dtype=attn_mask.dtype, device=attn_mask.device)) |
|
|
else: |
|
|
tri = torch.tril(torch.ones(query_len, seg_len, dtype=attn_mask.dtype, device=attn_mask.device)) |
|
|
|
|
|
mask = torch.einsum('bj,ij->bij', attn_mask, tri) |
|
|
mask = mask.unsqueeze(1) |
|
|
return mask |
|
|
|
|
|
def invert_attn_mask(attn_mask, dtype): |
|
|
if os.environ.get("NOT_INVERT_ATTN_MASK"): |
|
|
return attn_mask |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
|
|
|
one = torch.tensor(1.0, dtype=attn_mask.dtype, device=attn_mask.device) |
|
|
new_mask = (one - attn_mask) * min_dtype |
|
|
return new_mask |
|
|
|
|
|
|
|
|
|
|
|
class AssociativeLayerWrapper(torch.nn.Module): |
|
|
|
|
|
def __init__(self, layer, d_model, num_mem_tokens, d_mem, n_heads=1, correction=True, info=None, use_denom=True, gating=False) -> None: |
|
|
super().__init__() |
|
|
self.info = info |
|
|
self.seg_num = 0 |
|
|
self.d_model = d_model |
|
|
self.num_mem_tokens = num_mem_tokens |
|
|
self.d_mem = d_mem |
|
|
self.n_heads = n_heads |
|
|
self.gating = gating |
|
|
nu = 3 |
|
|
self.d_key = 2 * nu * d_mem |
|
|
|
|
|
assert self.d_mem % n_heads == 0 and self.d_model % n_heads == 0 |
|
|
|
|
|
self.phi = DPFP(nu) |
|
|
|
|
|
|
|
|
|
|
|
self.use_denom = use_denom |
|
|
|
|
|
|
|
|
layer_dtype = next(layer.parameters()).dtype |
|
|
|
|
|
self.W_mq = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
|
|
|
|
|
self.W_mk = torch.nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
|
|
self.W_mv = torch.nn.Linear(d_model, d_model, bias=False, dtype=layer_dtype) |
|
|
if gating: |
|
|
self.W_mb = torch.nn.Linear(d_model, d_model, dtype=layer_dtype) |
|
|
else: |
|
|
self.W_mb = torch.nn.Linear(d_model, n_heads, dtype=layer_dtype) |
|
|
torch.nn.init.zeros_(self.W_mv.weight) |
|
|
s = 1/math.sqrt(d_model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.layer = layer |
|
|
|
|
|
self.generate_mode = False |
|
|
self.first_seg = True |
|
|
self.correction = correction |
|
|
|
|
|
self.zero_mem() |
|
|
|
|
|
def _to_heads(self, x): |
|
|
bsz, seq_len, d_model = x.shape |
|
|
x = x.reshape(bsz, seq_len, self.n_heads, d_model // self.n_heads) |
|
|
x = x.permute(0, 2, 1, 3) |
|
|
return x |
|
|
|
|
|
def _from_heads(self, x): |
|
|
bsz, n_heads, seq_len, d_head = x.shape |
|
|
x = x.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * d_head) |
|
|
return x |
|
|
def associate(self, hidden_states): |
|
|
bsz, seq_len, d_model = hidden_states.shape |
|
|
|
|
|
self.W_mem = self.W_mem.to(hidden_states.device) |
|
|
if self.use_denom: |
|
|
self.z = self.z.to(hidden_states.device) |
|
|
|
|
|
q = self._to_heads(self.W_mq(hidden_states)) |
|
|
mq = self.phi(q) |
|
|
mq = F.normalize(mq, dim=-1, p=2.0) |
|
|
|
|
|
|
|
|
num = torch.einsum('ihjk,ihkt->ihjt', mq, self.W_mem) |
|
|
if self.use_denom: |
|
|
denom = torch.einsum("ihk,ihjk->ihj", self.z, mq)[..., None] + 1e-5 |
|
|
hidden_states = num / denom |
|
|
else: |
|
|
hidden_states = num |
|
|
hidden_states = self._from_heads(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
def forward(self, hidden_states, *args, **kwargs): |
|
|
if not self.first_seg: |
|
|
hidden_states = self.associate( |
|
|
|
|
|
hidden_states |
|
|
|
|
|
) + hidden_states |
|
|
out = self.layer(hidden_states, *args, **kwargs) |
|
|
if not self.generate_mode: |
|
|
|
|
|
|
|
|
if isinstance(out, tuple): |
|
|
mem_tokens = out[0][:, -self.num_mem_tokens:] |
|
|
else: |
|
|
mem_tokens = out[:, -self.num_mem_tokens:] |
|
|
|
|
|
self.update_mem(mem_tokens) |
|
|
return out |
|
|
|
|
|
def forward_no_update(self, hidden_states, *args, **kwargs): |
|
|
if not self.first_seg: |
|
|
hidden_states = self.associate( |
|
|
|
|
|
hidden_states |
|
|
|
|
|
)+ hidden_states |
|
|
out = self.layer(hidden_states, *args, **kwargs) |
|
|
return out |
|
|
|
|
|
def forward_no_update(self, hidden_states, *args, **kwargs): |
|
|
if not self.first_seg: |
|
|
hidden_states = self.associate( |
|
|
|
|
|
hidden_states |
|
|
|
|
|
) + hidden_states |
|
|
out = self.layer(hidden_states, *args, **kwargs) |
|
|
return out |
|
|
|
|
|
def update_mem(self, mem_tokens): |
|
|
|
|
|
self.W_mem = self.W_mem.to(mem_tokens.device) |
|
|
if self.use_denom: |
|
|
self.z = self.z.to(mem_tokens.device) |
|
|
k = self._to_heads(self.W_mk(mem_tokens)) |
|
|
mk = self.phi(k) |
|
|
mk = F.normalize(mk, dim=-1, p=2.0) |
|
|
|
|
|
new_mv = self._to_heads(self.W_mv(mem_tokens)) |
|
|
if not self.first_seg: |
|
|
num = torch.einsum('ihjk,ihkt->ihjt', mk, self.W_mem) |
|
|
if self.use_denom: |
|
|
denom = torch.einsum("ihj,ihkj->ihk", self.z, mk)[..., None] + 1e-5 |
|
|
prev_mv = num / denom |
|
|
if self.correction: |
|
|
new_info_coef = (1 - denom / (torch.linalg.norm(mk, dim=-1) ** 2)[..., None]) |
|
|
new_info_coef = torch.clip(new_info_coef, 0, 1).detach() |
|
|
else: |
|
|
new_info_coef = 1 |
|
|
else: |
|
|
prev_mv = num |
|
|
else: |
|
|
prev_mv = torch.zeros_like(new_mv, device=new_mv.device) |
|
|
new_info_coef = 1 |
|
|
|
|
|
mv = new_mv - prev_mv |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mb = self._to_heads(torch.sigmoid(self.W_mb(mem_tokens))) |
|
|
|
|
|
einop = f"ihjk,ihjt,ihj{'t' if self.gating else 'x'}->ihkt" |
|
|
associations = torch.einsum(einop, mk, mv, mb) |
|
|
|
|
|
self.W_mem = self.W_mem + associations |
|
|
|
|
|
if self.use_denom: |
|
|
self.z = self.z + (new_info_coef*mk).sum(dim=-2) |
|
|
|
|
|
self.seg_num += 1 |
|
|
self.first_seg = False |
|
|
|
|
|
def freeze_mem(self): |
|
|
self.W_mb.weight.requires_grad = False |
|
|
self.W_mb.bias.requires_grad = False |
|
|
self.W_mq.weight.requires_grad = False |
|
|
self.W_mk.weight.requires_grad = False |
|
|
self.W_mv.weight.requires_grad = False |
|
|
|
|
|
def zero_mem(self): |
|
|
self.first_seg = True |
|
|
|
|
|
layer_dtype = next(self.layer.parameters()).dtype |
|
|
self.W_mem = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, self.d_model // self.n_heads, dtype=layer_dtype) |
|
|
self.W_mem.requires_grad_(False) |
|
|
if self.use_denom: |
|
|
self.z = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, dtype=layer_dtype) |
|
|
self.z.requires_grad_(False) |
|
|
self.seg_num = 0 |
|
|
|
|
|
def detach_mem(self): |
|
|
self.W_mem = self.W_mem.detach() |
|
|
if self.use_denom: |
|
|
self.z = self.z.detach() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper): |
|
|
def __init__(self, |
|
|
layer, |
|
|
d_model, |
|
|
num_mem_tokens, |
|
|
d_mem, |
|
|
max_hop, |
|
|
n_heads=1, |
|
|
correction=True, |
|
|
info=None, |
|
|
use_denom=True, |
|
|
gating=False, |
|
|
constant_depth=False, |
|
|
|
|
|
) -> None: |
|
|
super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) |
|
|
self.act = ACT_basic(d_model) if not constant_depth else ACT_constant_depth() |
|
|
self.depth = max_hop |
|
|
self.max_length = 1024 |
|
|
|
|
|
self.timing_signal = gen_timing_signal(self.max_length, d_model) |
|
|
|
|
|
self.position_signal = gen_timing_signal(self.depth, d_model) |
|
|
|
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
|
|
|
def associate(self, hidden_states): |
|
|
self.remainders = self.remainders.to(hidden_states.device) |
|
|
self.n_updates = self.n_updates.to(hidden_states.device) |
|
|
self.segments_passed = self.segments_passed.to(hidden_states.device) |
|
|
out, (remainders, n_updates) = self.act( |
|
|
state=hidden_states, |
|
|
inputs=hidden_states, |
|
|
fn=super().associate, |
|
|
time_enc=self.timing_signal, |
|
|
pos_enc=self.position_signal, |
|
|
max_hop=self.depth |
|
|
) |
|
|
|
|
|
self.remainders = self.remainders + remainders.mean() |
|
|
self.n_updates = self.n_updates + n_updates.mean() |
|
|
self.segments_passed = self.segments_passed + 1 |
|
|
return out |
|
|
|
|
|
def zero_mem(self): |
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
return super().zero_mem() |
|
|
|
|
|
def detach_mem(self): |
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
return super().detach_mem() |
|
|
|
|
|
|
|
|
|
|
|
class AdaptiveAssociativeLayerWrapper2(AssociativeLayerWrapper): |
|
|
def __init__(self, |
|
|
layer, |
|
|
d_model, |
|
|
num_mem_tokens, |
|
|
d_mem, |
|
|
max_hop, |
|
|
n_heads=1, |
|
|
correction=True, |
|
|
info=None, |
|
|
use_denom=True, |
|
|
gating=False, |
|
|
act_format='linear', |
|
|
noisy_halting=False, |
|
|
constant_depth=False, |
|
|
) -> None: |
|
|
super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) |
|
|
|
|
|
if act_format=='transformer': |
|
|
self.act = ACT_transformer(d_model) |
|
|
elif constant_depth: |
|
|
self.act = ACT_constant_depth() |
|
|
elif act_format == 'linear': |
|
|
self.act = ACT_basic(d_model) |
|
|
else: |
|
|
raise NotImplemetedError |
|
|
|
|
|
self.depth = max_hop |
|
|
self.max_length = 1024 |
|
|
|
|
|
self.noisy_halting = noisy_halting |
|
|
|
|
|
self.timing_signal = gen_timing_signal(self.max_length, d_model) |
|
|
|
|
|
self.position_signal = gen_timing_signal(self.depth, d_model) |
|
|
|
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
|
|
|
def forward(self, hidden_states, *args, **kwargs): |
|
|
self.remainders = self.remainders.to(hidden_states.device) |
|
|
self.n_updates = self.n_updates.to(hidden_states.device) |
|
|
self.segments_passed = self.segments_passed.to(hidden_states.device) |
|
|
|
|
|
if self.noisy_halting: |
|
|
kwargs['noisy_halting'] = self.noisy_halting |
|
|
fwd = super().forward_no_update |
|
|
out, (remainders, n_updates) = self.act( |
|
|
*args, |
|
|
state=hidden_states, |
|
|
inputs=hidden_states, |
|
|
fn=fwd, |
|
|
time_enc=self.timing_signal, |
|
|
pos_enc=self.position_signal, |
|
|
max_hop=self.depth, |
|
|
**kwargs |
|
|
) |
|
|
if not self.generate_mode: |
|
|
mem_tokens = out[0][:, -self.num_mem_tokens:] |
|
|
|
|
|
self.update_mem(mem_tokens) |
|
|
self.first_seg = False |
|
|
self.remainders = self.remainders + remainders.mean() |
|
|
self.n_updates = self.n_updates + n_updates.mean() |
|
|
self.segments_passed = self.segments_passed + 1 |
|
|
return out |
|
|
|
|
|
|
|
|
def zero_mem(self): |
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
return super().zero_mem() |
|
|
|
|
|
def detach_mem(self): |
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
return super().detach_mem() |
|
|
|
|
|
|
|
|
class AdaptiveAssociativeLayerWrapper(AssociativeLayerWrapper): |
|
|
def __init__(self, |
|
|
layer, |
|
|
d_model, |
|
|
num_mem_tokens, |
|
|
d_mem, |
|
|
max_hop, |
|
|
n_heads=1, |
|
|
correction=True, |
|
|
info=None, |
|
|
use_denom=True, |
|
|
gating=False, |
|
|
|
|
|
) -> None: |
|
|
super().__init__(layer, d_model, num_mem_tokens, d_mem, n_heads, correction, info, use_denom, gating) |
|
|
self.act = ACT_basic(d_model) |
|
|
self.depth = max_hop |
|
|
self.max_length = 1024 |
|
|
|
|
|
self.timing_signal = gen_timing_signal(self.max_length, d_model) |
|
|
|
|
|
self.position_signal = gen_timing_signal(self.depth, d_model) |
|
|
|
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
|
|
|
def associate(self, hidden_states): |
|
|
self.remainders = self.remainders.to(hidden_states.device) |
|
|
self.n_updates = self.n_updates.to(hidden_states.device) |
|
|
self.segments_passed = self.segments_passed.to(hidden_states.device) |
|
|
out, (remainders, n_updates) = self.act( |
|
|
state=hidden_states, |
|
|
inputs=hidden_states, |
|
|
fn=super().associate, |
|
|
time_enc=self.timing_signal, |
|
|
pos_enc=self.position_signal, |
|
|
max_hop=self.depth |
|
|
) |
|
|
|
|
|
self.remainders = self.remainders + remainders |
|
|
self.n_updates = self.n_updates + n_updates |
|
|
self.segments_passed = self.segments_passed + 1 |
|
|
return out |
|
|
|
|
|
def zero_mem(self): |
|
|
self.remainders = torch.zeros(1,) |
|
|
self.n_updates = torch.zeros(1,) |
|
|
self.segments_passed = torch.zeros(1,) |
|
|
return super().zero_mem() |
|
|
|
|
|
|
|
|
|
|
|
class AssociativeMemoryCell(torch.nn.Module): |
|
|
def __init__(self, |
|
|
base_model, |
|
|
num_mem_tokens, |
|
|
d_mem, |
|
|
layers_attr: str = 'model.layers', |
|
|
wrap_pos=False, |
|
|
correction=True, |
|
|
n_heads=1, |
|
|
use_denom=True, |
|
|
gating=False, |
|
|
freeze_mem=False, |
|
|
act_on=False, |
|
|
max_hop=4, |
|
|
act_type='layer', |
|
|
act_format='linear', |
|
|
noisy_halting=False, |
|
|
constant_depth=False, |
|
|
attend_to_previous_input=False, |
|
|
use_sink=False, |
|
|
**rmt_config |
|
|
): |
|
|
super().__init__() |
|
|
self.model = base_model |
|
|
|
|
|
self.attend_to_previous_input = attend_to_previous_input |
|
|
self.previous_input = None |
|
|
self.use_sink = use_sink |
|
|
|
|
|
self.RWKV_ARMT = isinstance(self.model, RWKVModel) if RWKV_imported else False |
|
|
|
|
|
self.num_mem_tokens = num_mem_tokens |
|
|
self.d_mem = d_mem |
|
|
self.d_model = base_model.get_input_embeddings().embedding_dim |
|
|
self.W_mem = [] |
|
|
|
|
|
self.constant_depth = constant_depth |
|
|
|
|
|
self.layers_attrs = layers_attr.split('.') |
|
|
|
|
|
def _get_layers_from_model(model_root): |
|
|
layers_obj = model_root |
|
|
for attr in self.layers_attrs: |
|
|
layers_obj = getattr(layers_obj, attr) |
|
|
return layers_obj |
|
|
|
|
|
layers = _get_layers_from_model(self.model) |
|
|
|
|
|
for i in range(len(layers)): |
|
|
kw = dict( |
|
|
layer=layers[i], |
|
|
d_model=self.d_model, |
|
|
num_mem_tokens=self.num_mem_tokens, |
|
|
d_mem=self.d_mem, |
|
|
correction=correction, |
|
|
info={'layer': i}, |
|
|
n_heads=n_heads, |
|
|
use_denom=use_denom, |
|
|
gating=gating, |
|
|
) |
|
|
if act_on and act_type != 'model': |
|
|
kw['act_format'] = act_format |
|
|
if act_on and act_type == 'model' and act_format != 'linear': |
|
|
raise NotImplementedError |
|
|
if act_on and (act_type != 'model'): |
|
|
kw['max_hop'] = max_hop |
|
|
kw['constant_depth'] = self.constant_depth |
|
|
kw['act_format'] = act_format |
|
|
if act_on and noisy_halting: |
|
|
kw['noisy_halting'] = noisy_halting |
|
|
if not act_on: |
|
|
layers[i] = AssociativeLayerWrapper(**kw) |
|
|
elif act_type == 'associative': |
|
|
layers[i] = AdaptiveAssociativeLayerWrapper(**kw) |
|
|
elif act_type == 'layer': |
|
|
layers[i] = AdaptiveAssociativeLayerWrapper2(**kw) |
|
|
elif act_type == 'model': |
|
|
layers[i] = AssociativeLayerWrapper(**kw) |
|
|
else: |
|
|
raise f'Unknown ACT type: {act_type}' |
|
|
|
|
|
if act_type == 'model': |
|
|
self.act = ACTForWholeARMT(self.d_model) if not self.constant_depth else ACTForWholeARMT_constant_depth() |
|
|
self.depth = max_hop |
|
|
self.max_length = 1024 |
|
|
self.timing_signal = gen_timing_signal(self.max_length, self.d_model) |
|
|
self.position_signal = gen_timing_signal(self.depth, self.d_model) |
|
|
self.act_type = act_type |
|
|
|
|
|
self.create_memory(num_mem_tokens) |
|
|
self.wrap_pos = wrap_pos |
|
|
self.act_on = act_on |
|
|
if wrap_pos: |
|
|
self.wrap_positional_embeddings(num_mem_tokens) |
|
|
|
|
|
if freeze_mem: |
|
|
for layer in _get_layers_from_model(self.model): |
|
|
layer.freeze_mem() |
|
|
|
|
|
|
|
|
self.get_layers = lambda: _get_layers_from_model(self.model) |
|
|
|
|
|
def generate_mode(self, is_on): |
|
|
for layer in self.get_layers(): |
|
|
layer.generate_mode = is_on |
|
|
|
|
|
def create_memory(self, num_mem_tokens): |
|
|
self.num_mem_tokens = num_mem_tokens |
|
|
embeddings = self.model.get_input_embeddings() |
|
|
memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size) |
|
|
memory_weights = torch.randn((num_mem_tokens, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype) * embeddings.weight.data.std() |
|
|
|
|
|
self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True)) |
|
|
if self.use_sink: |
|
|
self.sink = torch.nn.Parameter(torch.randn((1, memory_dim), device=embeddings.weight.data.device, dtype=embeddings.weight.data.dtype), requires_grad=True) |
|
|
|
|
|
|
|
|
def wrap_positional_embeddings(self, num_mem_tokens): |
|
|
num_pos_embs, emb_dim = self.model.transformer.wpe.weight.shape |
|
|
prev_embs = self.model.transformer.wpe.weight.detach() |
|
|
self.model.transformer.wpe = torch.nn.Embedding(num_mem_tokens + num_pos_embs, emb_dim) |
|
|
|
|
|
new_num_pos = num_pos_embs + num_mem_tokens |
|
|
with torch.no_grad(): |
|
|
self.model.transformer.wpe.weight[:len(self.model.transformer.wpe.weight)-num_mem_tokens] = prev_embs |
|
|
for layer in self.model.transformer.h: |
|
|
layer.layer.attn.bias = torch.tril(torch.ones((new_num_pos, new_num_pos), dtype=torch.uint8)).view( |
|
|
1, 1, new_num_pos, new_num_pos |
|
|
) |
|
|
|
|
|
def set_memory(self, input_shape): |
|
|
memory = self.memory.repeat(input_shape[0], 1, 1) |
|
|
if self.use_sink: |
|
|
sink = self.sink.repeat(input_shape[0], 1, 1) |
|
|
else: |
|
|
sink = None |
|
|
return memory, sink |
|
|
|
|
|
def zero_mem(self): |
|
|
for layer in self.get_layers(): |
|
|
layer.zero_mem() |
|
|
self.previous_input = None |
|
|
|
|
|
def detach_mem(self): |
|
|
for layer in self.get_layers(): |
|
|
layer.detach_mem() |
|
|
pass |
|
|
|
|
|
def forward(self, input_ids, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs): |
|
|
if self.act_type != 'model': |
|
|
out = self.forward_with_update(input_ids, labels, labels_mask, zero_mem, attention_mask=attention_mask, **kwargs) |
|
|
else: |
|
|
seg_kwargs = self.process_input(input_ids=input_ids, |
|
|
labels=labels, |
|
|
labels_mask=labels_mask, |
|
|
zero_mem=zero_mem, |
|
|
attention_mask=attention_mask, |
|
|
**kwargs |
|
|
) |
|
|
out = self.gptneox_forward_act(**seg_kwargs) |
|
|
out = self.process_output(out, labels=labels, labels_mask=labels_mask) |
|
|
return out |
|
|
|
|
|
def forward_with_update(self, input_ids, labels=None, labels_mask=None, zero_mem=False, **kwargs): |
|
|
current_input_ids = input_ids.clone() |
|
|
if self.attend_to_previous_input and self.previous_input is not None: |
|
|
input_ids = torch.cat([self.previous_input, input_ids], dim=1) |
|
|
|
|
|
if zero_mem: |
|
|
self.zero_mem() |
|
|
|
|
|
seg_kwargs = self.process_input(input_ids, **kwargs) |
|
|
|
|
|
layers = self.get_layers() |
|
|
if self.RWKV_ARMT and not layers[0].generate_mode: |
|
|
input1 = dict() |
|
|
input2 = dict() |
|
|
for item in seg_kwargs: |
|
|
if isinstance(seg_kwargs[item], torch.Tensor): |
|
|
|
|
|
input1[item] = seg_kwargs[item][:, :-self.num_mem_tokens] |
|
|
input2[item] = seg_kwargs[item][:, -self.num_mem_tokens:] |
|
|
else: |
|
|
input1[item] = seg_kwargs[item] |
|
|
input2[item] = seg_kwargs[item] |
|
|
|
|
|
self.generate_mode(True) |
|
|
out = self.model(**input1) |
|
|
self.generate_mode(False) |
|
|
state_tmp = tuple([torch.clone(state) for state in out['state']]) |
|
|
out = ARMTOutput(**{k: torch.clone(t) if isinstance(t, torch.Tensor) else t for k, t in out.items()}) |
|
|
input2['state'] = out['state'] |
|
|
_ = self.model(**input2) |
|
|
out['state'] = state_tmp |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
out = self.model(**seg_kwargs) |
|
|
|
|
|
if self.attend_to_previous_input and self.previous_input is not None: |
|
|
out['logits'] = out['logits'][:, self.previous_input.size(1):] |
|
|
out = self.process_output(out, labels, labels_mask, **kwargs) |
|
|
self.previous_input = current_input_ids |
|
|
return out |
|
|
|
|
|
def process_input(self, input_ids, **kwargs): |
|
|
memory_state, sink = self.set_memory(input_ids.shape) |
|
|
seg_kwargs = dict(**kwargs) |
|
|
inputs_embeds = kwargs.get('inputs_embeds') |
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.get_input_embeddings()(input_ids) |
|
|
if self.use_sink: |
|
|
inputs_embeds = torch.cat([sink, inputs_embeds, memory_state], dim=1) |
|
|
else: |
|
|
inputs_embeds = torch.cat([inputs_embeds, memory_state], dim=1) |
|
|
|
|
|
seg_kwargs['input_ids'] = None |
|
|
seg_kwargs['inputs_embeds'] = inputs_embeds |
|
|
if kwargs.get('attention_mask') is not None: |
|
|
seg_kwargs['attention_mask'] = self.pad_attention_mask(kwargs['attention_mask'], dtype=inputs_embeds.dtype) |
|
|
if kwargs.get('prev_attn_mask') is not None: |
|
|
prev_seg_attn_mask = self.pad_prev_seg_attn_mask(kwargs['prev_attn_mask'], dtype=inputs_embeds.dtype) |
|
|
seg_kwargs['attention_mask'] = torch.cat([prev_seg_attn_mask, seg_kwargs['attention_mask']], dim=-1) |
|
|
if 'prev_attn_mask' in seg_kwargs: |
|
|
seg_kwargs.pop('prev_attn_mask') |
|
|
seg_kwargs['output_hidden_states'] = True |
|
|
|
|
|
if self.wrap_pos: |
|
|
num_pos_embs = self.model.transformer.wpe.weight.shape[0] |
|
|
ordinary_pos = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device) |
|
|
write_pos = torch.arange(num_pos_embs - self.num_mem_tokens, num_pos_embs, dtype=torch.long, device=input_ids.device) |
|
|
seg_kwargs['position_ids'] = torch.cat([ |
|
|
ordinary_pos, |
|
|
write_pos |
|
|
]).long().unsqueeze(0) |
|
|
return seg_kwargs |
|
|
|
|
|
|
|
|
|
|
|
def pad_attention_mask(self, attention_mask, dtype=float): |
|
|
if self.num_mem_tokens in {0, None}: |
|
|
return attention_mask |
|
|
else: |
|
|
shape = list(attention_mask.shape) |
|
|
if len(shape) == 4: |
|
|
|
|
|
shape[-1] += self.num_mem_tokens + self.use_sink |
|
|
shape[-2] += self.num_mem_tokens + self.use_sink |
|
|
mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
|
|
mask[..., int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] = attention_mask |
|
|
if self.use_sink: |
|
|
mask[..., 0, 1:] = 0 |
|
|
mask[..., :-self.num_mem_tokens, -self.num_mem_tokens:] = 0 |
|
|
|
|
|
if not os.environ.get("NOT_INVERT_ATTN_MASK"): |
|
|
mask = invert_attn_mask(mask, dtype) |
|
|
else: |
|
|
shape[-1] += self.num_mem_tokens + self.use_sink |
|
|
mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
|
|
mask[..., int(self.use_sink):-self.num_mem_tokens] = attention_mask |
|
|
return mask.to(dtype) |
|
|
|
|
|
def pad_prev_seg_attn_mask(self, prev_seg_attn_mask, dtype=float): |
|
|
if self.num_mem_tokens in {0, None}: |
|
|
return prev_seg_attn_mask |
|
|
else: |
|
|
shape = list(prev_seg_attn_mask.shape) |
|
|
if len(shape) == 4: |
|
|
shape[-2] += self.num_mem_tokens + self.use_sink |
|
|
mask = torch.ones(*shape, dtype=dtype).to(prev_seg_attn_mask.device) |
|
|
mask[..., int(self.use_sink):-self.num_mem_tokens, :] = prev_seg_attn_mask |
|
|
if self.use_sink: |
|
|
mask[..., 0, :] = 0 |
|
|
if not os.environ.get("NOT_INVERT_ATTN_MASK"): |
|
|
mask = invert_attn_mask(mask, dtype) |
|
|
else: |
|
|
mask = prev_seg_attn_mask |
|
|
return mask.to(dtype) |
|
|
|
|
|
def process_output(self, model_outputs, labels, labels_mask, **kwargs): |
|
|
|
|
|
if (self.num_mem_tokens not in {0, None}) and not self.RWKV_ARMT: |
|
|
out = CausalLMOutputWithCrossAttentions() |
|
|
out['logits'] = model_outputs.logits[:, int(self.use_sink):-self.num_mem_tokens] |
|
|
if kwargs.get('output_hidden_states'): |
|
|
out['hidden_states'] = [lh[:, int(self.use_sink):-self.num_mem_tokens] for lh in model_outputs.hidden_states] |
|
|
if kwargs.get('output_attentions'): |
|
|
out['attentions'] = model_outputs['attentions'] |
|
|
else: |
|
|
out = model_outputs |
|
|
|
|
|
if labels is not None: |
|
|
labels = labels[..., 1:].contiguous() |
|
|
flat_labels = labels.view(-1) |
|
|
|
|
|
if labels_mask is not None: |
|
|
flat_mask = labels_mask[..., :-1].contiguous().view(-1) |
|
|
flat_labels = flat_labels[flat_mask] |
|
|
|
|
|
|
|
|
if CUT_CROSS_ENTROPY_AVAILABLE and hasattr(self.model, 'embed_out'): |
|
|
|
|
|
if 'hidden_states' in model_outputs and model_outputs.hidden_states is not None: |
|
|
|
|
|
hidden_states = model_outputs.hidden_states[-1] |
|
|
|
|
|
if self.num_mem_tokens not in {0, None}: |
|
|
hidden_states = hidden_states[:, int(self.use_sink):-self.num_mem_tokens] |
|
|
|
|
|
hidden_states = hidden_states[..., :-1, :].contiguous() |
|
|
flat_hidden_states = hidden_states.view(-1, hidden_states.size(-1)) |
|
|
|
|
|
if labels_mask is not None: |
|
|
flat_hidden_states = flat_hidden_states[flat_mask] |
|
|
|
|
|
|
|
|
lm_head_weights = self.model.embed_out.weight |
|
|
|
|
|
|
|
|
ce_loss = linear_cross_entropy( |
|
|
flat_hidden_states, |
|
|
lm_head_weights, |
|
|
flat_labels, |
|
|
reduction='sum' |
|
|
) |
|
|
else: |
|
|
|
|
|
logits = out['logits'][..., :-1, :].contiguous() |
|
|
flat_logits = logits.view(-1, logits.size(-1)) |
|
|
if labels_mask is not None: |
|
|
flat_logits = flat_logits[flat_mask] |
|
|
ce_loss_fn = CrossEntropyLoss(reduction='sum') |
|
|
ce_loss = ce_loss_fn(flat_logits, flat_labels) |
|
|
else: |
|
|
|
|
|
logits = out['logits'][..., :-1, :].contiguous() |
|
|
flat_logits = logits.view(-1, logits.size(-1)) |
|
|
if labels_mask is not None: |
|
|
flat_logits = flat_logits[flat_mask] |
|
|
ce_loss_fn = CrossEntropyLoss(reduction='sum') |
|
|
ce_loss = ce_loss_fn(flat_logits, flat_labels) |
|
|
|
|
|
if labels_mask is not None: |
|
|
denom = labels_mask[..., :-1].contiguous().view(-1).sum() |
|
|
else: |
|
|
denom = (flat_labels != -100).sum() |
|
|
denom = torch.clamp(denom, min=1) |
|
|
out['ce_loss'] = ce_loss / denom |
|
|
|
|
|
if kwargs.get('use_cache', False): |
|
|
out['past_key_values'] = model_outputs.past_key_values |
|
|
if self.act_on and self.act_type == 'model': |
|
|
out['remainders'] = model_outputs['remainders'] |
|
|
out['n_updates'] = model_outputs['n_updates'] |
|
|
return out |
|
|
|
|
|
def generate(self, input_ids, attention_mask, zero_mem=False, **generate_kwargs): |
|
|
if zero_mem: |
|
|
self.zero_mem() |
|
|
|
|
|
|
|
|
self.generate_mode(True) |
|
|
seg_kwargs = self.process_input(input_ids, attention_mask=attention_mask) |
|
|
out = self.model.generate( |
|
|
inputs_embeds=seg_kwargs['inputs_embeds'][:, :-self.num_mem_tokens], |
|
|
attention_mask=seg_kwargs['attention_mask'][:, :-self.num_mem_tokens], |
|
|
**generate_kwargs |
|
|
) |
|
|
self.generate_mode(False) |
|
|
return out |
|
|
|
|
|
def update_past_key_values_sw(self, past_key_values, window_size): |
|
|
past_key_values = past_key_values.to_legacy_cache() |
|
|
past_key_values = [ |
|
|
[ |
|
|
k_or_v[..., -(window_size+self.use_sink):, :] |
|
|
for k_or_v in seg_kv |
|
|
] |
|
|
for seg_kv in past_key_values |
|
|
] |
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
return past_key_values |
|
|
|
|
|
def greedy_generate_sw(self, input_ids, attention_mask, prev_attn_mask, **generate_kwargs): |
|
|
self.generate_mode(True) |
|
|
window_size = generate_kwargs['window_size'] |
|
|
max_new_tokens = generate_kwargs['max_new_tokens'] |
|
|
past_key_values = self.update_past_key_values_sw(generate_kwargs['past_key_values'], window_size) |
|
|
eos_token_id = generate_kwargs['eos_token_id'] |
|
|
prev_attn_mask_2d = prev_attn_mask.clone() |
|
|
attention_mask_2d = attention_mask.clone() |
|
|
|
|
|
attention_mask = attn_mask_to_4d(attention_mask, upper=False, query_len=attention_mask.size(-1)) |
|
|
prev_attn_mask = attn_mask_to_4d(prev_attn_mask, upper=True, query_len=attention_mask.size(-1)) |
|
|
seg_kwargs = self.process_input(input_ids=input_ids, attention_mask=attention_mask, prev_attn_mask=prev_attn_mask, past_key_values=past_key_values) |
|
|
seg_kwargs['inputs_embeds'] = seg_kwargs['inputs_embeds'][..., :-self.num_mem_tokens, :] |
|
|
seg_kwargs['attention_mask'] = seg_kwargs['attention_mask'][..., :-self.num_mem_tokens, :-self.num_mem_tokens] |
|
|
outputs = self.model(**seg_kwargs, use_cache=True) |
|
|
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
past_key_values = outputs.past_key_values |
|
|
past_key_values = self.update_past_key_values_sw(past_key_values, window_size) |
|
|
|
|
|
generated_ids = None |
|
|
sw_attention_mask = torch.cat([prev_attn_mask_2d, torch.ones(attention_mask_2d.size(0), 1).to(prev_attn_mask_2d.device), attention_mask_2d], dim=-1) |
|
|
|
|
|
for i in range(max_new_tokens): |
|
|
|
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
|
|
|
|
|
if generated_ids is not None: |
|
|
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
|
|
else: |
|
|
generated_ids = next_token_id |
|
|
next_input = next_token_id |
|
|
|
|
|
sw_attention_mask = torch.cat([sw_attention_mask, torch.ones_like(next_token_id).to(sw_attention_mask.device)], dim=-1)[..., -window_size-1-self.use_sink:] |
|
|
with torch.no_grad(): |
|
|
outputs = self.model( |
|
|
input_ids=next_input, |
|
|
attention_mask=sw_attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
cache_position=torch.full((1,), window_size + i + input_ids.size(-1) + self.use_sink).to(input_ids.device) |
|
|
) |
|
|
past_key_values = self.update_past_key_values_sw(outputs.past_key_values, window_size) |
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
if (next_token_id[:, 0] == eos_token_id).all(): |
|
|
break |
|
|
self.generate_mode(False) |
|
|
return generated_ids |
|
|
|
|
|
|
|
|
def apply_layers(self, hidden_states, causal_mask, position_ids, cache_position, position_embeddings, update_mem=True): |
|
|
if not update_mem: |
|
|
tmp = [] |
|
|
for i in range(len(self.layers)): |
|
|
tmp.append(self.layers[i].forward) |
|
|
self.layers[i].forward = self.layers[i].forward_no_update |
|
|
|
|
|
for layer in self.get_layers(): |
|
|
hidden_states = layer( |
|
|
hidden_states, |
|
|
attention_mask=causal_mask, |
|
|
position_ids=position_ids, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
)[0] |
|
|
|
|
|
if not update_mem: |
|
|
for i, layer in enumerate(self.get_layers()): |
|
|
layer.forward = tmp[i] |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def gptneox_forward_act(self, inputs_embeds, labels=None, labels_mask=None, zero_mem=False, attention_mask=None, **kwargs): |
|
|
|
|
|
drop = self.model.gpt_neox.emb_dropout |
|
|
hidden_states = drop(inputs_embeds) |
|
|
seq_length = hidden_states.shape[1] |
|
|
cache_position = torch.arange(0, seq_length, device=hidden_states.device) |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
position_embeddings = self.model.gpt_neox.rotary_emb(hidden_states, position_ids) |
|
|
causal_mask = self.model.gpt_neox._update_causal_mask( |
|
|
attention_mask, hidden_states, cache_position, None, False |
|
|
) |
|
|
|
|
|
out, (remainders, n_updates) = self.act( |
|
|
state=hidden_states, |
|
|
inputs=hidden_states, |
|
|
fn_no_update=lambda *args, **kwargs: self.apply_layers(*args, **kwargs, update_mem=False), |
|
|
fn_update=self.apply_layers, |
|
|
time_enc=self.timing_signal, |
|
|
pos_enc=self.position_signal, |
|
|
max_hop=self.depth, |
|
|
causal_mask=causal_mask, |
|
|
position_ids=position_ids, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings |
|
|
) |
|
|
hidden_states = self.model.gpt_neox.final_layer_norm(out) |
|
|
|
|
|
lm_logits = self.model.embed_out(hidden_states) |
|
|
return ARMTOutput(logits=lm_logits, n_updates=n_updates, remainders=remainders) |
|
|
|
|
|
class AssociativeRecurrentWrapper(torch.nn.Module): |
|
|
def __init__(self, memory_cell, **rmt_kwargs): |
|
|
super().__init__() |
|
|
|
|
|
self.memory_cell = memory_cell |
|
|
self.rmt_config = rmt_kwargs |
|
|
self.last_state = None |
|
|
|
|
|
def gradient_checkpointing_enable(self, *args, **kwargs): |
|
|
self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs) |
|
|
|
|
|
def process_segment(self, segment_kwargs, next_seg_len=None): |
|
|
sliding_window = self.rmt_config['sliding_window'] if 'sliding_window' in self.rmt_config else False |
|
|
attend_to_previous_input = self.rmt_config['attend_to_previous_input'] if 'attend_to_previous_input' in self.rmt_config else False |
|
|
attn_mask = segment_kwargs['attention_mask'] |
|
|
seg_len = segment_kwargs['input_ids'].size(-1) |
|
|
|
|
|
segment_kwargs['use_cache'] = sliding_window |
|
|
if segment_kwargs.get('past_key_values') is None: |
|
|
segment_kwargs['past_key_values'] = None |
|
|
if segment_kwargs.get('prev_attn_mask') is None: |
|
|
segment_kwargs['prev_attn_mask'] = None |
|
|
segment_kwargs['zero_mem'] = False |
|
|
if sliding_window or attend_to_previous_input: |
|
|
segment_kwargs['attention_mask'] = attn_mask_to_4d(attn_mask, upper=False, query_len=seg_len) |
|
|
|
|
|
if 'state' in segment_kwargs and segment_kwargs['state'] is None: |
|
|
segment_kwargs.pop('state') |
|
|
|
|
|
num_mem_tokens = self.memory_cell.num_mem_tokens |
|
|
cell_out = self.memory_cell(**segment_kwargs) |
|
|
state = cell_out.get('state') |
|
|
if (sliding_window or attend_to_previous_input) and next_seg_len is not None: |
|
|
prev_attn_mask = attn_mask_to_4d(attn_mask, upper=True, query_len=next_seg_len) |
|
|
else: |
|
|
prev_attn_mask = None |
|
|
if sliding_window: |
|
|
past_key_values = [ |
|
|
[ |
|
|
k_or_v[..., -(num_mem_tokens+seg_len):k_or_v.size(-2)-num_mem_tokens, :].detach() |
|
|
for k_or_v in seg_kv |
|
|
] |
|
|
for seg_kv in cell_out['past_key_values'] |
|
|
] |
|
|
if not isinstance(cell_out['past_key_values'], tuple) and not isinstance(cell_out['past_key_values'], list): |
|
|
past_key_values = cell_out['past_key_values'].from_legacy_cache(past_key_values) |
|
|
else: |
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
else: |
|
|
past_key_values = None |
|
|
next_segment_kwargs = dict() |
|
|
next_segment_kwargs['use_cache'] = sliding_window |
|
|
next_segment_kwargs['past_key_values'] = past_key_values |
|
|
next_segment_kwargs['prev_attn_mask'] = prev_attn_mask |
|
|
next_segment_kwargs['zero_mem'] = False |
|
|
if state is not None: |
|
|
next_segment_kwargs['state'] = state |
|
|
return cell_out, next_segment_kwargs |
|
|
|
|
|
def forward(self, |
|
|
input_ids, |
|
|
labels=None, |
|
|
labels_mask=None, |
|
|
inputs_embeds=None, |
|
|
attention_mask=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
input_segmented=False, |
|
|
output_only_last_segment=False, |
|
|
use_previous_batch_state=torch.zeros(1), |
|
|
num_items_in_batch=None, |
|
|
**kwargs |
|
|
): |
|
|
if input_segmented: |
|
|
n_segs = input_ids.shape[1] if not (input_ids is None) else inputs_embeds.shape[1] |
|
|
segmented = [dict( |
|
|
input_ids=input_ids[:, i] if not (input_ids is None) else None, |
|
|
inputs_embeds=inputs_embeds[:, i] if not (inputs_embeds is None) else None, |
|
|
attention_mask=attention_mask[:, i], |
|
|
labels=labels[:, i] if not (labels is None) else None, |
|
|
labels_mask=labels_mask[:, i] if not (labels_mask is None) else None, |
|
|
) for i in range(n_segs)] |
|
|
labels = torch.cat([labels[:, i] for i in range(n_segs)], dim=1) |
|
|
if labels_mask is not None: |
|
|
labels_mask = torch.cat([labels_mask[:, i] for i in range(n_segs)], dim=1) |
|
|
else: |
|
|
segmented = self.segment(input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, labels_mask=labels_mask) |
|
|
|
|
|
cell_outputs = [] |
|
|
if not use_previous_batch_state.all() or self.last_state is None: |
|
|
self.memory_cell.zero_mem() |
|
|
state = None |
|
|
else: |
|
|
self.memory_cell.detach_mem() |
|
|
state = self.last_state |
|
|
next_seg_kwargs = dict(state=state) |
|
|
for seg_num, segment in enumerate(segmented): |
|
|
if seg_num != len(segmented) - 1: |
|
|
next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1) |
|
|
else: |
|
|
next_seg_len = None |
|
|
|
|
|
segment_with_kwargs = dict(**segment, **next_seg_kwargs) |
|
|
if kwargs.get('num_items_in_batch') is not None: |
|
|
segment_with_kwargs['num_items_in_batch'] = kwargs['num_items_in_batch'] |
|
|
cell_out, next_seg_kwargs = self.process_segment(segment_with_kwargs, next_seg_len=next_seg_len) |
|
|
if (not output_only_last_segment) or (seg_num == len(segmented) - 1): |
|
|
cell_outputs.append(cell_out) |
|
|
|
|
|
out = self.process_outputs(cell_outputs, labels=labels, |
|
|
labels_mask=labels_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
num_items_in_batch=kwargs.get('num_items_in_batch')) |
|
|
|
|
|
if not self.training: |
|
|
self.memory_cell.zero_mem() |
|
|
self.last_state = None |
|
|
return out |
|
|
|
|
|
def segment(self, **kwargs): |
|
|
segments = [] |
|
|
for k, tensor in kwargs.items(): |
|
|
if tensor is not None: |
|
|
k_segments = self.split_tensor(tensor) |
|
|
for s, k_seg in enumerate(k_segments): |
|
|
if s < len(segments): |
|
|
segments[s][k] = k_seg |
|
|
else: |
|
|
segments.append({k: k_seg}) |
|
|
|
|
|
return segments |
|
|
|
|
|
def split_tensor(self, tensor): |
|
|
align = self.rmt_config.get('segment_alignment') |
|
|
segment_size = self.rmt_config.get('segment_size') |
|
|
if align in {'left', None}: |
|
|
split_inds = list(range(0, tensor.shape[1], segment_size)) + [tensor.shape[1]] |
|
|
segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] |
|
|
elif align in {'right', None}: |
|
|
split_inds = (list(range(tensor.shape[1], 0, -segment_size)) + [0])[::-1] |
|
|
segments = [tensor[:, start:end] for (start, end) in zip(split_inds, split_inds[1:])] |
|
|
elif align == 'center': |
|
|
n_seg = math.ceil(tensor.shape[1] / segment_size) |
|
|
segments = torch.chunk(tensor, n_seg, dim=1) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
return segments |
|
|
|
|
|
def process_outputs(self, cell_outputs, **kwargs): |
|
|
out = ARMTOutput() |
|
|
full_logits = torch.cat([o.logits for o in cell_outputs], dim=1) |
|
|
|
|
|
labels = kwargs.get('labels') |
|
|
if labels is not None: |
|
|
labels = labels[:, -full_logits.size(1):] |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
flat_labels = shift_labels.view(-1) |
|
|
|
|
|
labels_mask = kwargs.get('labels_mask') |
|
|
if labels_mask is not None: |
|
|
labels_mask = labels_mask[:, -full_logits.size(1):] |
|
|
shift_mask = labels_mask[..., :-1].contiguous() |
|
|
flat_labels = flat_labels[shift_mask.view(-1)] |
|
|
|
|
|
|
|
|
if CUT_CROSS_ENTROPY_AVAILABLE and hasattr(self.memory_cell.model, 'embed_out'): |
|
|
|
|
|
if cell_outputs and 'hidden_states' in cell_outputs[-1] and cell_outputs[-1].hidden_states is not None: |
|
|
|
|
|
full_hidden_states = torch.cat([o.hidden_states[-1] for o in cell_outputs], dim=1) |
|
|
|
|
|
shift_hidden_states = full_hidden_states[..., :-1, :].contiguous() |
|
|
flat_hidden_states = shift_hidden_states.view(-1, shift_hidden_states.size(-1)) |
|
|
|
|
|
if labels_mask is not None: |
|
|
flat_hidden_states = flat_hidden_states[shift_mask.view(-1)] |
|
|
|
|
|
|
|
|
lm_head_weights = self.memory_cell.model.embed_out.weight |
|
|
|
|
|
|
|
|
loss = linear_cross_entropy( |
|
|
flat_hidden_states, |
|
|
lm_head_weights, |
|
|
flat_labels, |
|
|
reduction='sum' |
|
|
) |
|
|
else: |
|
|
|
|
|
shift_logits = full_logits[..., :-1, :].contiguous() |
|
|
flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
|
|
if labels_mask is not None: |
|
|
flat_logits = flat_logits[shift_mask.view(-1)] |
|
|
loss_fct = CrossEntropyLoss(reduction='sum') |
|
|
loss = loss_fct(flat_logits, flat_labels) |
|
|
else: |
|
|
|
|
|
shift_logits = full_logits[..., :-1, :].contiguous() |
|
|
flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
|
|
if labels_mask is not None: |
|
|
flat_logits = flat_logits[shift_mask.view(-1)] |
|
|
loss_fct = CrossEntropyLoss(reduction='sum') |
|
|
loss = loss_fct(flat_logits, flat_labels) |
|
|
|
|
|
if labels_mask is not None: |
|
|
|
|
|
denom = labels_mask[..., :-1].contiguous().view(-1).sum() |
|
|
else: |
|
|
denom = (flat_labels != -100).sum() |
|
|
denom = torch.clamp(denom, min=1) |
|
|
out['loss'] = loss / denom |
|
|
else: |
|
|
out['loss'] = 0 |
|
|
if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']: |
|
|
out['ce_loss'] = out['loss'] |
|
|
|
|
|
out['logits'] = full_logits |
|
|
segment_keys = ['loss', 'logits'] |
|
|
if kwargs.get('output_attentions'): |
|
|
segment_keys.append('attentions') |
|
|
if kwargs.get('output_hidden_states'): |
|
|
|
|
|
if all(hasattr(o, 'hidden_states') and o.hidden_states is not None for o in cell_outputs): |
|
|
full_hidden_states = tuple([torch.cat(layer_hs, dim=1) for layer_hs in zip(*[o.hidden_states for o in cell_outputs])]) |
|
|
segment_keys.append('hidden_states') |
|
|
out['hidden_states'] = full_hidden_states |
|
|
if ('HF_Trainer' not in os.environ) or not os.environ['HF_Trainer']: |
|
|
for seg_num, o in enumerate(cell_outputs): |
|
|
for key, value in o.items(): |
|
|
if any([sk in key for sk in segment_keys]): |
|
|
out[f'{key}_{seg_num}'] = value |
|
|
|
|
|
remainders = [] |
|
|
n_updates = [] |
|
|
act_on = self.rmt_config['act_on'] if 'act_on' in self.rmt_config else False |
|
|
if act_on: |
|
|
if self.memory_cell.act_type != 'model': |
|
|
for layer in self.memory_cell.get_layers(): |
|
|
remainders.append(layer.remainders / layer.segments_passed) |
|
|
n_updates.append(layer.n_updates / layer.segments_passed) |
|
|
remainders = torch.mean(torch.stack(remainders, dim=0)) |
|
|
n_updates = torch.mean(torch.stack(n_updates, dim=0)) |
|
|
else: |
|
|
remainders = torch.mean(torch.stack([o['remainders'] for o in cell_outputs], dim=0)) |
|
|
n_updates = torch.mean(torch.stack([o['n_updates'] for o in cell_outputs], dim=0)) |
|
|
out['n_updates'] = n_updates.detach().cpu() |
|
|
out['remainders'] = remainders.detach().cpu() |
|
|
time_penalty = self.rmt_config['time_penalty'] |
|
|
out['loss'] = out['loss'] + time_penalty * remainders |
|
|
|
|
|
return out |
|
|
|
|
|
def generate(self, input_ids, attention_mask, **generate_kwargs): |
|
|
self.memory_cell.zero_mem() |
|
|
segmented = self.segment(input_ids=input_ids, attention_mask=attention_mask) |
|
|
next_seg_kwargs = dict() |
|
|
for seg_num, segment in enumerate(segmented[:-1]): |
|
|
next_seg_len = segmented[seg_num + 1]['input_ids'].size(-1) |
|
|
_, next_seg_kwargs = self.process_segment(dict(**segment, **next_seg_kwargs), next_seg_len=next_seg_len) |
|
|
|
|
|
final_segment = segmented[-1] |
|
|
assert next_seg_kwargs.get('past_key_values') is None or isinstance(next_seg_kwargs.get('past_key_values'), Cache), "Sliding Window generation is not implemented for legacy cache" |
|
|
if next_seg_kwargs.get('past_key_values') is not None: |
|
|
prev_attn_mask = segmented[-2]['attention_mask'] |
|
|
legacy_cache = next_seg_kwargs['past_key_values'].to_legacy_cache() |
|
|
seg_len = segmented[-2]['input_ids'].size(-1) |
|
|
cache = DynamicCache().from_legacy_cache(legacy_cache) |
|
|
generate_kwargs['past_key_values'] = cache |
|
|
generate_kwargs['window_size'] = seg_len |
|
|
final_segment['prev_attn_mask'] = prev_attn_mask |
|
|
out = self.memory_cell.greedy_generate_sw(**final_segment, **generate_kwargs) |
|
|
return out |
|
|
else: |
|
|
out = self.memory_cell.generate(**final_segment, **generate_kwargs) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
from torch.nn.functional import relu as r |
|
|
import torch.nn.functional as F |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ARMTConfig(PretrainedConfig): |
|
|
model_type = "armt" |
|
|
|
|
|
def __init__(self, |
|
|
base_model_name=None, |
|
|
base_model_config=None, |
|
|
num_mem_tokens=16, |
|
|
d_mem=512, |
|
|
|
|
|
segment_size=512, |
|
|
segment_alignment="left", |
|
|
sliding_window=False, |
|
|
attend_to_previous_input=False, |
|
|
use_sink=False, |
|
|
layers_attr="model.layers", |
|
|
wrap_pos=False, |
|
|
correction=True, |
|
|
n_heads=1, |
|
|
use_denom=True, |
|
|
gating=False, |
|
|
freeze_mem=False, |
|
|
act_on=False, |
|
|
max_hop=4, |
|
|
act_type="associative", |
|
|
act_format="linear", |
|
|
noisy_halting=False, |
|
|
constant_depth=False, |
|
|
time_penalty=0.0, |
|
|
**kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if (base_model_name is not None) and (base_model_config is not None): |
|
|
raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided. Set the other to None.") |
|
|
self.base_model_name = base_model_name |
|
|
|
|
|
self.base_model_config = base_model_config |
|
|
self.num_mem_tokens = num_mem_tokens |
|
|
self.d_mem = d_mem |
|
|
|
|
|
self.segment_size = segment_size |
|
|
self.segment_alignment = segment_alignment |
|
|
self.sliding_window = sliding_window |
|
|
self.attend_to_previous_input = attend_to_previous_input |
|
|
self.use_sink = use_sink |
|
|
self.layers_attr = layers_attr |
|
|
self.wrap_pos = wrap_pos |
|
|
self.correction = correction |
|
|
self.n_heads = n_heads |
|
|
self.use_denom = use_denom |
|
|
self.gating = gating |
|
|
self.freeze_mem = freeze_mem |
|
|
self.act_on = act_on |
|
|
self.max_hop = max_hop |
|
|
self.act_type = act_type |
|
|
self.act_format = act_format |
|
|
self.noisy_halting = noisy_halting |
|
|
self.constant_depth = constant_depth |
|
|
self.time_penalty = time_penalty |
|
|
|
|
|
def get(self, attr: str, default=None): |
|
|
if hasattr(self, attr): |
|
|
return getattr(self, attr) |
|
|
else: |
|
|
return default |
|
|
|
|
|
|
|
|
class ARMTForCausalLM(PreTrainedModel): |
|
|
config_class = ARMTConfig |
|
|
|
|
|
def __init__(self, config: ARMTConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
base_model = None |
|
|
if getattr(config, 'base_model_name', None) is not None and getattr(config, 'base_model_config', None) is not None: |
|
|
raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided in ARMTConfig.") |
|
|
bm_cfg = getattr(config, 'base_model_config', None) |
|
|
if bm_cfg is not None: |
|
|
|
|
|
if isinstance(bm_cfg, PretrainedConfig) and getattr(bm_cfg, 'model_type', None) != ARMTConfig.model_type: |
|
|
resolved_cfg = bm_cfg |
|
|
elif isinstance(bm_cfg, dict): |
|
|
if 'model_type' not in bm_cfg: |
|
|
raise ValueError("`base_model_config` dict must include a 'model_type' key (e.g., 'gpt_neox', 'llama').") |
|
|
config_cls_or_instance = AutoConfig.for_model(bm_cfg['model_type']) |
|
|
|
|
|
if isinstance(config_cls_or_instance, PretrainedConfig): |
|
|
resolved_cfg = config_cls_or_instance |
|
|
for k, v in bm_cfg.items(): |
|
|
setattr(resolved_cfg, k, v) |
|
|
else: |
|
|
resolved_cfg = config_cls_or_instance.from_dict(bm_cfg) |
|
|
elif isinstance(bm_cfg, str): |
|
|
|
|
|
resolved_cfg = AutoConfig.from_pretrained(bm_cfg) |
|
|
else: |
|
|
raise TypeError("`base_model_config` must be a transformers.PretrainedConfig, dict, or str (name/path)") |
|
|
base_model = AutoModelForCausalLM.from_config(resolved_cfg) |
|
|
elif getattr(config, 'base_model_name', None): |
|
|
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name) |
|
|
else: |
|
|
raise ValueError("ARMTForCausalLM requires either `base_model_config` or `base_model_name` in ARMTConfig.") |
|
|
|
|
|
self.armt_config = config |
|
|
|
|
|
|
|
|
memory_cell = AssociativeMemoryCell( |
|
|
base_model=base_model, |
|
|
num_mem_tokens=config.num_mem_tokens, |
|
|
d_mem=config.d_mem, |
|
|
layers_attr=config.layers_attr, |
|
|
wrap_pos=config.wrap_pos, |
|
|
correction=config.correction, |
|
|
n_heads=config.n_heads, |
|
|
use_denom=config.use_denom, |
|
|
gating=config.gating, |
|
|
freeze_mem=config.freeze_mem, |
|
|
act_on=config.act_on, |
|
|
max_hop=config.max_hop, |
|
|
act_type=config.act_type, |
|
|
|
|
|
constant_depth=config.get('constant_depth', False), |
|
|
act_format=config.get('act_format', 'linear'), |
|
|
noisy_halting=config.get('noisy_halting', False), |
|
|
attend_to_previous_input=config.attend_to_previous_input, |
|
|
use_sink=config.use_sink |
|
|
) |
|
|
|
|
|
|
|
|
self.armt = AssociativeRecurrentWrapper( |
|
|
memory_cell, |
|
|
segment_size=config.segment_size, |
|
|
segment_alignment=config.segment_alignment, |
|
|
sliding_window=config.sliding_window, |
|
|
attend_to_previous_input=config.attend_to_previous_input, |
|
|
act_on=config.act_on, |
|
|
time_penalty=config.time_penalty |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
labels=None, |
|
|
labels_mask=None, |
|
|
inputs_embeds=None, |
|
|
attention_mask=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
input_segmented=False, |
|
|
output_only_last_segment=False, |
|
|
num_items_in_batch=None, |
|
|
): |
|
|
return self.armt( |
|
|
input_ids=input_ids, |
|
|
labels=labels, |
|
|
labels_mask=labels_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
input_segmented=input_segmented, |
|
|
output_only_last_segment=output_only_last_segment, |
|
|
num_items_in_batch=num_items_in_batch, |
|
|
) |
|
|
|
|
|
def generate(self, *args, **kwargs): |
|
|
return self.armt.generate(*args, **kwargs) |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True, assign=False): |
|
|
try: |
|
|
return super().load_state_dict(state_dict, strict, assign) |
|
|
except RuntimeError: |
|
|
print("Failed to load state, retrying with ARMT loader.") |
|
|
self.armt.load_state_dict(state_dict, strict=True, assign=assign) |
|
|
print("Success!") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, config=None, *args, **kwargs): |
|
|
|
|
|
return super().from_pretrained(pretrained_model_name_or_path, *args, config=config, **kwargs) |
|
|
|
|
|
def gradient_checkpointing_enable(self, *args, **kwargs): |
|
|
self.armt.gradient_checkpointing_enable(*args, **kwargs) |
|
|
|
|
|
|
|
|
import math |
|
|
import os |
|
|
import inspect |
|
|
from typing import Optional, Tuple, Callable |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.cache_utils import DynamicCache |
|
|
import warnings |
|
|
|
|
|
|
|
|
try: |
|
|
from liger_kernel.transformers import apply_liger_kernel_to_llama |
|
|
LIGER_KERNEL_AVAILABLE = True |
|
|
except ImportError: |
|
|
print("*** Can't import liger_kernel ***") |
|
|
LIGER_KERNEL_AVAILABLE = False |
|
|
except Exception as e: |
|
|
print("*** Can't import liger_kernel ***") |
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reverse_invert_attn_mask(mask: torch.Tensor) -> torch.Tensor: |
|
|
if os.environ.get("NOT_INVERT_ATTN_MASK"): |
|
|
return mask |
|
|
mask = mask.clone().long() |
|
|
mask[mask > -1] = 1 |
|
|
mask[mask < -1] = 0 |
|
|
return mask |
|
|
|
|
|
def attn_mask_to_2d(mask: torch.Tensor) -> torch.Tensor: |
|
|
mask = reverse_invert_attn_mask(mask) |
|
|
mask = torch.any(mask, dim=-2) |
|
|
mask = torch.any(mask, dim=1) |
|
|
return mask.long() |
|
|
|
|
|
def is_empty_past_key_values(past_key_values: Optional[DynamicCache], layer_idx: int) -> bool: |
|
|
if past_key_values is None: |
|
|
return True |
|
|
if len(past_key_values.layers) == 0: |
|
|
return True |
|
|
if len(past_key_values.layers) <= layer_idx: |
|
|
return True |
|
|
if past_key_values.layers[layer_idx].keys is None: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def segment_tensor(t: torch.Tensor, start_idx: int, end_idx: int, seq_len: int) -> torch.Tensor: |
|
|
if not isinstance(t, torch.Tensor): |
|
|
return t |
|
|
|
|
|
if t.dim() >= 2 and t.size(1) == seq_len: |
|
|
return t[:, start_idx:end_idx, ...] |
|
|
return t |
|
|
|
|
|
class InnerLoopAssociativeLayerWrapper(nn.Module): |
|
|
""" |
|
|
A per-layer wrapper that performs associative read/write within the layer by |
|
|
splitting the incoming full sequence into fixed-size segments on the fly. |
|
|
|
|
|
Unlike the outer-loop design (which segments inputs before the model), this |
|
|
module receives the full, unsplit hidden sequence and internally iterates |
|
|
over segments: |
|
|
1) Optional associative READ is applied to the segment's hidden states |
|
|
based on the current associative memory (W_mem, z). |
|
|
2) Memory tokens are appended to the segment and the underlying transformer |
|
|
layer is executed only on this augmented segment. |
|
|
3) The resulting memory token outputs are used to WRITE/update the |
|
|
associative memory. |
|
|
4) The transformed real-token outputs replace the corresponding slice in |
|
|
the layer output for the full sequence. |
|
|
|
|
|
This preserves identical behavior w.r.t. memory math while avoiding any |
|
|
outer recurrent wrapper. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
layer: nn.Module, |
|
|
d_model: int, |
|
|
num_mem_tokens: int, |
|
|
d_mem: int, |
|
|
segment_size: int, |
|
|
n_heads: int = 1, |
|
|
correction: bool = True, |
|
|
use_denom: bool = True, |
|
|
gating: bool = False, |
|
|
use_sink: bool = False, |
|
|
sliding_window: bool = False, |
|
|
get_memory_fn: Optional[Callable[[], torch.Tensor]] = None, |
|
|
get_sink_fn: Optional[Callable[[], Optional[torch.Tensor]]] = None, |
|
|
rotary_fn: Optional[Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
read_prev_states_fn: Optional[Callable[[int, int, torch.device, torch.dtype], Tuple[torch.Tensor, Optional[torch.Tensor]]]] = None, |
|
|
write_states_fn: Optional[Callable[[int, torch.Tensor, Optional[torch.Tensor]], None]] = None, |
|
|
info: Optional[dict] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.info = info |
|
|
self.layer = layer |
|
|
self.d_model = d_model |
|
|
self.num_mem_tokens = int(num_mem_tokens or 0) |
|
|
self.d_mem = d_mem |
|
|
self.segment_size = int(segment_size) |
|
|
self.n_heads = n_heads |
|
|
self.gating = gating |
|
|
self.use_denom = use_denom |
|
|
self.correction = correction |
|
|
self.use_sink = bool(use_sink) |
|
|
self.sliding_window = bool(sliding_window) |
|
|
|
|
|
|
|
|
nu = 3 |
|
|
self.d_key = 2 * nu * d_mem |
|
|
|
|
|
assert self.d_mem % n_heads == 0 and self.d_model % n_heads == 0 |
|
|
|
|
|
|
|
|
layer_dtype = next(self.layer.parameters()).dtype |
|
|
|
|
|
|
|
|
self.W_mq = nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
|
|
self.W_mk = nn.Linear(d_model, d_mem, bias=False, dtype=layer_dtype) |
|
|
self.W_mv = nn.Linear(d_model, d_model, bias=False, dtype=layer_dtype) |
|
|
if gating: |
|
|
self.W_mb = nn.Linear(d_model, d_model, dtype=layer_dtype) |
|
|
else: |
|
|
self.W_mb = nn.Linear(d_model, n_heads, dtype=layer_dtype) |
|
|
torch.nn.init.zeros_(self.W_mv.weight) |
|
|
|
|
|
self.phi = DPFP(nu) |
|
|
|
|
|
|
|
|
self.generate_mode = False |
|
|
self.seg_num = 0 |
|
|
|
|
|
|
|
|
|
|
|
self._get_memory = get_memory_fn |
|
|
self._get_sink = get_sink_fn |
|
|
self._rotary_fn = rotary_fn |
|
|
self._read_prev_states = read_prev_states_fn |
|
|
self._write_states = write_states_fn |
|
|
|
|
|
self.memory_state = None |
|
|
|
|
|
|
|
|
def _to_heads(self, x: torch.Tensor) -> torch.Tensor: |
|
|
bsz, seq_len, d_model = x.shape |
|
|
x = x.reshape(bsz, seq_len, self.n_heads, d_model // self.n_heads) |
|
|
x = x.permute(0, 2, 1, 3) |
|
|
return x |
|
|
|
|
|
def _from_heads(self, x: torch.Tensor) -> torch.Tensor: |
|
|
bsz, n_heads, seq_len, d_head = x.shape |
|
|
x = x.permute(0, 2, 1, 3).reshape(bsz, seq_len, n_heads * d_head) |
|
|
return x |
|
|
|
|
|
|
|
|
def associate(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
raise NotImplementedError("associate() is unused in inner-loop; uses local memory helpers instead") |
|
|
|
|
|
|
|
|
def update_mem(self, mem_tokens: torch.Tensor) -> None: |
|
|
raise NotImplementedError("update_mem() is unused in inner-loop; uses local memory helpers instead") |
|
|
|
|
|
|
|
|
def zero_mem(self) -> None: |
|
|
self.memory_state = None |
|
|
|
|
|
def detach_mem(self) -> None: |
|
|
self.memory_state = (self.memory_state[0].detach(), self.memory_state[1].detach()) if self.memory_state is not None else None |
|
|
|
|
|
def freeze_mem(self) -> None: |
|
|
self.W_mb.weight.requires_grad = False |
|
|
self.W_mb.bias.requires_grad = False |
|
|
self.W_mq.weight.requires_grad = False |
|
|
self.W_mk.weight.requires_grad = False |
|
|
self.W_mv.weight.requires_grad = False |
|
|
|
|
|
|
|
|
def _get_segment_positions( |
|
|
self, position_ids: Optional[torch.LongTensor], start: int, end: int, device: torch.device |
|
|
) -> torch.LongTensor: |
|
|
|
|
|
if position_ids is not None: |
|
|
return position_ids[:, start:end] |
|
|
else: |
|
|
position_ids = torch.arange(start, end, device=device).long().unsqueeze(0) |
|
|
return position_ids |
|
|
|
|
|
|
|
|
def pad_attention_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype): |
|
|
if self.num_mem_tokens in {0, None} and not self.use_sink: |
|
|
return attention_mask |
|
|
shape = list(attention_mask.shape) |
|
|
if len(shape) == 4: |
|
|
shape[-1] += self.num_mem_tokens + int(self.use_sink) |
|
|
shape[-2] += self.num_mem_tokens + int(self.use_sink) |
|
|
mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
|
|
mask[..., int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] = attention_mask |
|
|
if self.use_sink: |
|
|
mask[..., 0, 1:] = 0 |
|
|
mask[..., :-self.num_mem_tokens, -self.num_mem_tokens:] = 0 |
|
|
elif len(shape) == 2: |
|
|
shape[-1] += self.num_mem_tokens + int(self.use_sink) |
|
|
mask = torch.ones(*shape, dtype=dtype).to(attention_mask.device) |
|
|
mask[..., int(self.use_sink):-self.num_mem_tokens] = attention_mask |
|
|
else: |
|
|
raise ValueError("Attention mask must be 2D or 4D") |
|
|
return mask.to(dtype) |
|
|
|
|
|
|
|
|
def _get_memory_tokens(self, batch_size: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
if self._get_memory is None or self.num_mem_tokens == 0: |
|
|
return None, None |
|
|
memory = self._get_memory() |
|
|
sink = self._get_sink() if self.use_sink and self._get_sink is not None else None |
|
|
mem = memory.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
if sink is not None: |
|
|
sink = sink.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
return mem, sink |
|
|
|
|
|
|
|
|
def _alloc_initial_mem(self, device: torch.device, dtype: torch.dtype): |
|
|
W_mem = torch.zeros( |
|
|
1, |
|
|
self.n_heads, |
|
|
self.d_key // self.n_heads, |
|
|
self.d_model // self.n_heads, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
z = torch.zeros(1, self.n_heads, self.d_key // self.n_heads, device=device, dtype=dtype) if self.use_denom else None |
|
|
return W_mem, z |
|
|
|
|
|
def _associate_with_mem(self, hidden_states: torch.Tensor, W_mem: torch.Tensor, z: Optional[torch.Tensor]) -> torch.Tensor: |
|
|
q = self._to_heads(self.W_mq(hidden_states)) |
|
|
mq = self.phi(q) |
|
|
mq = F.normalize(mq, dim=-1, p=2.0) |
|
|
num = torch.einsum("ihjk,ihkt->ihjt", mq, W_mem) |
|
|
if self.use_denom and z is not None: |
|
|
denom = torch.einsum("ihk,ihjk->ihj", z, mq)[..., None] + 1e-5 |
|
|
hs = num / denom |
|
|
else: |
|
|
hs = num |
|
|
return self._from_heads(hs) |
|
|
|
|
|
def _update_mem_with_mem( |
|
|
self, |
|
|
mem_tokens: torch.Tensor, |
|
|
W_mem: torch.Tensor, |
|
|
z: Optional[torch.Tensor], |
|
|
first_seg: bool, |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], bool]: |
|
|
k = self._to_heads(self.W_mk(mem_tokens)) |
|
|
mk = self.phi(k) |
|
|
mk = F.normalize(mk, dim=-1, p=2.0) |
|
|
|
|
|
new_mv = self._to_heads(self.W_mv(mem_tokens)) |
|
|
if not first_seg: |
|
|
num = torch.einsum("ihjk,ihkt->ihjt", mk, W_mem) |
|
|
if self.use_denom and z is not None: |
|
|
denom = torch.einsum("ihj,ihkj->ihk", z, mk)[..., None] + 1e-5 |
|
|
prev_mv = num / denom |
|
|
if self.correction: |
|
|
new_info_coef = ( |
|
|
1 - denom / (torch.linalg.norm(mk, dim=-1) ** 2)[..., None] |
|
|
) |
|
|
new_info_coef = torch.clip(new_info_coef, 0, 1).detach() |
|
|
else: |
|
|
new_info_coef = 1 |
|
|
else: |
|
|
prev_mv = num |
|
|
new_info_coef = 1 |
|
|
else: |
|
|
prev_mv = torch.zeros_like(new_mv, device=new_mv.device) |
|
|
new_info_coef = 1 |
|
|
|
|
|
mv = new_mv - prev_mv |
|
|
mb = self._to_heads(torch.sigmoid(self.W_mb(mem_tokens))) |
|
|
einop = f"ihjk,ihjt,ihj{'t' if self.gating else 'x'}->ihkt" |
|
|
associations = torch.einsum(einop, mk, mv, mb) |
|
|
W_mem = W_mem + associations |
|
|
if self.use_denom and z is not None: |
|
|
z = z + (new_info_coef * mk).sum(dim=-2) |
|
|
return W_mem, z, False |
|
|
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs): |
|
|
""" |
|
|
Convert positional args of the wrapped HF block into keyword args by |
|
|
introspecting the block's forward signature. This prevents accidental |
|
|
misplacement (e.g., a cache object being treated as attention_mask). |
|
|
""" |
|
|
|
|
|
try: |
|
|
sig = inspect.signature(self.layer.forward) |
|
|
params = list(sig.parameters.values()) |
|
|
|
|
|
param_names = [p.name for p in params[1:]] |
|
|
|
|
|
if len(param_names) > 0 and param_names[0] in {"hidden_states", "x"}: |
|
|
param_names = param_names[1:] |
|
|
except Exception: |
|
|
param_names = [] |
|
|
|
|
|
for idx, arg in enumerate(args): |
|
|
if idx >= len(param_names): |
|
|
break |
|
|
name = param_names[idx] |
|
|
if name not in kwargs: |
|
|
kwargs[name] = arg |
|
|
|
|
|
|
|
|
if "layer_past" in kwargs and "past_key_values" not in kwargs: |
|
|
layer_past = kwargs.pop("layer_past") |
|
|
try: |
|
|
if isinstance(layer_past, DynamicCache): |
|
|
kwargs["past_key_values"] = layer_past |
|
|
else: |
|
|
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(layer_past) |
|
|
except Exception: |
|
|
kwargs["past_key_values"] = layer_past |
|
|
|
|
|
|
|
|
attention_mask = kwargs.pop("attention_mask", None) |
|
|
|
|
|
return self.forward_horizontal(hidden_states, attention_mask, **kwargs) |
|
|
|
|
|
|
|
|
def forward_horizontal(self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs): |
|
|
assert not self.generate_mode, "Generate mode is not supported for horizontal forward" |
|
|
assert attention_mask is None or attention_mask.dim() == 4, "Attention mask must be 4D" |
|
|
using_cache = not is_empty_past_key_values(kwargs.get("past_key_values"), self.info['layer']) |
|
|
assert not using_cache or (kwargs.get('past_attn_mask') is not None and kwargs.get('past_attn_mask').shape[-1] == self.segment_size), "When using cache, past_attn_mask must be provided and have the same length as the segment size" |
|
|
|
|
|
if isinstance(hidden_states, (tuple, list)): |
|
|
hidden_states = hidden_states[0] |
|
|
bsz, seq_len, _ = hidden_states.shape |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones(bsz, seq_len, device=hidden_states.device, dtype=hidden_states.dtype) |
|
|
attention_mask = attn_mask_to_4d(attention_mask, upper=False, query_len=seq_len) |
|
|
attention_mask = invert_attn_mask(attention_mask, hidden_states.dtype) |
|
|
out_full = [] |
|
|
|
|
|
|
|
|
if self.memory_state is not None: |
|
|
W_mem, z = self.memory_state |
|
|
first_seg = False |
|
|
else: |
|
|
W_mem, z = self._alloc_initial_mem(hidden_states.device, hidden_states.dtype) |
|
|
first_seg = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
provided_cache = kwargs.get("past_key_values") |
|
|
past_key_values = provided_cache if provided_cache is not None else DynamicCache() |
|
|
past_attn_mask = kwargs.get('past_attn_mask') if using_cache else None |
|
|
present_kv = None |
|
|
|
|
|
|
|
|
|
|
|
seg_num = 0 |
|
|
for start in range(0, seq_len, self.segment_size+self.num_mem_tokens+int(self.use_sink)): |
|
|
real_start = start+int(self.use_sink) |
|
|
real_end = min(real_start + self.segment_size, seq_len-self.num_mem_tokens) |
|
|
end = real_end+self.num_mem_tokens |
|
|
seg_aug = hidden_states[:, start:end, :] |
|
|
seg_len = real_end - real_start |
|
|
|
|
|
attn_mask = attention_mask[:, :, real_start:real_end, real_start:real_end] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_last_segment = (end >= seq_len) |
|
|
|
|
|
|
|
|
if not first_seg: |
|
|
assoc = self._associate_with_mem(seg_aug, W_mem, z) |
|
|
seg_aug = assoc + seg_aug |
|
|
|
|
|
|
|
|
seg_aug_len = seg_aug.size(1) |
|
|
|
|
|
if self.sliding_window: |
|
|
|
|
|
|
|
|
base_cur4d = reverse_invert_attn_mask(attn_mask) |
|
|
seg_mask = self.pad_attention_mask(base_cur4d, dtype=seg_aug.dtype) |
|
|
seg_mask = invert_attn_mask(seg_mask, seg_aug.dtype) |
|
|
|
|
|
if past_attn_mask is not None: |
|
|
|
|
|
base_past4d = attn_mask_to_4d(attn_mask_to_2d(past_attn_mask), upper=True, query_len=seg_aug_len) |
|
|
if self.use_sink: |
|
|
base_past4d[:, :, 0, :] = 0 |
|
|
|
|
|
base_past4d = invert_attn_mask(base_past4d, seg_aug.dtype) |
|
|
|
|
|
|
|
|
|
|
|
seg_mask = torch.cat([base_past4d, seg_mask], dim=-1) |
|
|
if os.environ.get("ARMT_DEBUG_SW"): |
|
|
print(f"[H-SEG] L{self.info['layer']} seg_len={seg_len} seg_aug_len={seg_aug_len} mask={tuple(seg_mask.shape)}") |
|
|
else: |
|
|
base_cur4d = reverse_invert_attn_mask(attn_mask) |
|
|
seg_mask = self.pad_attention_mask(base_cur4d, dtype=seg_aug.dtype) |
|
|
seg_mask = invert_attn_mask(seg_mask, seg_aug.dtype) |
|
|
|
|
|
|
|
|
seg_pos_ids = self._get_segment_positions(kwargs.get("position_ids", None), start, end, seg_aug.device) |
|
|
|
|
|
|
|
|
seg_args = tuple(segment_tensor(a, start, end, seq_len) if isinstance(a, torch.Tensor) else a for a in args) |
|
|
seg_kwargs = {k: segment_tensor(v, start, end, seq_len) for k, v in kwargs.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seg_kwargs["attention_mask"] = seg_mask.to(seg_aug.dtype) |
|
|
if seg_pos_ids is not None: |
|
|
seg_kwargs["position_ids"] = seg_pos_ids |
|
|
seg_kwargs["use_cache"] = self.sliding_window |
|
|
|
|
|
if self.sliding_window: |
|
|
seg_kwargs["past_key_values"] = past_key_values |
|
|
else: |
|
|
|
|
|
seg_kwargs.pop("layer_past", None) |
|
|
seg_kwargs.pop("cache_position", None) |
|
|
seg_kwargs.pop("past_key_values", None) |
|
|
seg_kwargs["use_cache"] = False |
|
|
|
|
|
if self._rotary_fn is not None and seg_pos_ids is not None: |
|
|
cos, sin = self._rotary_fn(seg_aug, seg_pos_ids) |
|
|
seg_kwargs["position_embeddings"] = (cos, sin) |
|
|
|
|
|
|
|
|
layer_out = self.layer(seg_aug, *seg_args, **seg_kwargs) |
|
|
if self.sliding_window: |
|
|
assert past_key_values is not None, "Past key values object must be provided" |
|
|
|
|
|
if os.environ.get("ARMT_DEBUG_SW"): |
|
|
k = past_key_values.layers[self.info['layer']].keys |
|
|
v = past_key_values.layers[self.info['layer']].values |
|
|
print(f"[H-CACHE:pre] L{self.info['layer']} K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
|
|
past_key_values = self.update_past_key_values_sw(past_key_values, self.segment_size) |
|
|
if os.environ.get("ARMT_DEBUG_SW"): |
|
|
k = past_key_values.layers[self.info['layer']].keys |
|
|
v = past_key_values.layers[self.info['layer']].values |
|
|
print(f"[H-CACHE:post] L{self.info['layer']} K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
|
|
if isinstance(layer_out, tuple): |
|
|
seg_out = layer_out[0] |
|
|
else: |
|
|
seg_out = layer_out |
|
|
|
|
|
seg_mem_out = seg_out[:, -self.num_mem_tokens:, :] |
|
|
W_mem, z, first_seg = self._update_mem_with_mem( |
|
|
seg_mem_out, W_mem, z, first_seg |
|
|
) |
|
|
first_seg = False |
|
|
|
|
|
out_full.append(seg_out) |
|
|
|
|
|
past_attn_mask = attn_mask |
|
|
seg_num += 1 |
|
|
|
|
|
merged = torch.cat(out_full, dim=1) |
|
|
|
|
|
|
|
|
self.memory_state = (W_mem, z) |
|
|
|
|
|
if isinstance(layer_out, tuple): |
|
|
YELLOW = "\033[93m" |
|
|
if len(layer_out) == 1: |
|
|
return (merged,) |
|
|
elif len(layer_out) == 2: |
|
|
warnings.warn(f"{YELLOW}Last attention was not tested for horizontal forward{RESET}") |
|
|
return (merged, None) |
|
|
elif len(layer_out) == 3: |
|
|
warnings.warn(f"{YELLOW}Last attention and kv states were not tested for horizontal forward{RESET}") |
|
|
return (merged, None, present_kv) |
|
|
else: |
|
|
raise ValueError(f"Expected 1, 2 or 3 elements in layer output, got {len(layer_out)}") |
|
|
else: |
|
|
return merged |
|
|
|
|
|
def update_past_key_values_sw(self, past_key_values, window_size): |
|
|
""" |
|
|
Update past key values for sliding window attention. |
|
|
This keeps only the most recent tokens within the window size. |
|
|
""" |
|
|
if is_empty_past_key_values(past_key_values, self.info['layer']): |
|
|
return None |
|
|
|
|
|
|
|
|
if hasattr(past_key_values, 'to_legacy_cache'): |
|
|
legacy = past_key_values.to_legacy_cache() |
|
|
legacy = past_key_values.to_legacy_cache() |
|
|
|
|
|
|
|
|
k, v = legacy[self.info['layer']] |
|
|
k = k[..., -window_size-self.num_mem_tokens:-self.num_mem_tokens, :] |
|
|
v = v[..., -window_size-self.num_mem_tokens:-self.num_mem_tokens, :] |
|
|
|
|
|
past_key_values.layers[self.info['layer']].keys = k |
|
|
past_key_values.layers[self.info['layer']].values = v |
|
|
return past_key_values |
|
|
|
|
|
|
|
|
class InnerLoopARMTForCausalLM(PreTrainedModel): |
|
|
""" |
|
|
Drop-in ARMT model that installs InnerLoopAssociativeLayerWrapper into a base |
|
|
HF Causal LM. All segmentation happens inside each wrapped layer; no outer |
|
|
recurrent driver is needed. |
|
|
""" |
|
|
|
|
|
|
|
|
config_class = ARMTConfig |
|
|
|
|
|
def __init__(self, config: PretrainedConfig, **kwargs): |
|
|
global LIGER_KERNEL_AVAILABLE |
|
|
super().__init__(config, **kwargs) |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
base_model = None |
|
|
bm_cfg = getattr(config, "base_model_config", None) |
|
|
bm_name = getattr(config, "base_model_name", None) |
|
|
|
|
|
if 'llama' not in bm_name: |
|
|
LIGER_KERNEL_AVAILABLE = False |
|
|
os.environ["ARMT_DISABLE_LIGER_KERNEL"] = "1" |
|
|
if LIGER_KERNEL_AVAILABLE and not os.environ.get("ARMT_DISABLE_LIGER_KERNEL"): |
|
|
apply_liger_kernel_to_llama() |
|
|
|
|
|
if bm_cfg is not None and bm_name is not None: |
|
|
raise ValueError("Exactly one of `base_model_name` or `base_model_config` must be provided in config.") |
|
|
if bm_cfg is not None: |
|
|
if isinstance(bm_cfg, PretrainedConfig) and getattr(bm_cfg, "model_type", None) != getattr(config, "model_type", None): |
|
|
resolved_cfg = bm_cfg |
|
|
elif isinstance(bm_cfg, dict): |
|
|
from transformers import AutoConfig as HF_AutoConfig |
|
|
|
|
|
if "model_type" not in bm_cfg: |
|
|
raise ValueError("`base_model_config` dict must include a 'model_type' key.") |
|
|
cfg_or_inst = HF_AutoConfig.for_model(bm_cfg["model_type"]) |
|
|
if isinstance(cfg_or_inst, PretrainedConfig): |
|
|
resolved_cfg = cfg_or_inst |
|
|
for k, v in bm_cfg.items(): |
|
|
setattr(resolved_cfg, k, v) |
|
|
else: |
|
|
resolved_cfg = cfg_or_inst.from_dict(bm_cfg) |
|
|
elif isinstance(bm_cfg, str): |
|
|
from transformers import AutoConfig as HF_AutoConfig |
|
|
|
|
|
resolved_cfg = HF_AutoConfig.from_pretrained(bm_cfg) |
|
|
else: |
|
|
raise TypeError("`base_model_config` must be a transformers.PretrainedConfig, dict, or str.") |
|
|
base_model = AutoModelForCausalLM.from_config(resolved_cfg) |
|
|
elif bm_name is not None: |
|
|
from transformers import AutoModelForCausalLM as HF_AutoModelForCausalLM |
|
|
|
|
|
base_model = HF_AutoModelForCausalLM.from_pretrained(bm_name) |
|
|
else: |
|
|
raise ValueError("InnerLoopARMTForCausalLM requires either `base_model_config` or `base_model_name` in the config.") |
|
|
|
|
|
|
|
|
self.model = base_model |
|
|
|
|
|
|
|
|
self.num_mem_tokens = int(getattr(config, "num_mem_tokens", 0) or 0) |
|
|
self.d_mem = int(getattr(config, "d_mem", 512)) |
|
|
self.segment_size = int(getattr(config, "segment_size", 512)) |
|
|
self.segment_alignment = getattr(config, "segment_alignment", "left") |
|
|
if self.segment_alignment != 'left': |
|
|
raise |
|
|
self.layers_attr = getattr(config, "layers_attr", "model.layers") |
|
|
self.correction = bool(getattr(config, "correction", True)) |
|
|
self.n_heads = int(getattr(config, "n_heads", 1)) |
|
|
self.use_denom = bool(getattr(config, "use_denom", True)) |
|
|
self.gating = bool(getattr(config, "gating", False)) |
|
|
self.freeze_mem_flag = bool(getattr(config, "freeze_mem", False)) |
|
|
self.use_sink = bool(getattr(config, "use_sink", False)) |
|
|
self.sliding_window = bool(getattr(config, "sliding_window", False)) |
|
|
|
|
|
|
|
|
emb = self.model.get_input_embeddings() |
|
|
d_model = emb.embedding_dim |
|
|
memory_dim = getattr(self.model.config, "n_embd", getattr(self.model.config, "hidden_size", d_model)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory_weights = torch.empty( |
|
|
(self.num_mem_tokens, memory_dim), device=emb.weight.device, dtype=emb.weight.dtype |
|
|
) |
|
|
|
|
|
torch.nn.init.normal_(memory_weights, mean=0.0, std=0.02) |
|
|
self.memory = nn.Parameter(memory_weights, requires_grad=True) |
|
|
if self.use_sink: |
|
|
self.sink = nn.Parameter( |
|
|
torch.randn((1, memory_dim), device=emb.weight.device, dtype=emb.weight.dtype), requires_grad=True |
|
|
) |
|
|
|
|
|
def _get_layers_from_model(model_root: nn.Module): |
|
|
obj = model_root |
|
|
for attr in self.layers_attr.split("."): |
|
|
obj = getattr(obj, attr) |
|
|
return obj |
|
|
|
|
|
layers = _get_layers_from_model(self.model) |
|
|
self.wrap_layers = config.get("wrap_layers", [1,] * len(layers)) |
|
|
assert len(self.wrap_layers) == len(layers) |
|
|
rotary_fn = None |
|
|
if hasattr(self.model, "model") and hasattr(self.model.model, "rotary_emb"): |
|
|
rotary_fn = self.model.model.rotary_emb |
|
|
elif hasattr(self.model, "gpt_neox") and hasattr(self.model.gpt_neox, "rotary_emb"): |
|
|
rotary_fn = self.model.gpt_neox.rotary_emb |
|
|
|
|
|
for i in range(len(layers)): |
|
|
if self.wrap_layers[i]: |
|
|
layers[i] = InnerLoopAssociativeLayerWrapper( |
|
|
layer=layers[i], |
|
|
d_model=d_model, |
|
|
num_mem_tokens=self.num_mem_tokens, |
|
|
d_mem=self.d_mem, |
|
|
segment_size=self.segment_size, |
|
|
n_heads=self.n_heads, |
|
|
correction=self.correction, |
|
|
use_denom=self.use_denom, |
|
|
gating=self.gating, |
|
|
use_sink=self.use_sink, |
|
|
sliding_window=self.sliding_window, |
|
|
get_memory_fn=lambda self_ref=self: self_ref.memory, |
|
|
get_sink_fn=lambda self_ref=self: getattr(self_ref, "sink", None), |
|
|
rotary_fn=rotary_fn, |
|
|
info={"layer": i}, |
|
|
) |
|
|
|
|
|
if self.freeze_mem_flag: |
|
|
for layer in _get_layers_from_model(self.model): |
|
|
layer.freeze_mem() |
|
|
|
|
|
|
|
|
|
|
|
self.get_layers = lambda: _get_layers_from_model(self.model) |
|
|
|
|
|
self.vertical_mode = False |
|
|
|
|
|
|
|
|
def generate_mode(self, is_on: bool): |
|
|
for layer in self.get_layers(): |
|
|
layer.generate_mode = is_on |
|
|
|
|
|
def zero_mem(self): |
|
|
"""Reset memory state for all layers.""" |
|
|
for layer in self.get_layers(): |
|
|
layer.zero_mem() |
|
|
|
|
|
def detach_mem(self): |
|
|
"""Detach memory state for all layers.""" |
|
|
for layer in self.get_layers(): |
|
|
layer.detach_mem() |
|
|
|
|
|
def augment_sequence(self, hidden_states: torch.Tensor, mem: torch.Tensor, sink: torch.Tensor = None): |
|
|
segments = torch.split(hidden_states, self.segment_size, dim=1) |
|
|
if sink is not None: |
|
|
augmented_segments = [torch.cat([sink.to(segment.dtype).to(segment.device), segment, mem.to(segment.dtype).to(segment.device)], dim=1) for segment in segments] |
|
|
else: |
|
|
augmented_segments = [torch.cat([segment, mem.to(segment.dtype).to(segment.device)], dim=1) for segment in segments] |
|
|
augmented_sequence = torch.cat(augmented_segments, dim=1) |
|
|
|
|
|
return augmented_sequence |
|
|
|
|
|
def clean_sequence(self, hidden_states: torch.Tensor): |
|
|
augmented_segments = torch.split(hidden_states, self.segment_size+self.num_mem_tokens+int(self.use_sink), dim=1) |
|
|
segments = [segment[:, int(self.use_sink):-self.num_mem_tokens] for segment in augmented_segments] |
|
|
return torch.cat(segments, dim=1) |
|
|
|
|
|
def augment_attention_mask(self, attention_mask: torch.Tensor): |
|
|
segments = torch.split(attention_mask, self.segment_size, dim=1) |
|
|
if self.use_sink: |
|
|
augmented_segments = [torch.cat([ |
|
|
torch.ones(segment.shape[0], 1, device=segment.device, dtype=segment.dtype), |
|
|
segment, |
|
|
torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
|
|
], dim=1) for segment in segments] |
|
|
else: |
|
|
augmented_segments = [torch.cat([ |
|
|
segment, |
|
|
torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
|
|
], dim=1) for segment in segments] |
|
|
augmented_attention_mask = torch.cat(augmented_segments, dim=1) |
|
|
return augmented_attention_mask |
|
|
|
|
|
def augment_labels(self, labels): |
|
|
if labels is None: |
|
|
return None |
|
|
first = labels[:, :1] |
|
|
segments = torch.split(labels[:, 1:], self.segment_size, dim=1) |
|
|
if self.use_sink: |
|
|
augmented_segments = [torch.cat([ |
|
|
-100 * torch.ones(segment.shape[0], 1, device=segment.device, dtype=segment.dtype), |
|
|
segment, |
|
|
-100 * torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
|
|
], dim=1) for segment in segments] |
|
|
else: |
|
|
augmented_segments = [torch.cat([ |
|
|
segment, |
|
|
-100 * torch.ones(segment.shape[0], self.num_mem_tokens, device=segment.device, dtype=segment.dtype) |
|
|
], dim=1) for segment in segments] |
|
|
augmented_segments = torch.cat(augmented_segments, dim=1) |
|
|
augmented_labels = torch.cat([first, augmented_segments], dim=1) |
|
|
return augmented_labels |
|
|
|
|
|
def augment(self, input_ids, inputs_embeds, attention_mask, labels): |
|
|
if input_ids is not None: |
|
|
assert inputs_embeds is None, "input_ids and inputs_embeds cannot be provided together" |
|
|
hidden_states = self.model.get_input_embeddings()(input_ids) |
|
|
elif inputs_embeds is not None: |
|
|
hidden_states = inputs_embeds |
|
|
else: |
|
|
raise ValueError("Either input_ids or inputs_embeds must be provided") |
|
|
mem = self.memory.unsqueeze(0).expand(hidden_states.size(0), -1, -1) |
|
|
sink = self.sink.unsqueeze(0).expand(hidden_states.size(0), -1, -1) if self.use_sink else None |
|
|
|
|
|
augmented_hidden_states = self.augment_sequence(hidden_states, mem, sink) |
|
|
augmented_attention_mask = self.augment_attention_mask(attention_mask) |
|
|
augmented_labels = self.augment_labels(labels) |
|
|
return augmented_hidden_states, augmented_attention_mask, augmented_labels |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
labels=None, |
|
|
labels_mask=None, |
|
|
inputs_embeds=None, |
|
|
attention_mask=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
output_only_last_segment=False, |
|
|
num_items_in_batch=None, |
|
|
use_cache=None, |
|
|
past_key_values=None, |
|
|
): |
|
|
if labels_mask is not None: |
|
|
assert labels_mask.any(), "labels_mask must not be all zeros" |
|
|
|
|
|
effective_labels = labels |
|
|
if labels is not None and labels_mask is not None: |
|
|
if isinstance(labels_mask, torch.Tensor): |
|
|
mask_bool = labels_mask.bool() if labels_mask.dtype != torch.bool else labels_mask |
|
|
effective_labels = labels.masked_fill(~mask_bool, -100) |
|
|
else: |
|
|
raise ValueError("labels_mask must be a torch.Tensor") |
|
|
|
|
|
if attention_mask is None: |
|
|
if input_ids is not None: |
|
|
attention_mask = torch.ones(input_ids.shape[0], input_ids.shape[1], device=input_ids.device, dtype=input_ids.dtype) |
|
|
else: |
|
|
attention_mask = torch.ones(inputs_embeds.shape[0], inputs_embeds.shape[1], device=inputs_embeds.device, dtype=inputs_embeds.dtype) |
|
|
|
|
|
if self.vertical_mode: |
|
|
return self.forward_vertical( |
|
|
input_ids=input_ids, |
|
|
labels=effective_labels, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
output_only_last_segment=output_only_last_segment, |
|
|
num_items_in_batch=num_items_in_batch, |
|
|
use_cache=use_cache, |
|
|
past_key_values=past_key_values, |
|
|
past_attn_mask=None |
|
|
) |
|
|
else: |
|
|
return self.forward_horizontal( |
|
|
input_ids=input_ids, |
|
|
labels=effective_labels, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
output_only_last_segment=output_only_last_segment, |
|
|
num_items_in_batch=num_items_in_batch, |
|
|
use_cache=use_cache, |
|
|
past_key_values=past_key_values |
|
|
) |
|
|
def forward_vertical( |
|
|
self, |
|
|
input_ids=None, |
|
|
labels=None, |
|
|
inputs_embeds=None, |
|
|
attention_mask=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
output_only_last_segment=False, |
|
|
num_items_in_batch=None, |
|
|
use_cache=None, |
|
|
past_key_values=None, |
|
|
past_attn_mask=None, |
|
|
): |
|
|
assert not self.training or os.environ.get("ARMT_DISABLE_LIGER_KERNEL"), "Liger kernel is not supported for training in vertical mode, to disable liger kernel, set ARMT_DISABLE_LIGER_KERNEL=1" |
|
|
|
|
|
if input_ids is not None: |
|
|
assert inputs_embeds is None |
|
|
B, L = input_ids.shape |
|
|
device = input_ids.device |
|
|
elif inputs_embeds is not None: |
|
|
B, L, _ = inputs_embeds.shape |
|
|
device = inputs_embeds.device |
|
|
else: |
|
|
raise ValueError("Either input_ids or inputs_embeds must be provided") |
|
|
dtype = next(self.model.parameters()).dtype |
|
|
|
|
|
augmented_hidden_states, augmented_attention_mask, augmented_labels = self.augment(input_ids, inputs_embeds, attention_mask, labels) |
|
|
|
|
|
|
|
|
def split_tensor(tensor: torch.Tensor, segment_size: int): |
|
|
return torch.split(tensor, segment_size+self.num_mem_tokens+int(self.use_sink), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
seg_inputs_embeds = split_tensor(augmented_hidden_states, self.segment_size) |
|
|
seg_attention_mask = split_tensor(augmented_attention_mask, self.segment_size) if attention_mask is not None else None |
|
|
seg_labels = split_tensor(augmented_labels, self.segment_size) if labels is not None else None |
|
|
|
|
|
num_segments = len(seg_inputs_embeds) |
|
|
segments = [] |
|
|
for i in range(num_segments): |
|
|
segments.append({ |
|
|
"inputs_embeds": seg_inputs_embeds[i], |
|
|
"attention_mask": None if seg_attention_mask is None else seg_attention_mask[i], |
|
|
"labels": None if seg_labels is None else seg_labels[i], |
|
|
}) |
|
|
|
|
|
|
|
|
use_sliding = bool(self.sliding_window) |
|
|
shared_cache = past_key_values if (use_sliding and past_key_values is not None) else (DynamicCache() if use_sliding else None) |
|
|
past_attn_mask = past_attn_mask if use_sliding else None |
|
|
|
|
|
pos_offset = 0 |
|
|
|
|
|
|
|
|
seg_outputs = [] |
|
|
layers = self.get_layers() |
|
|
for seg in segments: |
|
|
seg_len = seg["inputs_embeds"].size(1) |
|
|
if seg.get("attention_mask") is None: |
|
|
base_2d = torch.ones(B, seg_len, device=device, dtype=dtype) |
|
|
else: |
|
|
base_2d = seg["attention_mask"] |
|
|
cur4d = attn_mask_to_4d(base_2d, upper=False, query_len=seg_len) |
|
|
cur4d = invert_attn_mask(cur4d, dtype=dtype) |
|
|
|
|
|
|
|
|
position_ids = torch.arange(pos_offset, pos_offset + seg_len, device=device).long().unsqueeze(0) |
|
|
|
|
|
|
|
|
orig_forwards = [ly.forward for ly in layers] |
|
|
seg_past_attn_mask = past_attn_mask |
|
|
def _inject_mask(orig_fn, mask): |
|
|
def _wrapped(hs, *a, **k): |
|
|
|
|
|
if mask is not None: |
|
|
if 'past_attn_mask' not in k: |
|
|
k['past_attn_mask'] = mask |
|
|
|
|
|
if 'past_key_values' not in k or k['past_key_values'] is None: |
|
|
k['past_key_values'] = shared_cache |
|
|
|
|
|
if hasattr(k['past_key_values'], 'layers') and len(k['past_key_values'].layers) < len(layers): |
|
|
|
|
|
needed = len(layers) - len(k['past_key_values'].layers) |
|
|
k['past_key_values'].layers.extend([type(k['past_key_values'].layers[0])() for _ in range(needed)]) |
|
|
k['use_cache'] = True |
|
|
return orig_fn(hs, *a, **k) |
|
|
return _wrapped |
|
|
for i, ly in enumerate(layers): |
|
|
ly.forward = _inject_mask(orig_forwards[i], seg_past_attn_mask) |
|
|
|
|
|
out = self.model( |
|
|
input_ids=seg.get("input_ids"), |
|
|
inputs_embeds=seg.get("inputs_embeds"), |
|
|
attention_mask=cur4d, |
|
|
position_ids=position_ids, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
use_cache=use_sliding, |
|
|
past_key_values=shared_cache if use_sliding else None, |
|
|
) |
|
|
if os.environ.get("ARMT_DEBUG_SW"): |
|
|
print(f"[V-SEG] seg_len={seg_len} cur4d={tuple(cur4d.shape)} pos=({int(position_ids[0,0])},{int(position_ids[0,-1])})") |
|
|
if hasattr(out, 'past_key_values') and out.past_key_values is not None: |
|
|
try: |
|
|
k = out.past_key_values.layers[0].keys |
|
|
v = out.past_key_values.layers[0].values |
|
|
print(f"[V-CACHE:out] L0 K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
for i, ly in enumerate(layers): |
|
|
ly.forward = orig_forwards[i] |
|
|
seg_outputs.append(out) |
|
|
|
|
|
if use_sliding: |
|
|
|
|
|
shared_cache = out.past_key_values if hasattr(out, 'past_key_values') else shared_cache |
|
|
if os.environ.get("ARMT_DEBUG_SW") and shared_cache is not None: |
|
|
try: |
|
|
k = shared_cache.layers[0].keys |
|
|
v = shared_cache.layers[0].values |
|
|
print(f"[V-CACHE:posttrim] L0 K={tuple(k.shape) if k is not None else None} V={tuple(v.shape) if v is not None else None}") |
|
|
except Exception: |
|
|
pass |
|
|
past_attn_mask = cur4d[:, :, int(self.use_sink):-self.num_mem_tokens, int(self.use_sink):-self.num_mem_tokens] |
|
|
pos_offset += seg_len |
|
|
|
|
|
|
|
|
|
|
|
full_logits = torch.cat([o.logits for o in seg_outputs], dim=1) if len(seg_outputs) > 1 else seg_outputs[0].logits |
|
|
|
|
|
result = {} |
|
|
result["logits"] = self.clean_sequence(full_logits) |
|
|
|
|
|
|
|
|
if labels is not None: |
|
|
labels = labels[:, -full_logits.size(1):] |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
flat_labels = shift_labels.view(-1) |
|
|
|
|
|
if labels_mask is not None: |
|
|
labels_mask = labels_mask[:, -full_logits.size(1):] |
|
|
shift_mask = labels_mask[..., :-1].contiguous() |
|
|
else: |
|
|
shift_mask = None |
|
|
|
|
|
shift_logits = full_logits[..., :-1, :].contiguous() |
|
|
flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
|
|
if shift_mask is not None: |
|
|
flat_logits = flat_logits[shift_mask.view(-1)] |
|
|
flat_labels = flat_labels[shift_mask.view(-1)] |
|
|
loss_fct = CrossEntropyLoss(reduction='sum') |
|
|
loss = loss_fct(flat_logits, flat_labels) |
|
|
|
|
|
if labels_mask is not None: |
|
|
denom = labels_mask[..., :-1].contiguous().view(-1).sum() |
|
|
else: |
|
|
denom = (flat_labels != -100).sum() |
|
|
denom = torch.clamp(denom, min=1) |
|
|
result["loss"] = loss / denom |
|
|
|
|
|
if output_hidden_states: |
|
|
if all(getattr(o, 'hidden_states', None) is not None for o in seg_outputs): |
|
|
|
|
|
full_hidden_states = tuple([ |
|
|
torch.cat(layer_hs, dim=1) |
|
|
for layer_hs in zip(*[o.hidden_states for o in seg_outputs]) |
|
|
]) |
|
|
result["hidden_states"] = full_hidden_states |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def forward_horizontal( |
|
|
self, |
|
|
input_ids=None, |
|
|
labels=None, |
|
|
inputs_embeds=None, |
|
|
attention_mask=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
output_only_last_segment=False, |
|
|
num_items_in_batch=None, |
|
|
use_cache=None, |
|
|
past_key_values=None, |
|
|
): |
|
|
augmented_hidden_states, augmented_attention_mask, augmented_labels = self.augment(input_ids, inputs_embeds, attention_mask, labels) |
|
|
out = self.model( |
|
|
labels=augmented_labels, |
|
|
inputs_embeds=augmented_hidden_states, |
|
|
attention_mask=augmented_attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
use_cache=use_cache, |
|
|
past_key_values=past_key_values, |
|
|
) |
|
|
if not LIGER_KERNEL_AVAILABLE: |
|
|
out.logits = self.clean_sequence(out.logits) |
|
|
self.zero_mem() |
|
|
return out |
|
|
|
|
|
def generate(self, input_ids, attention_mask=None, **generate_kwargs): |
|
|
""" |
|
|
Generate tokens using the inner-loop model with proper sliding window attention. |
|
|
This method should produce the same logits as the forward method for alignment. |
|
|
""" |
|
|
|
|
|
warnings.warn("Efficient generation is not implemented") |
|
|
if self.sliding_window: |
|
|
return self._generate_inefficient(input_ids, attention_mask, **generate_kwargs) |
|
|
else: |
|
|
|
|
|
return self._generate_inefficient(input_ids, attention_mask, **generate_kwargs) |
|
|
|
|
|
|
|
|
def _generate_standard(self, input_ids, attention_mask=None, **generate_kwargs): |
|
|
"""Standard generation without sliding window.""" |
|
|
generate_kwargs['output_scores'] = generate_kwargs.get('return_logits', False) |
|
|
generate_kwargs['return_dict_in_generate'] = generate_kwargs.get('return_logits', False) |
|
|
generate_kwargs.pop('return_logits') |
|
|
out = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) |
|
|
if generate_kwargs.get('output_scores', False): |
|
|
print(out.scores) |
|
|
return out.sequences, out.scores |
|
|
else: |
|
|
return out.sequences |
|
|
|
|
|
def _generate_inefficient(self, input_ids, attention_mask=None, **generate_kwargs): |
|
|
""" |
|
|
Generate tokens using sliding window attention that matches the forward method. |
|
|
This ensures alignment between generate and forward methods. |
|
|
INEFFICIENT: recomputes the entire sequence on every token generation. |
|
|
Kept for reference and testing purposes. |
|
|
""" |
|
|
max_new_tokens = generate_kwargs.get('max_new_tokens', 1) |
|
|
eos_token_id = generate_kwargs.get('eos_token_id', None) |
|
|
return_logits = generate_kwargs.get('return_logits', False) |
|
|
|
|
|
generated_ids = None |
|
|
all_logits = [] |
|
|
|
|
|
|
|
|
for i in range(max_new_tokens): |
|
|
|
|
|
if generated_ids is not None: |
|
|
current_input_ids = torch.cat([input_ids, generated_ids], dim=-1) |
|
|
current_attention_mask = torch.cat([attention_mask, torch.ones_like(generated_ids)], dim=-1) |
|
|
else: |
|
|
current_input_ids = input_ids |
|
|
current_attention_mask = attention_mask |
|
|
|
|
|
|
|
|
|
|
|
self.zero_mem() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.forward( |
|
|
input_ids=current_input_ids, |
|
|
attention_mask=current_attention_mask |
|
|
) |
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
|
|
|
|
|
if generated_ids is not None: |
|
|
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
|
|
else: |
|
|
generated_ids = next_token_id |
|
|
|
|
|
|
|
|
if return_logits: |
|
|
all_logits.append(next_token_logits) |
|
|
|
|
|
|
|
|
if eos_token_id is not None and (next_token_id == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
if return_logits: |
|
|
|
|
|
return generated_ids, torch.stack(all_logits, dim=1) |
|
|
else: |
|
|
return generated_ids |
|
|
|
|
|
def _generate_sliding_window(self, input_ids, attention_mask=None, **generate_kwargs): |
|
|
""" |
|
|
Generate tokens using sliding window attention with efficient caching. |
|
|
Uses the base model directly with past_key_values to avoid recomputing the entire sequence. |
|
|
This method should produce the same logits as the forward method for alignment. |
|
|
""" |
|
|
self.generate_mode(True) |
|
|
try: |
|
|
max_new_tokens = generate_kwargs.get('max_new_tokens', 1) |
|
|
eos_token_id = generate_kwargs.get('eos_token_id', None) |
|
|
return_logits = generate_kwargs.get('return_logits', False) |
|
|
|
|
|
|
|
|
self.zero_mem() |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
|
|
|
initial_outputs = self.forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
next_token_logits = initial_outputs.logits[:, -1, :] |
|
|
|
|
|
generated_ids = None |
|
|
all_logits = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_model = self.model |
|
|
window_size = self.segment_size + self.num_mem_tokens + int(self.use_sink) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
base_outputs = base_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
use_cache=True |
|
|
) |
|
|
past_key_values = base_outputs.past_key_values |
|
|
|
|
|
|
|
|
for i in range(max_new_tokens): |
|
|
|
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
|
|
|
|
|
if generated_ids is not None: |
|
|
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
|
|
else: |
|
|
generated_ids = next_token_id |
|
|
|
|
|
|
|
|
if return_logits: |
|
|
all_logits.append(next_token_logits) |
|
|
|
|
|
|
|
|
if eos_token_id is not None and (next_token_id == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
next_outputs = base_model( |
|
|
input_ids=next_token_id, |
|
|
attention_mask=torch.ones_like(next_token_id), |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True |
|
|
) |
|
|
next_token_logits = next_outputs.logits[:, -1, :] |
|
|
past_key_values = next_outputs.past_key_values |
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_values = self.update_past_key_values_sw(past_key_values, window_size) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error implementing efficient generation: {e}") |
|
|
print("This suggests the base model doesn't support the expected interface") |
|
|
print("Why could this happen?") |
|
|
print("1. The base model might not support past_key_values") |
|
|
print("2. The attention mask handling might be incompatible") |
|
|
print("3. The memory tokens might interfere with caching") |
|
|
print("4. The inner loop wrapper might not be compatible with base model caching") |
|
|
raise RuntimeError(f"Efficient generation failed: {e}") |
|
|
|
|
|
if return_logits: |
|
|
return generated_ids, torch.stack(all_logits, dim=1) |
|
|
else: |
|
|
return generated_ids |
|
|
finally: |
|
|
self.generate_mode(False) |
|
|
|
|
|
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): |
|
|
try: |
|
|
return super().load_state_dict(state_dict, strict, assign) |
|
|
except RuntimeError: |
|
|
|
|
|
self.model.load_state_dict(state_dict, strict=True) |
|
|
return |
|
|
|
|
|
def zero_mem(self): |
|
|
for layer in self.get_layers(): |
|
|
layer.zero_mem() |
|
|
|
|
|
def detach_mem(self): |
|
|
for layer in self.get_layers(): |
|
|
layer.detach_mem() |
|
|
|
|
|
def freeze_mem(self): |
|
|
for layer in self.get_layers(): |
|
|
layer.freeze_mem() |
|
|
|
|
|
|