Upload folder using huggingface_hub
Browse files- Qformer.py +1 -1
- basis_functions.py +266 -0
- long_term_attention_gibbs.py +315 -0
Qformer.py
CHANGED
|
@@ -41,7 +41,7 @@ from transformers.utils import logging
|
|
| 41 |
from transformers.models.bert.configuration_bert import BertConfig
|
| 42 |
|
| 43 |
from functools import partial
|
| 44 |
-
from .
|
| 45 |
|
| 46 |
logger = logging.get_logger(__name__)
|
| 47 |
|
|
|
|
| 41 |
from transformers.models.bert.configuration_bert import BertConfig
|
| 42 |
|
| 43 |
from functools import partial
|
| 44 |
+
from .long_term_attention_gibbs import LongTermAttention
|
| 45 |
|
| 46 |
logger = logging.get_logger(__name__)
|
| 47 |
|
basis_functions.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BasisFunctions(object):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
def __len__(self):
|
| 10 |
+
"""Number of basis functions."""
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
def evaluate(self, t):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
def integrate_t2_times_psi(self, a, b):
|
| 17 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def integrate_t_times_psi(self, a, b):
|
| 21 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def integrate_psi(self, a, b):
|
| 25 |
+
"""Compute integral int_a^b psi(t)."""
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PowerBasisFunctions(BasisFunctions):
|
| 30 |
+
"""Function phi(t) = t**degree."""
|
| 31 |
+
def __init__(self, degree):
|
| 32 |
+
self.degree = degree.unsqueeze(0)
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
"""Number of basis functions."""
|
| 36 |
+
return self.degree.size(1)
|
| 37 |
+
|
| 38 |
+
def evaluate(self, t):
|
| 39 |
+
return t**self.degree
|
| 40 |
+
|
| 41 |
+
def integrate_t2_times_psi(self, a, b):
|
| 42 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 43 |
+
return (b**(self.degree + 3) - a**(self.degree + 3)) / (self.degree + 3)
|
| 44 |
+
|
| 45 |
+
def integrate_t_times_psi(self, a, b):
|
| 46 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 47 |
+
return (b**(self.degree + 2) - a**(self.degree + 2)) / (self.degree + 2)
|
| 48 |
+
|
| 49 |
+
def integrate_psi(self, a, b):
|
| 50 |
+
"""Compute integral int_a^b psi(t)."""
|
| 51 |
+
return (b**(self.degree + 1) - a**(self.degree + 1)) / (self.degree + 1)
|
| 52 |
+
|
| 53 |
+
def __repr__(self):
|
| 54 |
+
return f"PowerBasisFunction(degree={self.degree})"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SineBasisFunctions(BasisFunctions):
|
| 58 |
+
"""Function phi(t) = sin(omega*t)."""
|
| 59 |
+
def __init__(self, omega):
|
| 60 |
+
self.omega = omega.unsqueeze(0)
|
| 61 |
+
|
| 62 |
+
def __repr__(self):
|
| 63 |
+
return f"SineBasisFunction(omega={self.omega})"
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
"""Number of basis functions."""
|
| 67 |
+
return self.omega.size(1)
|
| 68 |
+
|
| 69 |
+
def evaluate(self, t):
|
| 70 |
+
return torch.sin(self.omega*t)
|
| 71 |
+
|
| 72 |
+
def integrate_t2_times_psi(self, a, b):
|
| 73 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 74 |
+
# The antiderivative of (t**2)*sin(omega*t) is
|
| 75 |
+
# ((2-(t**2)*(omega**2))*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
|
| 76 |
+
return ((2-(b**2)*(self.omega**2))*torch.cos(self.omega*b)
|
| 77 |
+
+ 2*self.omega*b*torch.sin(self.omega*b)
|
| 78 |
+
- (2-(a**2)*(self.omega**2))*torch.cos(self.omega*a)
|
| 79 |
+
- 2*self.omega*a*torch.sin(self.omega*a)
|
| 80 |
+
) / (self.omega**3)
|
| 81 |
+
|
| 82 |
+
def integrate_t_times_psi(self, a, b):
|
| 83 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 84 |
+
# The antiderivative of t*sin(omega*t) is
|
| 85 |
+
# (sin(omega*t) - omega*t*cos(omega*t)) / omega**2.
|
| 86 |
+
return (torch.sin(self.omega*b) - self.omega*b*torch.cos(self.omega*b)
|
| 87 |
+
- torch.sin(self.omega*a) + self.omega*a*torch.cos(self.omega*a)
|
| 88 |
+
) / (self.omega**2)
|
| 89 |
+
|
| 90 |
+
def integrate_psi(self, a, b):
|
| 91 |
+
"""Compute integral int_a^b psi(t)."""
|
| 92 |
+
# The antiderivative of sin(omega*t) is -cos(omega*t)/omega.
|
| 93 |
+
return (-torch.cos(self.omega*b) + torch.cos(self.omega*a)) / self.omega
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class CosineBasisFunctions(BasisFunctions):
|
| 97 |
+
"""Function phi(t) = cos(omega*t)."""
|
| 98 |
+
def __init__(self, omega):
|
| 99 |
+
self.omega = omega.unsqueeze(0)
|
| 100 |
+
|
| 101 |
+
def __repr__(self):
|
| 102 |
+
return f"CosineBasisFunction(omega={self.omega})"
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
"""Number of basis functions."""
|
| 106 |
+
return self.omega.size(1)
|
| 107 |
+
|
| 108 |
+
def evaluate(self, t):
|
| 109 |
+
return torch.cos(self.omega*t)
|
| 110 |
+
|
| 111 |
+
def integrate_t2_times_psi(self, a, b):
|
| 112 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 113 |
+
# The antiderivative of (t**2)*cos(omega*t) is
|
| 114 |
+
# (((t**2)*(omega**2)-2)*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
|
| 115 |
+
return (((b**2)*(self.omega**2)-2)*torch.sin(self.omega*b)
|
| 116 |
+
+ 2*self.omega*b*torch.cos(self.omega*b)
|
| 117 |
+
- ((a**2)*(self.omega**2)-2)*torch.sin(self.omega*a)
|
| 118 |
+
- 2*self.omega*a*torch.cos(self.omega*a)
|
| 119 |
+
) / (self.omega**3)
|
| 120 |
+
|
| 121 |
+
def integrate_t_times_psi(self, a, b):
|
| 122 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 123 |
+
# The antiderivative of t*cos(omega*t) is
|
| 124 |
+
# (cos(omega*t) + omega*t*sin(omega*t)) / omega**2.
|
| 125 |
+
return (torch.cos(self.omega*b) + self.omega*b*torch.sin(self.omega*b)
|
| 126 |
+
- torch.cos(self.omega*a) - self.omega*a*torch.sin(self.omega*a)
|
| 127 |
+
) / (self.omega**2)
|
| 128 |
+
|
| 129 |
+
def integrate_psi(self, a, b):
|
| 130 |
+
"""Compute integral int_a^b psi(t)."""
|
| 131 |
+
# The antiderivative of cos(omega*t) is sin(omega*t)/omega.
|
| 132 |
+
return (torch.sin(self.omega*b) - torch.sin(self.omega*a)) / self.omega
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class GaussianBasisFunctions(BasisFunctions):
|
| 136 |
+
"""Function phi(t) = Gaussian(t; mu, sigma_sq)."""
|
| 137 |
+
def __init__(self, mu, sigma):
|
| 138 |
+
self.mu = mu.unsqueeze(0)
|
| 139 |
+
self.sigma = sigma.unsqueeze(0)
|
| 140 |
+
|
| 141 |
+
def __repr__(self):
|
| 142 |
+
return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
|
| 143 |
+
|
| 144 |
+
def __len__(self):
|
| 145 |
+
"""Number of basis functions."""
|
| 146 |
+
return self.mu.size(1)
|
| 147 |
+
|
| 148 |
+
def _phi(self, t):
|
| 149 |
+
return 1. / math.sqrt(2 * math.pi) * torch.exp(-.5 * t**2)
|
| 150 |
+
|
| 151 |
+
def _Phi(self, t):
|
| 152 |
+
return .5 * (1 + torch.erf(t / math.sqrt(2)))
|
| 153 |
+
|
| 154 |
+
def _integrate_product_of_gaussians(self, mu, sigma_sq):
|
| 155 |
+
sigma = torch.sqrt(self.sigma ** 2 + sigma_sq)
|
| 156 |
+
return self._phi((mu - self.mu) / sigma) / sigma
|
| 157 |
+
|
| 158 |
+
def evaluate(self, t):
|
| 159 |
+
return self._phi((t - self.mu) / self.sigma) / self.sigma
|
| 160 |
+
|
| 161 |
+
def batch_evaluate(self, t):
|
| 162 |
+
t_ = t.repeat(self.mu.size(0),1) - self.mu.repeat(t.size(0),1).transpose(1,0)
|
| 163 |
+
t_ = t_ / self.sigma.repeat((t.size(0),1)).transpose(1,0)
|
| 164 |
+
return (self._phi(t_) / self.sigma.repeat((t.size(0),1)).transpose(1,0)).transpose(0,1)
|
| 165 |
+
|
| 166 |
+
def integrate_t2_times_psi(self, a, b):
|
| 167 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 168 |
+
return (self.mu**2 + self.sigma**2) * (
|
| 169 |
+
self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
|
| 170 |
+
) - (
|
| 171 |
+
self.sigma * (b + self.mu) * self._phi((b - self.mu) / self.sigma)
|
| 172 |
+
) + (
|
| 173 |
+
self.sigma * (a + self.mu) * self._phi((a - self.mu) / self.sigma)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def integrate_t_times_psi(self, a, b):
|
| 177 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 178 |
+
return self.mu * (
|
| 179 |
+
self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
|
| 180 |
+
) - self.sigma * (
|
| 181 |
+
self._phi((b - self.mu) / self.sigma) - self._phi((a - self.mu) / self.sigma)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def integrate_psi(self, a, b):
|
| 185 |
+
"""Compute integral int_a^b psi(t)."""
|
| 186 |
+
return self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
|
| 187 |
+
|
| 188 |
+
def integrate_t2_times_psi_gaussian(self, mu, sigma_sq):
|
| 189 |
+
"""Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t)."""
|
| 190 |
+
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
|
| 191 |
+
mu_tilde = (
|
| 192 |
+
self.mu * sigma_sq + mu * self.sigma ** 2
|
| 193 |
+
) / (
|
| 194 |
+
self.sigma ** 2 + sigma_sq
|
| 195 |
+
)
|
| 196 |
+
sigma_sq_tilde = ((self.sigma ** 2) * sigma_sq) / (self.sigma ** 2 + sigma_sq)
|
| 197 |
+
return S_tilde * (mu_tilde ** 2 + sigma_sq_tilde)
|
| 198 |
+
|
| 199 |
+
def integrate_t_times_psi_gaussian(self, mu, sigma_sq):
|
| 200 |
+
"""Compute integral int N(t; mu, sigma_sq) * t * psi(t)."""
|
| 201 |
+
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
|
| 202 |
+
mu_tilde = (
|
| 203 |
+
self.mu * sigma_sq + mu * self.sigma ** 2
|
| 204 |
+
) / (
|
| 205 |
+
self.sigma ** 2 + sigma_sq
|
| 206 |
+
)
|
| 207 |
+
return S_tilde * mu_tilde
|
| 208 |
+
|
| 209 |
+
def integrate_psi_gaussian(self, mu, sigma_sq):
|
| 210 |
+
"""Compute integral int N(t; mu, sigma_sq) * psi(t)."""
|
| 211 |
+
return self._integrate_product_of_gaussians(mu, sigma_sq)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class RetangularBasisFunctions(BasisFunctions):
|
| 215 |
+
"""Function phi(t) = Gaussian(t; mu, sigma_sq)."""
|
| 216 |
+
def __init__(self, mu, sigma):
|
| 217 |
+
self.mu = mu.unsqueeze(0)
|
| 218 |
+
self.width = sigma.unsqueeze(0)
|
| 219 |
+
|
| 220 |
+
def __repr__(self):
|
| 221 |
+
return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
|
| 222 |
+
|
| 223 |
+
def __len__(self):
|
| 224 |
+
"""Number of basis functions."""
|
| 225 |
+
return self.mu.size(1)
|
| 226 |
+
|
| 227 |
+
def batch_evaluate(self, t):
|
| 228 |
+
"""
|
| 229 |
+
Evaluate multiple time points against all rectangular basis functions.
|
| 230 |
+
Args:
|
| 231 |
+
t: Tensor of time values to evaluate, shape (num_points,).
|
| 232 |
+
Returns:
|
| 233 |
+
Tensor of evaluations, shape (num_basis, num_points).
|
| 234 |
+
"""
|
| 235 |
+
t = t.repeat(self.mu.size(0),1) # Shape: (1, num_points)
|
| 236 |
+
mu = self.mu.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
|
| 237 |
+
width = self.width.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
|
| 238 |
+
return ((t >= (mu - width / 2)) & (t < (mu + width / 2))).float().transpose(0,1)
|
| 239 |
+
|
| 240 |
+
def _Phi(self, t):
|
| 241 |
+
"""
|
| 242 |
+
Compute the step function for a single value of t.
|
| 243 |
+
Args:
|
| 244 |
+
t: A scalar or tensor of time values.
|
| 245 |
+
Returns:
|
| 246 |
+
Tensor of values indicating presence in each basis function's range.
|
| 247 |
+
"""
|
| 248 |
+
lower_bounds = self.mu - self.width / 2
|
| 249 |
+
upper_bounds = self.mu + self.width / 2
|
| 250 |
+
return ((t >= lower_bounds) & (t < upper_bounds)).float()
|
| 251 |
+
|
| 252 |
+
def evaluate(self, t):
|
| 253 |
+
"""
|
| 254 |
+
Evaluate the rectangular basis functions at a single point or array of points.
|
| 255 |
+
Args:
|
| 256 |
+
t: A scalar or 1D tensor of time values.
|
| 257 |
+
Returns:
|
| 258 |
+
Tensor of shape (num_basis,) for scalar input, or (num_basis, num_points) for tensor input.
|
| 259 |
+
"""
|
| 260 |
+
if t.ndim == 0: # Scalar input
|
| 261 |
+
return self._Phi(t)
|
| 262 |
+
else: # Tensor input
|
| 263 |
+
# Shape: (1, num_points)
|
| 264 |
+
lower_bounds = (self.mu - self.width / 2) # Shape: (num_basis, 1)
|
| 265 |
+
upper_bounds = (self.mu + self.width / 2) # Shape: (num_basis, 1)
|
| 266 |
+
return ((t >= lower_bounds) & (t < upper_bounds)).float()
|
long_term_attention_gibbs.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
"""
|
| 3 |
+
Attention modules
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.distributions as dist
|
| 8 |
+
|
| 9 |
+
from .basis_functions import (
|
| 10 |
+
PowerBasisFunctions,
|
| 11 |
+
SineBasisFunctions,
|
| 12 |
+
CosineBasisFunctions,
|
| 13 |
+
GaussianBasisFunctions,
|
| 14 |
+
RetangularBasisFunctions
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LongTermAttention(nn.Module):
|
| 22 |
+
def __init__(self, head_size:int , length: int, target_len:int, attn_func: str, attn_num_basis: int,
|
| 23 |
+
continuous: bool, attn_drop: float, infinite_memory: bool, n_layers: int,
|
| 24 |
+
n_heads: int, affines: bool, mask: bool, mask_type: str, kl_regularizer: bool, proj_key, proj_value, sigma_0, mu_0, sticky_memories, sigmas, tau, **kwargs):
|
| 25 |
+
|
| 26 |
+
super(LongTermAttention, self).__init__()
|
| 27 |
+
|
| 28 |
+
self.device = 'cuda'
|
| 29 |
+
self.length = length #memory length
|
| 30 |
+
self.target_len = target_len #target length / transformer length
|
| 31 |
+
self.head_size = head_size
|
| 32 |
+
self.attn_num_basis = attn_num_basis
|
| 33 |
+
self.continuous = continuous # whether attention over memory vectors is continuous
|
| 34 |
+
self.attn_func = attn_func # normalizing function
|
| 35 |
+
self.n_head = n_heads
|
| 36 |
+
self.sigmas = sigmas
|
| 37 |
+
self.kl_regularizer = kl_regularizer
|
| 38 |
+
self.sigma_0 = sigma_0
|
| 39 |
+
self.mu_0 = mu_0
|
| 40 |
+
self.proj_key = proj_key
|
| 41 |
+
self.proj_value = proj_value
|
| 42 |
+
|
| 43 |
+
self.affines=affines # whether mu, sigma should be computed using affine transformations
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
self.sticky_memories=sticky_memories
|
| 47 |
+
|
| 48 |
+
self.mem_threshold=2048
|
| 49 |
+
self.infinite_memory = infinite_memory # whether the memory is infinite
|
| 50 |
+
|
| 51 |
+
self.nb_samples=512 # number of samples used for update
|
| 52 |
+
self.tau = tau #compressing factor
|
| 53 |
+
self.count = 0
|
| 54 |
+
|
| 55 |
+
self.x_past=None # previous memory vectors
|
| 56 |
+
self.B_past=None # previous coefficient matrix
|
| 57 |
+
|
| 58 |
+
self.ridge_penalty=0.5 # ridge penalty
|
| 59 |
+
self.padding = True
|
| 60 |
+
|
| 61 |
+
self.spacing='linear'
|
| 62 |
+
|
| 63 |
+
def get_basis(self, length, target_len):
|
| 64 |
+
def compute_G(l, psi, positions, padding=True):
|
| 65 |
+
|
| 66 |
+
F = torch.zeros(self.attn_num_basis, positions.size(0))
|
| 67 |
+
|
| 68 |
+
basis_functions = psi
|
| 69 |
+
F[:, :] = basis_functions.evaluate(positions.unsqueeze(1)).t()
|
| 70 |
+
|
| 71 |
+
I = torch.eye(self.attn_num_basis)
|
| 72 |
+
G = F.t().matmul((F.matmul(F.t()) + self.ridge_penalty * I).inverse())
|
| 73 |
+
|
| 74 |
+
if padding:
|
| 75 |
+
if l % 2:
|
| 76 |
+
G = G[((l-1)//2):(-(l-1)//2), :]
|
| 77 |
+
else:
|
| 78 |
+
G = G[(l//2):-(l//2), :]
|
| 79 |
+
|
| 80 |
+
return G.to(self.device)
|
| 81 |
+
padding = self.padding
|
| 82 |
+
attn_num_basis = self.attn_num_basis
|
| 83 |
+
if self.continuous:
|
| 84 |
+
|
| 85 |
+
self.psi=[None]
|
| 86 |
+
self.Gs=[None for _ in range(length+1)]
|
| 87 |
+
lengths=[]
|
| 88 |
+
for i in range(length):
|
| 89 |
+
self.psi.append([])
|
| 90 |
+
if (i+1)%target_len==0:
|
| 91 |
+
lengths.append(i+1)
|
| 92 |
+
if length not in lengths:
|
| 93 |
+
lengths.append(length)
|
| 94 |
+
for l in lengths:
|
| 95 |
+
# get positions for memory vectors
|
| 96 |
+
self.add_retangular_basis_functions(self.psi[l], attn_num_basis, device=self.device)
|
| 97 |
+
|
| 98 |
+
if self.spacing=='linear':
|
| 99 |
+
if padding:
|
| 100 |
+
if l % 2:
|
| 101 |
+
shift = 1 / float(l)
|
| 102 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
|
| 103 |
+
else:
|
| 104 |
+
shift = 1 / float(2*l)
|
| 105 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
|
| 106 |
+
else:
|
| 107 |
+
shift = 1 / float(2*l)
|
| 108 |
+
positions = torch.linspace(shift, 1-shift, l).to(self.device)
|
| 109 |
+
elif self.spacing=='log':
|
| 110 |
+
if padding:
|
| 111 |
+
if l % 2:
|
| 112 |
+
shift = 1 / float(l)
|
| 113 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
|
| 114 |
+
else:
|
| 115 |
+
shift = 1 / float(2*l)
|
| 116 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
|
| 117 |
+
|
| 118 |
+
pos = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
|
| 119 |
+
positions = torch.cat([positions[:int(l/2)],pos.to(self.device),positions[-int(l/2):]])
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
positions = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
|
| 123 |
+
|
| 124 |
+
# compute basis functions
|
| 125 |
+
self.Gs[l]=compute_G(l, self.psi[l][0], positions, padding=padding) # [L,N]
|
| 126 |
+
self.positions = positions[int(l/2):-int(l/2)]
|
| 127 |
+
|
| 128 |
+
# compute samples for memory update
|
| 129 |
+
if self.infinite_memory:
|
| 130 |
+
tm_tau = torch.arange(1,self.nb_samples+1).float()
|
| 131 |
+
tm_l = torch.arange(self.nb_samples+1,length+self.nb_samples+1).float()
|
| 132 |
+
tm_tau = tm_tau*self.tau/self.nb_samples # positions of old vectors
|
| 133 |
+
tm_l = self.tau + (1-self.tau)*(tm_l-self.nb_samples)/length # positions of new vectors
|
| 134 |
+
positions_inf = torch.cat([tm_tau, tm_l],0).to(self.device) # positions
|
| 135 |
+
|
| 136 |
+
if padding:
|
| 137 |
+
if l % 2:
|
| 138 |
+
shift = 1 / float(length+self.nb_samples)
|
| 139 |
+
positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)-1).to(self.device)
|
| 140 |
+
else:
|
| 141 |
+
shift = 1 / float(2*length+self.nb_samples)
|
| 142 |
+
positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)).to(self.device)
|
| 143 |
+
positions_pad_ = torch.FloatTensor([i for i in positions_pad if i<0]).to(self.device)
|
| 144 |
+
positions_pad__ = torch.FloatTensor([i for i in positions_pad if i>1]).to(self.device)
|
| 145 |
+
positions_inf = torch.cat([positions_pad_,positions_inf,positions_pad__], dim=0)
|
| 146 |
+
|
| 147 |
+
self.samples=None
|
| 148 |
+
for t in tm_tau:
|
| 149 |
+
if self.samples is None:
|
| 150 |
+
self.samples = self.psi[l][0].evaluate(t/self.tau)
|
| 151 |
+
else:
|
| 152 |
+
self.samples = torch.cat([self.samples,self.psi[l][0].evaluate(t/self.tau)], dim=0)
|
| 153 |
+
|
| 154 |
+
# compute G for the infinite case
|
| 155 |
+
self.G_inf = compute_G(self.nb_samples+length, self.psi[l][0], positions_inf, padding=padding) #[L+nb_samples,N]
|
| 156 |
+
|
| 157 |
+
if self.sticky_memories:
|
| 158 |
+
self.bins = torch.linspace(0,1,129).to(device=self.device) #self.positions
|
| 159 |
+
self.nb_bins_cat=1
|
| 160 |
+
self.bins_cat = dist.Categorical(torch.ones(self.nb_bins_cat))
|
| 161 |
+
|
| 162 |
+
def add_gaussian_basis_functions(self, psi, nb_basis, sigmas, device):
|
| 163 |
+
mu, sigma = torch.meshgrid(torch.linspace(0, 1, nb_basis // len(sigmas)), torch.Tensor(sigmas))
|
| 164 |
+
mu = mu.flatten().to(device)
|
| 165 |
+
sigma = sigma.flatten().to(device)
|
| 166 |
+
self.basis_mu=mu
|
| 167 |
+
self.basis_sigma=sigma
|
| 168 |
+
assert mu.size(0) == nb_basis
|
| 169 |
+
psi.append(GaussianBasisFunctions(mu=mu, sigma=sigma))
|
| 170 |
+
|
| 171 |
+
def add_retangular_basis_functions(self, psi, nb_basis, device):
|
| 172 |
+
width = torch.ones(nb_basis, device=device) / nb_basis
|
| 173 |
+
|
| 174 |
+
# Compute the centers (midpoints) of each bin
|
| 175 |
+
edges = torch.linspace(0, 1, nb_basis + 1, device=device)
|
| 176 |
+
mu = (edges[:-1] + edges[1:]) / 2
|
| 177 |
+
psi.append(RetangularBasisFunctions(mu=mu, sigma=width))
|
| 178 |
+
|
| 179 |
+
def value_function(self, x, inf=False):
|
| 180 |
+
if inf:
|
| 181 |
+
G = self.G_inf # [nb_sample+L,N]
|
| 182 |
+
else:
|
| 183 |
+
G = self.Gs[x.size(-1)] # [L,N]
|
| 184 |
+
B = torch.matmul(x, G) # [B,e,N]
|
| 185 |
+
B = B.permute(0,2,1) # [B,N,e]
|
| 186 |
+
|
| 187 |
+
return B
|
| 188 |
+
|
| 189 |
+
def update_inf(self, x):
|
| 190 |
+
if self.B_past is not None:
|
| 191 |
+
if self.sticky_memories:
|
| 192 |
+
bins = self.bins.clone()
|
| 193 |
+
bins[0]=-.000001
|
| 194 |
+
bins[-1]=1.000001
|
| 195 |
+
prob_density = self.compute_probability(self.score, t=bins)
|
| 196 |
+
cum_prob = torch.cumulative_trapezoid(prob_density, bins, dim=-1).to(self.device)
|
| 197 |
+
p = (cum_prob[..., 1:] - cum_prob[..., :-1]).sum(dim=(1, 2))
|
| 198 |
+
p = p / p.sum(-1, keepdim=True) # Normalize over the last dimension (bins)
|
| 199 |
+
p = dist.Categorical(p)
|
| 200 |
+
b = p.sample((self.nb_samples,))
|
| 201 |
+
t = self.bins_cat.sample((self.nb_samples, 1)).to(device=self.device)
|
| 202 |
+
ts = (t*(self.bins[b+1]-self.bins[b])/self.nb_bins_cat +self.bins[b]).transpose(1,0)
|
| 203 |
+
samples = self.psi[self.length][0].batch_evaluate(ts[0]).contiguous()
|
| 204 |
+
|
| 205 |
+
xm_tau = self.B_past.transpose(-1,-2).matmul(samples.transpose(-1,-2)) # [B,e,nb_samples]
|
| 206 |
+
else:
|
| 207 |
+
xm_tau = self.B_past.transpose(-1,-2).matmul(self.samples.transpose(-1,-2)) # [B,e,nb_samples]
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
x = torch.cat([xm_tau,x], dim=2) # [B,e,nb_samples+L]
|
| 211 |
+
B = self.value_function(x, inf=True) # [B,N,e]
|
| 212 |
+
else:
|
| 213 |
+
B = self.value_function(x)
|
| 214 |
+
|
| 215 |
+
self.B_past=B.detach()
|
| 216 |
+
self.x_past=x
|
| 217 |
+
return B
|
| 218 |
+
|
| 219 |
+
def score(self, t):
|
| 220 |
+
psis = self.psis[0].batch_evaluate(t)
|
| 221 |
+
query = self.queries/ (self.d_head ** 0.5) # divide by sqrt(d_head) [B,h,q,d]
|
| 222 |
+
keys = self.keys.transpose(-1, -2)
|
| 223 |
+
keys = torch.matmul(keys, psis.T) #[B,h,d,1]
|
| 224 |
+
scores = torch.matmul(query, keys) #[B,h,q,1]
|
| 225 |
+
return scores
|
| 226 |
+
|
| 227 |
+
def compute_probability(self, score_fn, num_points=1000, t=None):
|
| 228 |
+
"""
|
| 229 |
+
Compute probability distribution p(t).
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
score_fn (callable): Function that computes z(t)
|
| 233 |
+
num_points (int): Number of points for numerical integration
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
tuple: (probabilities, normalization constant)
|
| 237 |
+
"""
|
| 238 |
+
if t is None:
|
| 239 |
+
# Create integration points
|
| 240 |
+
t = torch.linspace(0, 1, num_points).to(self.device)
|
| 241 |
+
|
| 242 |
+
scores = score_fn(t)
|
| 243 |
+
prob = torch.exp(scores) / torch.trapz(torch.exp(scores), t, dim=-1).unsqueeze(-1)
|
| 244 |
+
return prob
|
| 245 |
+
|
| 246 |
+
def expected_value(self, score_fn, num_points=1000):
|
| 247 |
+
"""
|
| 248 |
+
Compute expected value E_p[V(t)] using nested integration.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
score_fn (callable): Function that computes z(t)
|
| 252 |
+
value_fn (callable): Function that computes v(t)
|
| 253 |
+
num_points (int): Number of points for numerical integration
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
torch.Tensor: Expected value
|
| 257 |
+
"""
|
| 258 |
+
# Create integration points
|
| 259 |
+
t = torch.linspace(0, 1, num_points).to(self.device)
|
| 260 |
+
|
| 261 |
+
# Compute basis functions
|
| 262 |
+
self.psis = []
|
| 263 |
+
self.add_retangular_basis_functions(self.psis, self.attn_num_basis, self.device)
|
| 264 |
+
psi = self.psis[0].batch_evaluate(t)
|
| 265 |
+
# Compute probability distribution
|
| 266 |
+
prob = self.compute_probability(score_fn, num_points)
|
| 267 |
+
# Compute values at integration points
|
| 268 |
+
values = self.values
|
| 269 |
+
# Compute p(t) * psi(t)
|
| 270 |
+
# Reshape psi for broadcasting to match the shape of prob
|
| 271 |
+
psi_broadcasted = psi.unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
| 272 |
+
|
| 273 |
+
# Expand psi to match the dimensions of prob (num_points, batch_size, n_head, qlen, 256)
|
| 274 |
+
psi_broadcasted = psi_broadcasted.expand(num_points, self.batch_size, self.n_head, self.qlen, self.attn_num_basis)
|
| 275 |
+
integrand = torch.matmul(prob.permute(3,0,1,2).unsqueeze(-1).unsqueeze(-1), psi_broadcasted.unsqueeze(-2)).permute(1, 2, 3, 4, 5, 0).squeeze(-3)
|
| 276 |
+
|
| 277 |
+
integral = torch.trapz(integrand, t, dim=-1)
|
| 278 |
+
# Matrix multiply with values
|
| 279 |
+
expected_value = torch.matmul(integral, values) # [B, h, q, d]
|
| 280 |
+
|
| 281 |
+
return expected_value
|
| 282 |
+
|
| 283 |
+
def forward(self, k, q, new_doc, layer_n):
|
| 284 |
+
self.device = k.device
|
| 285 |
+
if self.continuous:
|
| 286 |
+
klen = int(k.size(1)/(14*14))
|
| 287 |
+
self.length = klen
|
| 288 |
+
batch_size = k.size(0) #batch size
|
| 289 |
+
qlen = q.size(1) #query length
|
| 290 |
+
self.qlen = qlen
|
| 291 |
+
self.batch_size = batch_size
|
| 292 |
+
self.d_head = self.head_size #head size
|
| 293 |
+
self.get_basis(klen, klen)
|
| 294 |
+
# clean memory if going through different document
|
| 295 |
+
if new_doc:
|
| 296 |
+
self.B_past=None
|
| 297 |
+
self.x_past=None
|
| 298 |
+
|
| 299 |
+
k = k.reshape(batch_size, klen, 14, 14, 1024).mean(dim=(2, 3))
|
| 300 |
+
k = k.transpose(1,2)
|
| 301 |
+
# perform memory update
|
| 302 |
+
if self.infinite_memory:
|
| 303 |
+
B = self.update_inf(k)
|
| 304 |
+
else: # compute input continuous approximation
|
| 305 |
+
B = self.value_function(k) # [B,N,e]
|
| 306 |
+
keys = self.proj_key(B)
|
| 307 |
+
values = self.proj_value(B)
|
| 308 |
+
query = q
|
| 309 |
+
self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
|
| 310 |
+
self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
|
| 311 |
+
self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
|
| 312 |
+
context = self.expected_value(self.score) # Shape [1, 32, 768]
|
| 313 |
+
|
| 314 |
+
return context.contiguous().transpose(1,2).reshape(1, qlen, -1)
|
| 315 |
+
|