| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def scan(f, init, xs, out, checkpoint_group=0): |
| """ |
| 模拟JAX中的lax.scan函数,用于序列化处理数据。 |
| |
| 参数: |
| f: 处理函数,接收(carry, x)作为输入,返回(new_carry, y) |
| init: 初始状态值 |
| xs: 输入序列,可以是字典或列表 |
| out: 输出结果的存储张量 |
| checkpoint_group: 梯度检查点分组数量,用于节省内存 |
| |
| 返回: |
| carry: 最终的状态值 |
| out: 填充好的输出张量 |
| """ |
| |
| carry = init |
| |
| |
| if isinstance(xs, dict): |
| |
| num_items = len(next(iter(xs.values()))) |
| else: |
| |
| num_items = len(xs[0]) |
|
|
| def scan_fn(carry, i_start, i_end): |
| """内部扫描函数,处理从i_start到i_end的元素""" |
| for i in range(i_start, i_end): |
| |
| if isinstance(xs, dict): |
| |
| x = {key: tensor[i] for key, tensor in xs.items()} |
| else: |
| |
| x = [x[i] for x in xs] |
| |
| |
| carry, y = f(carry, x) |
| |
| |
| out[i] = y |
| |
| |
| return carry |
|
|
| |
| if checkpoint_group > 0: |
| |
| ckpt_every_n = num_items // checkpoint_group |
| |
| |
| for k in range(0, num_items, ckpt_every_n): |
| |
| carry = torch.utils.checkpoint.checkpoint( |
| scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False |
| ) |
| else: |
| |
| carry = scan_fn(carry, 0, num_items) |
|
|
| |
| return carry, out |
|
|
| def ln_fwd(x, gamma, beta, eps=1e-6): |
| "Batch forward for LayerNorm." |
|
|
| |
| mu = x.mean(dim=-1, keepdim=True) |
| var = x.var(dim=-1, keepdim=True, unbiased=False) |
|
|
| |
| std = torch.sqrt(var + eps) |
| x_hat = (x - mu) / std |
|
|
| |
| y = gamma * x_hat + beta |
|
|
| return y |
|
|
| def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6): |
| """ |
| 层归一化(LayerNorm)与L2损失融合的反向传播函数。 |
| |
| 这个函数执行两个操作: |
| 1. 前向传播:对输入x进行层归一化,得到输出y |
| 2. 反向传播:计算L2损失(y - l2_target)对输入x的梯度 |
| |
| 参数: |
| x: 输入张量 |
| l2_target: L2损失的目标值 |
| gamma: 层归一化的缩放参数 |
| beta: 层归一化的偏移参数 |
| eps: 数值稳定性的小常数 |
| |
| 返回: |
| z: 损失对输入x的梯度 |
| """ |
| D = x.shape[-1] |
|
|
| |
| mu = x.mean(dim=-1, keepdim=True) |
| var = x.var(dim=-1, keepdim=True, unbiased=False) |
|
|
| |
| std = torch.sqrt(var + eps) |
| x_hat = (x - mu) / std |
|
|
| |
| y = gamma * x_hat + beta |
|
|
| |
| grad_output = y - l2_target |
| grad_x_hat = grad_output * gamma |
| |
| |
| z = ( |
| (1.0 / D) |
| * ( |
| D * grad_x_hat |
| - grad_x_hat.sum(dim=-1, keepdim=True) |
| - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) |
| ) |
| / std |
| ) |
|
|
| return z |
|
|
| from torch.autograd import Function |
| class MyLinearFunction(Function): |
| @staticmethod |
| def forward(ctx, input, weight, bias): |
| """ |
| 正向计算: y = x * W^T + b |
| 参数: |
| ctx :上下文对象,用于保存反向传播时需要的信息。 |
| input :输入 tensor, 尺寸为 (N, in_features) |
| weight :权重 tensor, 尺寸为 (out_features, in_features) |
| bias :偏置 tensor, 尺寸为 (out_features) |
| 返回: |
| 输出 tensor, 尺寸为 (N, out_features) |
| """ |
| |
| ctx.save_for_backward(input, weight, bias) |
| |
| |
| output = input.matmul(weight.t()) + bias |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| """ |
| 反向传播:计算正向计算中各个输入的梯度。 |
| 参数: |
| grad_output:从上层传回来的梯度,形状与 forward 的输出相同 (N, out_features) |
| 返回: |
| grad_input :关于 input 的梯度,形状 (N, in_features) |
| grad_weight :关于 weight 的梯度,形状 (out_features, in_features) |
| grad_bias :关于 bias 的梯度,形状 (out_features) |
| """ |
| |
| input, weight, bias = ctx.saved_tensors |
| |
| |
| |
| |
| grad_input = grad_output.matmul(weight) |
| |
| |
| |
| |
| |
| grad_weight = grad_output.t().matmul(input) |
| |
| |
| |
| grad_bias = grad_output.sum(dim=0) |
| |
| |
| return grad_input, grad_weight, grad_bias |
|
|
| class TTT_Cross_Layer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.input_size = config.concept_dim |
| self.concept_dim = config.concept_dim |
| |
| |
|
|
| |
| self.logit_dim = config.logit_dim |
|
|
| self.weight_linear = nn.Parameter(torch.empty(self.concept_dim, self.input_size, self.logit_dim)) |
| self.weight_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) |
| self.bias_linear = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) |
| self.bias_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim)) |
|
|
| |
| |
| |
| |
| |
| self.config = config |
| self.init_weights() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def init_params_as_logits(self, batch_size, sequence_length): |
| weight_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
| weight_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
| bias_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
| bias_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype) |
| |
| params = { |
| 'weight_linear_tmp': weight_linear_tmp, |
| 'weight_ln_tmp': weight_ln_tmp, |
| 'bias_linear_tmp': bias_linear_tmp, |
| 'bias_ln_tmp': bias_ln_tmp |
| } |
| return params |
|
|
| def init_weights(self): |
| |
| nn.init.normal_(self.weight_linear, mean=0.0, std=self.config.initializer_range) |
| nn.init._no_grad_fill_(self.weight_ln, 1.0 / self.logit_dim) |
| |
| |
| nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim) |
| nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim) |
|
|
| def get_weight_per_token(self, params): |
| |
| weight_linear_tmp = torch.einsum('iol,bsl->bsio', self.weight_linear, params['weight_linear_tmp']) |
| weight_ln_tmp = torch.einsum('ol,bsl->bso', self.weight_ln, params['weight_ln_tmp']) |
| bias_linear_tmp = torch.einsum('ol,bsl->bso', self.bias_linear, params['bias_linear_tmp']) |
| bias_ln_tmp = torch.einsum('ol,bsl->bso', self.bias_ln, params['bias_ln_tmp']) |
|
|
| return weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp |
|
|
| def learn(self, k, v, params, lr_linear=1, lr_ln=1): |
| |
| |
| |
| |
| |
| |
| weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params) |
| |
| |
| |
| |
| |
| |
| |
|
|
| z = torch.einsum('bsi,bsio->bso', k, weight_linear_tmp) + bias_linear_tmp |
| mu = z.mean(dim=-1, keepdim=True) |
| var = z.var(dim=-1, keepdim=True, unbiased=False) |
|
|
| |
| eps = 1e-6 |
| std = torch.sqrt(var + eps) |
| z_hat = (z - mu) / std |
| |
| output_reshaped = weight_ln_tmp * z_hat + bias_ln_tmp + k |
|
|
| |
| |
| |
| error_reshaped = output_reshaped - v |
| |
| |
| |
| grad_weight_ln_temp = error_reshaped * z_hat |
| |
| |
| grad_weight_ln = grad_weight_ln_temp |
| |
| params0 = params['weight_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.weight_ln, grad_weight_ln) |
| |
| |
| |
| grad_bias_ln = error_reshaped |
| params1 = params['bias_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.bias_ln, grad_bias_ln) |
|
|
| |
| |
| grad_linear = weight_ln_tmp * error_reshaped / std |
| |
| grad_weight_linear = torch.einsum('bsi,bso->bsio', k, grad_linear) |
| |
| |
| params2 = params['weight_linear_tmp'] - lr_linear * torch.einsum('iol,bsio->bsl', self.weight_linear, grad_weight_linear) |
| |
| grad_b = grad_linear |
| |
| params3 = params['bias_linear_tmp'] - lr_linear * torch.einsum('ol,bso->bsl', self.bias_linear, grad_b) |
| |
| params_new = { |
| 'weight_linear_tmp': params2, |
| 'weight_ln_tmp': params0, |
| 'bias_linear_tmp': params3, |
| 'bias_ln_tmp': params1 |
| } |
|
|
| return params_new |
|
|
| def predict(self, q, params): |
| weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params) |
| z = torch.einsum('bsi,bsio->bso', q, weight_linear_tmp) + bias_linear_tmp |
| mu = z.mean(dim=-1, keepdim=True) |
| var = z.var(dim=-1, keepdim=True, unbiased=False) |
|
|
| |
| eps = 1e-6 |
| std = torch.sqrt(var + eps) |
| z_hat = (z - mu) / std |
| |
| output = weight_ln_tmp * z_hat + bias_ln_tmp + q |
|
|
| return output |
| |
|
|
|
|
|
|
|
|