| | import torch |
| |
|
| |
|
| | def cal_n_log(log_theta, log_eta, seq_len): |
| | """ |
| | calculate n_{i,j} in log space |
| | log(n_{i,j}) = log(θ_j) + sum_{k=j+1}^i log(η_k) |
| | """ |
| | |
| | log_n = torch.zeros(*log_theta.shape, seq_len, dtype=log_eta.dtype).to( |
| | log_eta.device |
| | ) |
| | for i in range(seq_len): |
| | for j in range(i + 1): |
| | if i == j: |
| | log_n[..., j, i] = log_theta[..., j] |
| | else: |
| | log_n[..., j, i] = log_theta[..., j] + torch.sum( |
| | log_eta[..., j + 1: i + 1], dim=-1 |
| | ) |
| |
|
| | return log_n |
| |
|
| |
|
| | def cal_f_log(log_beta, seq_len, log_m): |
| | """ |
| | cal_f_log(log_beta, seq_len, log_m) -> f |
| | log(f_t) = log(sum_{i=1}^t exp(sum_{k=i+1}^t log(1-α_k) + sum_{k=1}^i log(η_k))) |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | log_f = torch.zeros_like(log_beta) |
| | for t in range(seq_len): |
| | a_i = log_beta[..., t: t + 1] - log_beta[..., : t + 1] + log_m[..., : t + 1] |
| | log_f[..., t] = torch.logsumexp(a_i, dim=-1) |
| | f = torch.exp(log_f) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return f |
| |
|
| |
|
| | def cal_G_log(log_beta, log_n, seq_len): |
| | """ |
| | calculate G_{i,j} |
| | log(G_{i,j}) = log(sum_{k=j}^i exp(log(β_i/β_k) + log(n_{k,j}))) |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | log_G = torch.full( |
| | (*log_beta.shape[:-1], seq_len, seq_len), float("-inf"), device=log_beta.device |
| | ) |
| | |
| | for i in range(seq_len): |
| | for j in range(i + 1): |
| | terms = ( |
| | log_beta[..., i: i + 1] |
| | - log_beta[..., j: i + 1] |
| | + log_n[..., j: j + 1, j: i + 1].squeeze(-2) |
| | ) |
| | |
| | log_G[..., i, j] = torch.logsumexp(terms, dim=-1) |
| |
|
| | G = torch.exp(log_G) |
| | return G |
| |
|
| |
|
| | def _combine_params_log(log_theta, log_alpha_complement, log_eta, seq_len): |
| | """ |
| | Update rule for Titans in log space |
| | |
| | Parameters: |
| | - log_theta: log(θ) |
| | - log_alpha_complement: log(1-α) |
| | - log_eta: log(η) |
| | - seq_len: sequence length |
| | |
| | Returns: |
| | - log_beta, beta_T, log_f, f_T, log_g, log_G, m_T, n_T |
| | """ |
| | |
| | log_beta = torch.cumsum(log_alpha_complement, dim=-1) |
| |
|
| | |
| | beta_T = torch.exp(log_beta[..., -1]) |
| |
|
| | |
| | log_m = torch.cumsum(log_eta, dim=-1) |
| | m_T = torch.exp(log_m[..., -1]) |
| |
|
| | |
| | log_n = cal_n_log(log_theta, log_eta, seq_len) |
| | n_T = torch.exp(log_n[..., -1]) |
| |
|
| | |
| | f = cal_f_log(log_beta, seq_len, log_m) |
| | f_T = f[..., -1] |
| |
|
| | |
| | G = cal_G_log(log_beta, log_n, seq_len) |
| | |
| | g = G[..., -1, :] |
| |
|
| | return log_beta, beta_T, f, f_T, g, G, m_T, n_T |
| |
|
| |
|
| | def combine_params_log(theta, alpha, eta, seq_len): |
| | """ |
| | log space Titians |
| | |
| | Parameters: |
| | - theta: θ |
| | - alpha: α |
| | - eta: η |
| | - seq_len: sequence length |
| | |
| | Returns: |
| | - beta, beta_T, f, f_T, g, G, m_T, n_T |
| | """ |
| | |
| | log_theta = torch.log(theta.squeeze(-1)) |
| | log_alpha_complement = torch.log(1 - alpha.squeeze(-1)) |
| | log_eta = torch.log(eta.squeeze(-1)) |
| |
|
| | |
| | log_beta, beta_T, f, f_T, g, G, m_T, n_T = _combine_params_log( |
| | log_theta, log_alpha_complement, log_eta, seq_len |
| | ) |
| |
|
| | |
| | beta = torch.exp(log_beta) |
| |
|
| | return beta, beta_T, f, f_T, g, G, m_T, n_T |
| |
|