Rihong commited on
Commit
b84a0b5
·
verified ·
1 Parent(s): 4e5ef76

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. Qformer.py +1 -1
  2. basis_functions.py +266 -0
  3. long_term_attention_gibbs.py +315 -0
Qformer.py CHANGED
@@ -41,7 +41,7 @@ from transformers.utils import logging
41
  from transformers.models.bert.configuration_bert import BertConfig
42
 
43
  from functools import partial
44
- from .ltm.long_term_attention_gibbs import LongTermAttention
45
 
46
  logger = logging.get_logger(__name__)
47
 
 
41
  from transformers.models.bert.configuration_bert import BertConfig
42
 
43
  from functools import partial
44
+ from .long_term_attention_gibbs import LongTermAttention
45
 
46
  logger = logging.get_logger(__name__)
47
 
basis_functions.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+
5
+ class BasisFunctions(object):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def __len__(self):
10
+ """Number of basis functions."""
11
+ pass
12
+
13
+ def evaluate(self, t):
14
+ pass
15
+
16
+ def integrate_t2_times_psi(self, a, b):
17
+ """Compute integral int_a^b (t**2) * psi(t)."""
18
+ pass
19
+
20
+ def integrate_t_times_psi(self, a, b):
21
+ """Compute integral int_a^b t * psi(t)."""
22
+ pass
23
+
24
+ def integrate_psi(self, a, b):
25
+ """Compute integral int_a^b psi(t)."""
26
+ pass
27
+
28
+
29
+ class PowerBasisFunctions(BasisFunctions):
30
+ """Function phi(t) = t**degree."""
31
+ def __init__(self, degree):
32
+ self.degree = degree.unsqueeze(0)
33
+
34
+ def __len__(self):
35
+ """Number of basis functions."""
36
+ return self.degree.size(1)
37
+
38
+ def evaluate(self, t):
39
+ return t**self.degree
40
+
41
+ def integrate_t2_times_psi(self, a, b):
42
+ """Compute integral int_a^b (t**2) * psi(t)."""
43
+ return (b**(self.degree + 3) - a**(self.degree + 3)) / (self.degree + 3)
44
+
45
+ def integrate_t_times_psi(self, a, b):
46
+ """Compute integral int_a^b t * psi(t)."""
47
+ return (b**(self.degree + 2) - a**(self.degree + 2)) / (self.degree + 2)
48
+
49
+ def integrate_psi(self, a, b):
50
+ """Compute integral int_a^b psi(t)."""
51
+ return (b**(self.degree + 1) - a**(self.degree + 1)) / (self.degree + 1)
52
+
53
+ def __repr__(self):
54
+ return f"PowerBasisFunction(degree={self.degree})"
55
+
56
+
57
+ class SineBasisFunctions(BasisFunctions):
58
+ """Function phi(t) = sin(omega*t)."""
59
+ def __init__(self, omega):
60
+ self.omega = omega.unsqueeze(0)
61
+
62
+ def __repr__(self):
63
+ return f"SineBasisFunction(omega={self.omega})"
64
+
65
+ def __len__(self):
66
+ """Number of basis functions."""
67
+ return self.omega.size(1)
68
+
69
+ def evaluate(self, t):
70
+ return torch.sin(self.omega*t)
71
+
72
+ def integrate_t2_times_psi(self, a, b):
73
+ """Compute integral int_a^b (t**2) * psi(t)."""
74
+ # The antiderivative of (t**2)*sin(omega*t) is
75
+ # ((2-(t**2)*(omega**2))*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
76
+ return ((2-(b**2)*(self.omega**2))*torch.cos(self.omega*b)
77
+ + 2*self.omega*b*torch.sin(self.omega*b)
78
+ - (2-(a**2)*(self.omega**2))*torch.cos(self.omega*a)
79
+ - 2*self.omega*a*torch.sin(self.omega*a)
80
+ ) / (self.omega**3)
81
+
82
+ def integrate_t_times_psi(self, a, b):
83
+ """Compute integral int_a^b t * psi(t)."""
84
+ # The antiderivative of t*sin(omega*t) is
85
+ # (sin(omega*t) - omega*t*cos(omega*t)) / omega**2.
86
+ return (torch.sin(self.omega*b) - self.omega*b*torch.cos(self.omega*b)
87
+ - torch.sin(self.omega*a) + self.omega*a*torch.cos(self.omega*a)
88
+ ) / (self.omega**2)
89
+
90
+ def integrate_psi(self, a, b):
91
+ """Compute integral int_a^b psi(t)."""
92
+ # The antiderivative of sin(omega*t) is -cos(omega*t)/omega.
93
+ return (-torch.cos(self.omega*b) + torch.cos(self.omega*a)) / self.omega
94
+
95
+
96
+ class CosineBasisFunctions(BasisFunctions):
97
+ """Function phi(t) = cos(omega*t)."""
98
+ def __init__(self, omega):
99
+ self.omega = omega.unsqueeze(0)
100
+
101
+ def __repr__(self):
102
+ return f"CosineBasisFunction(omega={self.omega})"
103
+
104
+ def __len__(self):
105
+ """Number of basis functions."""
106
+ return self.omega.size(1)
107
+
108
+ def evaluate(self, t):
109
+ return torch.cos(self.omega*t)
110
+
111
+ def integrate_t2_times_psi(self, a, b):
112
+ """Compute integral int_a^b (t**2) * psi(t)."""
113
+ # The antiderivative of (t**2)*cos(omega*t) is
114
+ # (((t**2)*(omega**2)-2)*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
115
+ return (((b**2)*(self.omega**2)-2)*torch.sin(self.omega*b)
116
+ + 2*self.omega*b*torch.cos(self.omega*b)
117
+ - ((a**2)*(self.omega**2)-2)*torch.sin(self.omega*a)
118
+ - 2*self.omega*a*torch.cos(self.omega*a)
119
+ ) / (self.omega**3)
120
+
121
+ def integrate_t_times_psi(self, a, b):
122
+ """Compute integral int_a^b t * psi(t)."""
123
+ # The antiderivative of t*cos(omega*t) is
124
+ # (cos(omega*t) + omega*t*sin(omega*t)) / omega**2.
125
+ return (torch.cos(self.omega*b) + self.omega*b*torch.sin(self.omega*b)
126
+ - torch.cos(self.omega*a) - self.omega*a*torch.sin(self.omega*a)
127
+ ) / (self.omega**2)
128
+
129
+ def integrate_psi(self, a, b):
130
+ """Compute integral int_a^b psi(t)."""
131
+ # The antiderivative of cos(omega*t) is sin(omega*t)/omega.
132
+ return (torch.sin(self.omega*b) - torch.sin(self.omega*a)) / self.omega
133
+
134
+
135
+ class GaussianBasisFunctions(BasisFunctions):
136
+ """Function phi(t) = Gaussian(t; mu, sigma_sq)."""
137
+ def __init__(self, mu, sigma):
138
+ self.mu = mu.unsqueeze(0)
139
+ self.sigma = sigma.unsqueeze(0)
140
+
141
+ def __repr__(self):
142
+ return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
143
+
144
+ def __len__(self):
145
+ """Number of basis functions."""
146
+ return self.mu.size(1)
147
+
148
+ def _phi(self, t):
149
+ return 1. / math.sqrt(2 * math.pi) * torch.exp(-.5 * t**2)
150
+
151
+ def _Phi(self, t):
152
+ return .5 * (1 + torch.erf(t / math.sqrt(2)))
153
+
154
+ def _integrate_product_of_gaussians(self, mu, sigma_sq):
155
+ sigma = torch.sqrt(self.sigma ** 2 + sigma_sq)
156
+ return self._phi((mu - self.mu) / sigma) / sigma
157
+
158
+ def evaluate(self, t):
159
+ return self._phi((t - self.mu) / self.sigma) / self.sigma
160
+
161
+ def batch_evaluate(self, t):
162
+ t_ = t.repeat(self.mu.size(0),1) - self.mu.repeat(t.size(0),1).transpose(1,0)
163
+ t_ = t_ / self.sigma.repeat((t.size(0),1)).transpose(1,0)
164
+ return (self._phi(t_) / self.sigma.repeat((t.size(0),1)).transpose(1,0)).transpose(0,1)
165
+
166
+ def integrate_t2_times_psi(self, a, b):
167
+ """Compute integral int_a^b (t**2) * psi(t)."""
168
+ return (self.mu**2 + self.sigma**2) * (
169
+ self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
170
+ ) - (
171
+ self.sigma * (b + self.mu) * self._phi((b - self.mu) / self.sigma)
172
+ ) + (
173
+ self.sigma * (a + self.mu) * self._phi((a - self.mu) / self.sigma)
174
+ )
175
+
176
+ def integrate_t_times_psi(self, a, b):
177
+ """Compute integral int_a^b t * psi(t)."""
178
+ return self.mu * (
179
+ self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
180
+ ) - self.sigma * (
181
+ self._phi((b - self.mu) / self.sigma) - self._phi((a - self.mu) / self.sigma)
182
+ )
183
+
184
+ def integrate_psi(self, a, b):
185
+ """Compute integral int_a^b psi(t)."""
186
+ return self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
187
+
188
+ def integrate_t2_times_psi_gaussian(self, mu, sigma_sq):
189
+ """Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t)."""
190
+ S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
191
+ mu_tilde = (
192
+ self.mu * sigma_sq + mu * self.sigma ** 2
193
+ ) / (
194
+ self.sigma ** 2 + sigma_sq
195
+ )
196
+ sigma_sq_tilde = ((self.sigma ** 2) * sigma_sq) / (self.sigma ** 2 + sigma_sq)
197
+ return S_tilde * (mu_tilde ** 2 + sigma_sq_tilde)
198
+
199
+ def integrate_t_times_psi_gaussian(self, mu, sigma_sq):
200
+ """Compute integral int N(t; mu, sigma_sq) * t * psi(t)."""
201
+ S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
202
+ mu_tilde = (
203
+ self.mu * sigma_sq + mu * self.sigma ** 2
204
+ ) / (
205
+ self.sigma ** 2 + sigma_sq
206
+ )
207
+ return S_tilde * mu_tilde
208
+
209
+ def integrate_psi_gaussian(self, mu, sigma_sq):
210
+ """Compute integral int N(t; mu, sigma_sq) * psi(t)."""
211
+ return self._integrate_product_of_gaussians(mu, sigma_sq)
212
+
213
+
214
+ class RetangularBasisFunctions(BasisFunctions):
215
+ """Function phi(t) = Gaussian(t; mu, sigma_sq)."""
216
+ def __init__(self, mu, sigma):
217
+ self.mu = mu.unsqueeze(0)
218
+ self.width = sigma.unsqueeze(0)
219
+
220
+ def __repr__(self):
221
+ return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
222
+
223
+ def __len__(self):
224
+ """Number of basis functions."""
225
+ return self.mu.size(1)
226
+
227
+ def batch_evaluate(self, t):
228
+ """
229
+ Evaluate multiple time points against all rectangular basis functions.
230
+ Args:
231
+ t: Tensor of time values to evaluate, shape (num_points,).
232
+ Returns:
233
+ Tensor of evaluations, shape (num_basis, num_points).
234
+ """
235
+ t = t.repeat(self.mu.size(0),1) # Shape: (1, num_points)
236
+ mu = self.mu.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
237
+ width = self.width.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
238
+ return ((t >= (mu - width / 2)) & (t < (mu + width / 2))).float().transpose(0,1)
239
+
240
+ def _Phi(self, t):
241
+ """
242
+ Compute the step function for a single value of t.
243
+ Args:
244
+ t: A scalar or tensor of time values.
245
+ Returns:
246
+ Tensor of values indicating presence in each basis function's range.
247
+ """
248
+ lower_bounds = self.mu - self.width / 2
249
+ upper_bounds = self.mu + self.width / 2
250
+ return ((t >= lower_bounds) & (t < upper_bounds)).float()
251
+
252
+ def evaluate(self, t):
253
+ """
254
+ Evaluate the rectangular basis functions at a single point or array of points.
255
+ Args:
256
+ t: A scalar or 1D tensor of time values.
257
+ Returns:
258
+ Tensor of shape (num_basis,) for scalar input, or (num_basis, num_points) for tensor input.
259
+ """
260
+ if t.ndim == 0: # Scalar input
261
+ return self._Phi(t)
262
+ else: # Tensor input
263
+ # Shape: (1, num_points)
264
+ lower_bounds = (self.mu - self.width / 2) # Shape: (num_basis, 1)
265
+ upper_bounds = (self.mu + self.width / 2) # Shape: (num_basis, 1)
266
+ return ((t >= lower_bounds) & (t < upper_bounds)).float()
long_term_attention_gibbs.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ Attention modules
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.distributions as dist
8
+
9
+ from .basis_functions import (
10
+ PowerBasisFunctions,
11
+ SineBasisFunctions,
12
+ CosineBasisFunctions,
13
+ GaussianBasisFunctions,
14
+ RetangularBasisFunctions
15
+ )
16
+
17
+ import numpy as np
18
+
19
+
20
+
21
+ class LongTermAttention(nn.Module):
22
+ def __init__(self, head_size:int , length: int, target_len:int, attn_func: str, attn_num_basis: int,
23
+ continuous: bool, attn_drop: float, infinite_memory: bool, n_layers: int,
24
+ n_heads: int, affines: bool, mask: bool, mask_type: str, kl_regularizer: bool, proj_key, proj_value, sigma_0, mu_0, sticky_memories, sigmas, tau, **kwargs):
25
+
26
+ super(LongTermAttention, self).__init__()
27
+
28
+ self.device = 'cuda'
29
+ self.length = length #memory length
30
+ self.target_len = target_len #target length / transformer length
31
+ self.head_size = head_size
32
+ self.attn_num_basis = attn_num_basis
33
+ self.continuous = continuous # whether attention over memory vectors is continuous
34
+ self.attn_func = attn_func # normalizing function
35
+ self.n_head = n_heads
36
+ self.sigmas = sigmas
37
+ self.kl_regularizer = kl_regularizer
38
+ self.sigma_0 = sigma_0
39
+ self.mu_0 = mu_0
40
+ self.proj_key = proj_key
41
+ self.proj_value = proj_value
42
+
43
+ self.affines=affines # whether mu, sigma should be computed using affine transformations
44
+
45
+
46
+ self.sticky_memories=sticky_memories
47
+
48
+ self.mem_threshold=2048
49
+ self.infinite_memory = infinite_memory # whether the memory is infinite
50
+
51
+ self.nb_samples=512 # number of samples used for update
52
+ self.tau = tau #compressing factor
53
+ self.count = 0
54
+
55
+ self.x_past=None # previous memory vectors
56
+ self.B_past=None # previous coefficient matrix
57
+
58
+ self.ridge_penalty=0.5 # ridge penalty
59
+ self.padding = True
60
+
61
+ self.spacing='linear'
62
+
63
+ def get_basis(self, length, target_len):
64
+ def compute_G(l, psi, positions, padding=True):
65
+
66
+ F = torch.zeros(self.attn_num_basis, positions.size(0))
67
+
68
+ basis_functions = psi
69
+ F[:, :] = basis_functions.evaluate(positions.unsqueeze(1)).t()
70
+
71
+ I = torch.eye(self.attn_num_basis)
72
+ G = F.t().matmul((F.matmul(F.t()) + self.ridge_penalty * I).inverse())
73
+
74
+ if padding:
75
+ if l % 2:
76
+ G = G[((l-1)//2):(-(l-1)//2), :]
77
+ else:
78
+ G = G[(l//2):-(l//2), :]
79
+
80
+ return G.to(self.device)
81
+ padding = self.padding
82
+ attn_num_basis = self.attn_num_basis
83
+ if self.continuous:
84
+
85
+ self.psi=[None]
86
+ self.Gs=[None for _ in range(length+1)]
87
+ lengths=[]
88
+ for i in range(length):
89
+ self.psi.append([])
90
+ if (i+1)%target_len==0:
91
+ lengths.append(i+1)
92
+ if length not in lengths:
93
+ lengths.append(length)
94
+ for l in lengths:
95
+ # get positions for memory vectors
96
+ self.add_retangular_basis_functions(self.psi[l], attn_num_basis, device=self.device)
97
+
98
+ if self.spacing=='linear':
99
+ if padding:
100
+ if l % 2:
101
+ shift = 1 / float(l)
102
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
103
+ else:
104
+ shift = 1 / float(2*l)
105
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
106
+ else:
107
+ shift = 1 / float(2*l)
108
+ positions = torch.linspace(shift, 1-shift, l).to(self.device)
109
+ elif self.spacing=='log':
110
+ if padding:
111
+ if l % 2:
112
+ shift = 1 / float(l)
113
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
114
+ else:
115
+ shift = 1 / float(2*l)
116
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
117
+
118
+ pos = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
119
+ positions = torch.cat([positions[:int(l/2)],pos.to(self.device),positions[-int(l/2):]])
120
+
121
+ else:
122
+ positions = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
123
+
124
+ # compute basis functions
125
+ self.Gs[l]=compute_G(l, self.psi[l][0], positions, padding=padding) # [L,N]
126
+ self.positions = positions[int(l/2):-int(l/2)]
127
+
128
+ # compute samples for memory update
129
+ if self.infinite_memory:
130
+ tm_tau = torch.arange(1,self.nb_samples+1).float()
131
+ tm_l = torch.arange(self.nb_samples+1,length+self.nb_samples+1).float()
132
+ tm_tau = tm_tau*self.tau/self.nb_samples # positions of old vectors
133
+ tm_l = self.tau + (1-self.tau)*(tm_l-self.nb_samples)/length # positions of new vectors
134
+ positions_inf = torch.cat([tm_tau, tm_l],0).to(self.device) # positions
135
+
136
+ if padding:
137
+ if l % 2:
138
+ shift = 1 / float(length+self.nb_samples)
139
+ positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)-1).to(self.device)
140
+ else:
141
+ shift = 1 / float(2*length+self.nb_samples)
142
+ positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)).to(self.device)
143
+ positions_pad_ = torch.FloatTensor([i for i in positions_pad if i<0]).to(self.device)
144
+ positions_pad__ = torch.FloatTensor([i for i in positions_pad if i>1]).to(self.device)
145
+ positions_inf = torch.cat([positions_pad_,positions_inf,positions_pad__], dim=0)
146
+
147
+ self.samples=None
148
+ for t in tm_tau:
149
+ if self.samples is None:
150
+ self.samples = self.psi[l][0].evaluate(t/self.tau)
151
+ else:
152
+ self.samples = torch.cat([self.samples,self.psi[l][0].evaluate(t/self.tau)], dim=0)
153
+
154
+ # compute G for the infinite case
155
+ self.G_inf = compute_G(self.nb_samples+length, self.psi[l][0], positions_inf, padding=padding) #[L+nb_samples,N]
156
+
157
+ if self.sticky_memories:
158
+ self.bins = torch.linspace(0,1,129).to(device=self.device) #self.positions
159
+ self.nb_bins_cat=1
160
+ self.bins_cat = dist.Categorical(torch.ones(self.nb_bins_cat))
161
+
162
+ def add_gaussian_basis_functions(self, psi, nb_basis, sigmas, device):
163
+ mu, sigma = torch.meshgrid(torch.linspace(0, 1, nb_basis // len(sigmas)), torch.Tensor(sigmas))
164
+ mu = mu.flatten().to(device)
165
+ sigma = sigma.flatten().to(device)
166
+ self.basis_mu=mu
167
+ self.basis_sigma=sigma
168
+ assert mu.size(0) == nb_basis
169
+ psi.append(GaussianBasisFunctions(mu=mu, sigma=sigma))
170
+
171
+ def add_retangular_basis_functions(self, psi, nb_basis, device):
172
+ width = torch.ones(nb_basis, device=device) / nb_basis
173
+
174
+ # Compute the centers (midpoints) of each bin
175
+ edges = torch.linspace(0, 1, nb_basis + 1, device=device)
176
+ mu = (edges[:-1] + edges[1:]) / 2
177
+ psi.append(RetangularBasisFunctions(mu=mu, sigma=width))
178
+
179
+ def value_function(self, x, inf=False):
180
+ if inf:
181
+ G = self.G_inf # [nb_sample+L,N]
182
+ else:
183
+ G = self.Gs[x.size(-1)] # [L,N]
184
+ B = torch.matmul(x, G) # [B,e,N]
185
+ B = B.permute(0,2,1) # [B,N,e]
186
+
187
+ return B
188
+
189
+ def update_inf(self, x):
190
+ if self.B_past is not None:
191
+ if self.sticky_memories:
192
+ bins = self.bins.clone()
193
+ bins[0]=-.000001
194
+ bins[-1]=1.000001
195
+ prob_density = self.compute_probability(self.score, t=bins)
196
+ cum_prob = torch.cumulative_trapezoid(prob_density, bins, dim=-1).to(self.device)
197
+ p = (cum_prob[..., 1:] - cum_prob[..., :-1]).sum(dim=(1, 2))
198
+ p = p / p.sum(-1, keepdim=True) # Normalize over the last dimension (bins)
199
+ p = dist.Categorical(p)
200
+ b = p.sample((self.nb_samples,))
201
+ t = self.bins_cat.sample((self.nb_samples, 1)).to(device=self.device)
202
+ ts = (t*(self.bins[b+1]-self.bins[b])/self.nb_bins_cat +self.bins[b]).transpose(1,0)
203
+ samples = self.psi[self.length][0].batch_evaluate(ts[0]).contiguous()
204
+
205
+ xm_tau = self.B_past.transpose(-1,-2).matmul(samples.transpose(-1,-2)) # [B,e,nb_samples]
206
+ else:
207
+ xm_tau = self.B_past.transpose(-1,-2).matmul(self.samples.transpose(-1,-2)) # [B,e,nb_samples]
208
+
209
+
210
+ x = torch.cat([xm_tau,x], dim=2) # [B,e,nb_samples+L]
211
+ B = self.value_function(x, inf=True) # [B,N,e]
212
+ else:
213
+ B = self.value_function(x)
214
+
215
+ self.B_past=B.detach()
216
+ self.x_past=x
217
+ return B
218
+
219
+ def score(self, t):
220
+ psis = self.psis[0].batch_evaluate(t)
221
+ query = self.queries/ (self.d_head ** 0.5) # divide by sqrt(d_head) [B,h,q,d]
222
+ keys = self.keys.transpose(-1, -2)
223
+ keys = torch.matmul(keys, psis.T) #[B,h,d,1]
224
+ scores = torch.matmul(query, keys) #[B,h,q,1]
225
+ return scores
226
+
227
+ def compute_probability(self, score_fn, num_points=1000, t=None):
228
+ """
229
+ Compute probability distribution p(t).
230
+
231
+ Args:
232
+ score_fn (callable): Function that computes z(t)
233
+ num_points (int): Number of points for numerical integration
234
+
235
+ Returns:
236
+ tuple: (probabilities, normalization constant)
237
+ """
238
+ if t is None:
239
+ # Create integration points
240
+ t = torch.linspace(0, 1, num_points).to(self.device)
241
+
242
+ scores = score_fn(t)
243
+ prob = torch.exp(scores) / torch.trapz(torch.exp(scores), t, dim=-1).unsqueeze(-1)
244
+ return prob
245
+
246
+ def expected_value(self, score_fn, num_points=1000):
247
+ """
248
+ Compute expected value E_p[V(t)] using nested integration.
249
+
250
+ Args:
251
+ score_fn (callable): Function that computes z(t)
252
+ value_fn (callable): Function that computes v(t)
253
+ num_points (int): Number of points for numerical integration
254
+
255
+ Returns:
256
+ torch.Tensor: Expected value
257
+ """
258
+ # Create integration points
259
+ t = torch.linspace(0, 1, num_points).to(self.device)
260
+
261
+ # Compute basis functions
262
+ self.psis = []
263
+ self.add_retangular_basis_functions(self.psis, self.attn_num_basis, self.device)
264
+ psi = self.psis[0].batch_evaluate(t)
265
+ # Compute probability distribution
266
+ prob = self.compute_probability(score_fn, num_points)
267
+ # Compute values at integration points
268
+ values = self.values
269
+ # Compute p(t) * psi(t)
270
+ # Reshape psi for broadcasting to match the shape of prob
271
+ psi_broadcasted = psi.unsqueeze(1).unsqueeze(2).unsqueeze(3)
272
+
273
+ # Expand psi to match the dimensions of prob (num_points, batch_size, n_head, qlen, 256)
274
+ psi_broadcasted = psi_broadcasted.expand(num_points, self.batch_size, self.n_head, self.qlen, self.attn_num_basis)
275
+ integrand = torch.matmul(prob.permute(3,0,1,2).unsqueeze(-1).unsqueeze(-1), psi_broadcasted.unsqueeze(-2)).permute(1, 2, 3, 4, 5, 0).squeeze(-3)
276
+
277
+ integral = torch.trapz(integrand, t, dim=-1)
278
+ # Matrix multiply with values
279
+ expected_value = torch.matmul(integral, values) # [B, h, q, d]
280
+
281
+ return expected_value
282
+
283
+ def forward(self, k, q, new_doc, layer_n):
284
+ self.device = k.device
285
+ if self.continuous:
286
+ klen = int(k.size(1)/(14*14))
287
+ self.length = klen
288
+ batch_size = k.size(0) #batch size
289
+ qlen = q.size(1) #query length
290
+ self.qlen = qlen
291
+ self.batch_size = batch_size
292
+ self.d_head = self.head_size #head size
293
+ self.get_basis(klen, klen)
294
+ # clean memory if going through different document
295
+ if new_doc:
296
+ self.B_past=None
297
+ self.x_past=None
298
+
299
+ k = k.reshape(batch_size, klen, 14, 14, 1024).mean(dim=(2, 3))
300
+ k = k.transpose(1,2)
301
+ # perform memory update
302
+ if self.infinite_memory:
303
+ B = self.update_inf(k)
304
+ else: # compute input continuous approximation
305
+ B = self.value_function(k) # [B,N,e]
306
+ keys = self.proj_key(B)
307
+ values = self.proj_value(B)
308
+ query = q
309
+ self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
310
+ self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
311
+ self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
312
+ context = self.expected_value(self.score) # Shape [1, 32, 768]
313
+
314
+ return context.contiguous().transpose(1,2).reshape(1, qlen, -1)
315
+