| import torch |
|
|
|
|
| def cal_n_log(log_theta, log_eta, seq_len): |
| """ |
| calculate n_{i,j} in log space |
| log(n_{i,j}) = log(θ_j) + sum_{k=j+1}^i log(η_k) |
| """ |
| |
| log_n = torch.zeros(*log_theta.shape, seq_len, dtype=log_eta.dtype).to( |
| log_eta.device |
| ) |
| for i in range(seq_len): |
| for j in range(i + 1): |
| if i == j: |
| log_n[..., j, i] = log_theta[..., j] |
| else: |
| log_n[..., j, i] = log_theta[..., j] + torch.sum( |
| log_eta[..., j + 1: i + 1], dim=-1 |
| ) |
|
|
| return log_n |
|
|
|
|
| def cal_f_log(log_beta, seq_len, log_m): |
| """ |
| cal_f_log(log_beta, seq_len, log_m) -> f |
| log(f_t) = log(sum_{i=1}^t exp(sum_{k=i+1}^t log(1-α_k) + sum_{k=1}^i log(η_k))) |
| """ |
| |
| |
| |
| |
| |
| log_f = torch.zeros_like(log_beta) |
| for t in range(seq_len): |
| a_i = log_beta[..., t: t + 1] - log_beta[..., : t + 1] + log_m[..., : t + 1] |
| log_f[..., t] = torch.logsumexp(a_i, dim=-1) |
| f = torch.exp(log_f) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return f |
|
|
|
|
| def cal_G_log(log_beta, log_n, seq_len): |
| """ |
| calculate G_{i,j} |
| log(G_{i,j}) = log(sum_{k=j}^i exp(log(β_i/β_k) + log(n_{k,j}))) |
| """ |
| |
| |
| |
| |
| |
| |
| |
|
|
| log_G = torch.full( |
| (*log_beta.shape[:-1], seq_len, seq_len), float("-inf"), device=log_beta.device |
| ) |
| |
| for i in range(seq_len): |
| for j in range(i + 1): |
| terms = ( |
| log_beta[..., i: i + 1] |
| - log_beta[..., j: i + 1] |
| + log_n[..., j: j + 1, j: i + 1].squeeze(-2) |
| ) |
| |
| log_G[..., i, j] = torch.logsumexp(terms, dim=-1) |
|
|
| G = torch.exp(log_G) |
| return G |
|
|
|
|
| def _combine_params_log(log_theta, log_alpha_complement, log_eta, seq_len): |
| """ |
| Update rule for Titans in log space |
| |
| Parameters: |
| - log_theta: log(θ) |
| - log_alpha_complement: log(1-α) |
| - log_eta: log(η) |
| - seq_len: sequence length |
| |
| Returns: |
| - log_beta, beta_T, log_f, f_T, log_g, log_G, m_T, n_T |
| """ |
| |
| log_beta = torch.cumsum(log_alpha_complement, dim=-1) |
|
|
| |
| beta_T = torch.exp(log_beta[..., -1]) |
|
|
| |
| log_m = torch.cumsum(log_eta, dim=-1) |
| m_T = torch.exp(log_m[..., -1]) |
|
|
| |
| log_n = cal_n_log(log_theta, log_eta, seq_len) |
| n_T = torch.exp(log_n[..., -1]) |
|
|
| |
| f = cal_f_log(log_beta, seq_len, log_m) |
| f_T = f[..., -1] |
|
|
| |
| G = cal_G_log(log_beta, log_n, seq_len) |
| |
| g = G[..., -1, :] |
|
|
| return log_beta, beta_T, f, f_T, g, G, m_T, n_T |
|
|
|
|
| def combine_params_log(theta, alpha, eta, seq_len): |
| """ |
| log space Titians |
| |
| Parameters: |
| - theta: θ |
| - alpha: α |
| - eta: η |
| - seq_len: sequence length |
| |
| Returns: |
| - beta, beta_T, f, f_T, g, G, m_T, n_T |
| """ |
| |
| log_theta = torch.log(theta.squeeze(-1)) |
| log_alpha_complement = torch.log(1 - alpha.squeeze(-1)) |
| log_eta = torch.log(eta.squeeze(-1)) |
|
|
| |
| log_beta, beta_T, f, f_T, g, G, m_T, n_T = _combine_params_log( |
| log_theta, log_alpha_complement, log_eta, seq_len |
| ) |
|
|
| |
| beta = torch.exp(log_beta) |
|
|
| return beta, beta_T, f, f_T, g, G, m_T, n_T |
|
|