| |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from fla.ops.titans.log_impl import combine_params_log |
|
|
|
|
| def cal_n(theta, eta, seq_len): |
| n = torch.zeros(*theta.shape, seq_len, dtype=theta.dtype).to( |
| theta.device |
| ) |
|
|
| |
| indices = torch.arange(seq_len, device=theta.device) |
| n[..., indices, indices] = theta[..., indices] |
|
|
| |
| |
| mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(theta.device) |
| |
| mask = mask.bool() |
| |
| eta_expanded = eta.unsqueeze(-2).expand(*theta.shape[:-1], seq_len, seq_len) |
| |
| cumulative = torch.ones_like(eta_expanded) |
| cumulative = torch.where(mask, eta_expanded, cumulative) |
| |
| cumulative_prod = torch.cumprod(cumulative, dim=-1) |
|
|
| |
| |
| theta_expanded = theta.unsqueeze(-1).expand(*theta.shape[:-1], seq_len, seq_len) |
| |
| upper_triangular = torch.triu(torch.ones_like(n), diagonal=1).bool() |
| |
| n = torch.where(upper_triangular, theta_expanded * cumulative_prod, n) |
| return n |
|
|
|
|
| def cal_f(beta, seq_len, m): |
| a = torch.tril(beta.to(torch.float32).unsqueeze(-1).expand(*beta.shape, seq_len), 0) |
| ratio = (m.to(torch.float32) / beta.to(torch.float32)).unsqueeze(-1) |
| f = torch.matmul(a, ratio).squeeze(-1) |
| return f.to(beta.dtype) |
|
|
|
|
| def cal_G(beta, n, seq_len): |
| i_indices = torch.arange(seq_len, device=beta.device) |
| j_indices = torch.arange(seq_len, device=beta.device) |
| k_indices = torch.arange(seq_len, device=beta.device) |
| beta_ratio = beta[..., :, None] / beta[..., None, :] |
|
|
| |
| k_mask = (k_indices[None, None, :] >= j_indices[None, :, None]) & ( |
| k_indices[None, None, :] <= i_indices[:, None, None] |
| ) |
|
|
| |
| masked_beta_ratio = beta_ratio[..., :, None, :] * k_mask |
| masked_n = n[..., None, :, :] * k_mask |
| |
| G = torch.sum(masked_beta_ratio * masked_n, dim=-1) |
| return G |
|
|
|
|
| def combine_params(theta, alpha, eta, seq_len): |
| theta = theta.squeeze(-1) |
| eta = eta.squeeze(-1) |
| alpha = alpha.squeeze(-1) |
| beta = torch.cumprod(1 - alpha, dim=-1) |
| beta_T = beta[..., -1] |
| |
| m = torch.cumprod(eta, dim=-1) |
| m_T = m[..., -1] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| n = cal_n(theta, eta, seq_len) |
| n_T = n[..., -1] |
| |
| |
| |
| |
| |
| f = cal_f(beta, seq_len, m) |
| f_T = f[..., -1] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| G = cal_G(beta, n, seq_len) |
| g = G[:, :, -1, :] |
| |
| return beta, beta_T, f, f_T, g, G, m_T, n_T |
|
|
|
|
| def titans_linear( |
| q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state |
| ): |
| """ |
| Implementation of Titans Linear function based on the update rules: |
| M_t = (1 - alpha_t) * M_{t-1} + S_t |
| S_t = eta_t * S_{t-1} - theta_t * nabla_l(M_{t-1}; x_t) |
| |
| Args: |
| q: Query tensor |
| k: Key tensor |
| v: Value tensor |
| w: Weight tensor |
| b: Bias tensor |
| theta: Learning rate tensor |
| alpha: Momentum decay tensor |
| eta: Step size tensor |
| eps: Epsilon for numerical stability |
| initial_state: Initial state M_0 |
| output_final_state: Whether to output the final state |
| |
| Returns: |
| Tuple of (output tensor, final state) |
| """ |
| B, H, T, D = q.shape |
| device = q.device |
| w = w.reshape(H, 1, D).to(torch.float32) |
| b = b.reshape(H, 1, D).to(torch.float32) |
| |
| if initial_state is None: |
| M_prev = torch.zeros(B, H, D, D, device=device) |
| else: |
| M_prev = initial_state |
| M_prev_nabla = M_prev.clone() |
| S_prev = torch.zeros_like(M_prev) |
| outputs = [] |
|
|
| |
| for t in range(T): |
| |
| q_t = q[:, :, t: t + 1, :] |
| k_t = k[:, :, t: t + 1, :] |
| v_t = v[:, :, t: t + 1, :] |
| theta_t = theta[:, :, t: t + 1, :] |
| alpha_t = alpha[:, :, t: t + 1, :] |
| eta_t = eta[:, :, t: t + 1, :] |
|
|
| |
| km = k_t @ M_prev_nabla |
| reconstruction_target = v_t - k_t |
| mean = km.mean(-1, keepdim=True) |
| var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32) |
| rstd = torch.sqrt(var + eps).to(torch.float32) |
| km_hat = (km - mean) / rstd |
|
|
| grad = w * km_hat + b - reconstruction_target |
| grad = grad * w |
| |
| |
| v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D) |
| proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D) |
| v_new = v_new - proj_term |
| |
|
|
| |
| S_t = eta_t * S_prev - 2 * theta_t * k_t.transpose(-2, -1) @ v_new |
|
|
| |
| M_t = (1 - alpha_t) * M_prev + S_t |
|
|
| |
| output_t = q_t @ M_t |
| mean = output_t.mean(dim=-1, keepdim=True) |
| var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) |
| rstd = torch.sqrt(var + eps).to(torch.float32) |
| output_t = output_t + (output_t - mean) / rstd * w + b |
| outputs.append(output_t) |
|
|
| |
| if (t + 1) % chunk_size == 0: |
| M_prev_nabla = M_t.clone() |
| M_prev = M_t |
| S_prev = S_t |
|
|
| |
| output = torch.stack(outputs, dim=-2).squeeze( |
| -3 |
| ) |
|
|
| if output_final_state: |
| return output, M_prev |
| return output, None |
|
|
|
|
| def chunk_titans_linear( |
| q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state |
| ): |
| B, H, T, D = q.shape |
| num_batch = T // chunk_size |
| |
| _q = q.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) |
| _k = k.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) |
| _v = v.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) |
| |
| _eta = eta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) |
| _theta = theta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) |
| _alpha = alpha.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) |
| |
| w = w.reshape(H, 1, D).to(torch.float32) |
| b = b.reshape(H, 1, D).to(torch.float32) |
| |
| if initial_state is None: |
| M_prev = torch.zeros((B, H, D, D), device=v.device, dtype=v.dtype).to( |
| torch.float32 |
| ) |
| else: |
| M_prev = initial_state |
|
|
| S_prev = torch.zeros_like(M_prev) |
|
|
| |
| o = torch.empty_like(_v) |
|
|
| for i in range(num_batch): |
| q_i, k_i, v_i, eta_i, theta_i, alpha_i = [ |
| x[i] for x in [_q, _k, _v, _eta, _theta, _alpha] |
| ] |
|
|
| |
| beta, beta_T, f, f_T, g, G, m_T, n = combine_params_log( |
| theta_i, alpha_i, eta_i, chunk_size |
| ) |
|
|
| m_T = m_T.unsqueeze(-1).unsqueeze(-1) |
| beta_T = beta_T.unsqueeze(-1).unsqueeze(-1) |
| f_T = f_T.unsqueeze(-1).unsqueeze(-1) |
| g_diag = torch.diag_embed(g).to(q_i.dtype) |
| n = torch.diag_embed(n).to(q_i.dtype) |
| beta = torch.diag_embed(beta).to(q_i.dtype) |
| f = torch.diag_embed(f).to(q_i.dtype) |
| km = k_i @ M_prev |
| reconstruction_target = v_i - k_i |
|
|
| mean = km.mean(-1, True) |
| var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32) |
| rstd = torch.sqrt(var + eps).to(torch.float32) |
| km_hat = (km - mean) / rstd |
|
|
| grad = w * km_hat + b - reconstruction_target |
| grad *= w |
| v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D) |
| proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D) |
| v_new = v_new - proj_term |
| |
| |
|
|
| |
|
|
| Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) * G |
|
|
| |
| output_t = beta @ q_i @ M_prev + f @ q_i @ S_prev - 2 * Attn @ v_new |
|
|
| M_t = ( |
| beta_T * M_prev |
| + f_T * S_prev |
| - 2 * (g_diag @ k_i).transpose(-1, -2) @ v_new |
| ) |
| |
| S_t = m_T * S_prev - 2 * (n @ k_i).transpose(-1, -2) @ v_new |
| |
| mean = output_t.mean(dim=-1, keepdim=True) |
| var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) |
| rstd = torch.sqrt(var + eps).to(torch.float32) |
| output_t = output_t + (output_t - mean) / rstd * w + b |
| o[i] = output_t |
| S_prev = S_t |
| M_prev = M_t |
|
|
| |
| o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D) |
| M_prev = M_prev if output_final_state else None |
| return o, M_prev |
|
|
|
|
| |
| def chunk_titans_linear_ref( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| w: torch.Tensor, |
| b: torch.Tensor, |
| theta: torch.Tensor, |
| alpha: torch.Tensor, |
| eta: torch.Tensor, |
| eps: float = 1e-6, |
| chunk_size: int = 16, |
| initial_state: torch.Tensor = None, |
| output_final_state: bool = False, |
| head_first: bool = True, |
| use_chunk: bool = True, |
| ): |
| assert q.dtype == k.dtype == v.dtype |
| assert k.shape[-1] == v.shape[-1], "DK must equal to DV." |
| if not head_first: |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| eta = eta.transpose(1, 2) |
| alpha = alpha.transpose(1, 2) |
| theta = theta.transpose(1, 2) |
| seq_len = q.shape[-2] |
| pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size |
| if pad_len > 0: |
| q = F.pad(q, (0, 0, 0, pad_len)) |
| k = F.pad(k, (0, 0, 0, pad_len)) |
| v = F.pad(v, (0, 0, 0, pad_len)) |
| theta = F.pad(theta, (0, 0, 0, pad_len)) |
| alpha = F.pad(alpha, (0, 0, 0, pad_len)) |
| eta = F.pad(eta, (0, 0, 0, pad_len)) |
| theta[:, :, -1, :] = theta[:, :, -(pad_len + 1), :] |
| alpha[:, :, -1, :] = alpha[:, :, -(pad_len + 1), :] |
| eta[:, :, -1, :] = eta[:, :, -(pad_len + 1), :] |
| assert q.shape[-2] % chunk_size == 0, "Sequence length should be a multiple of BT." |
| q, k, v, w, b = map(lambda x: x.to(torch.float32), [q, k, v, w, b]) |
| if use_chunk: |
| o, final_state = chunk_titans_linear( |
| q, |
| k, |
| v, |
| w, |
| b, |
| theta, |
| alpha, |
| eta, |
| eps, |
| chunk_size, |
| initial_state, |
| output_final_state, |
| ) |
| else: |
| o, final_state = titans_linear( |
| q, |
| k, |
| v, |
| w, |
| b, |
| theta, |
| alpha, |
| eta, |
| eps, |
| chunk_size, |
| initial_state, |
| output_final_state, |
| ) |
| o = o[:, :, :seq_len, :] |
| if not head_first: |
| o = o.transpose(1, 2) |
| return o, final_state |
|
|