intexcp commited on
Commit
708a987
·
verified ·
1 Parent(s): 6ed8e3e

Update encoder.py

Browse files
Files changed (1) hide show
  1. encoder.py +27 -59
encoder.py CHANGED
@@ -203,10 +203,8 @@ class MultiHeadAttention(nn.Module, ABC):
203
  return self.linear_out(x)
204
 
205
 
 
206
  class RelPositionMultiHeadAttention(MultiHeadAttention):
207
- """
208
- Relative Position Multi-Head Attention module.
209
- """
210
 
211
  def __init__(self, n_head: int, n_feat: int):
212
  super().__init__(n_head, n_feat)
@@ -214,19 +212,20 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
214
  self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
215
  self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
216
 
217
- def rel_shift(self, x: Tensor) -> Tensor:
 
218
  b, h, qlen, pos_len = x.size()
219
  x = torch.nn.functional.pad(x, pad=(1, 0))
220
  x = x.view(b, h, -1, qlen)
221
  return x[:, :, 1:].view(b, h, qlen, pos_len)
222
 
223
  def forward(
224
- self,
225
- query: Tensor,
226
- key: Tensor,
227
- value: Tensor,
228
- pos_emb: Tensor,
229
- mask: Optional[Tensor] = None,
230
  ) -> Tensor:
231
  q, k, v = self.forward_qkv(query, key, value)
232
  q = q.transpose(1, 2)
@@ -243,17 +242,14 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
243
 
244
 
245
  class RotaryPositionMultiHeadAttention(MultiHeadAttention):
246
- """
247
- Rotary Position Multi-Head Attention module.
248
- """
249
 
250
  def forward(
251
- self,
252
- query: Tensor,
253
- key: Tensor,
254
- value: Tensor,
255
- pos_emb: List[Tensor],
256
- mask: Optional[Tensor] = None,
257
  ) -> Tensor:
258
  b, t, _ = value.size()
259
  query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
@@ -269,25 +265,13 @@ class RotaryPositionMultiHeadAttention(MultiHeadAttention):
269
  value.view(t, b, self.h * self.d_k).transpose(0, 1),
270
  )
271
 
272
- # if not self.flash_attn:
273
- scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
274
  out = self.forward_attention(v, scores, mask)
275
- # else:
276
- # if mask is None:
277
- # scores = flash_attn_func(q, k, v)
278
- # else:
279
- # scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
280
-
281
- # scores = scores.view(b, -1, self.h * self.d_k)
282
- # out = self.linear_out(scores)
283
 
284
  return out
285
 
286
 
287
  class PositionalEncoding(nn.Module, ABC):
288
- """
289
- Base class of Positional Encodings.
290
- """
291
 
292
  def __init__(self, dim: int, base: int):
293
  super().__init__()
@@ -295,14 +279,11 @@ class PositionalEncoding(nn.Module, ABC):
295
  self.base = base
296
 
297
  @abstractmethod
298
- def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
299
  pass
300
 
301
- def extend_pe(self, length: int, device: torch.device):
302
- """
303
- Extends the positional encoding buffer to process longer sequences.
304
- """
305
- pe = self.create_pe(length, device)
306
  if pe is None:
307
  return
308
  if hasattr(self, "pe"):
@@ -312,17 +293,10 @@ class PositionalEncoding(nn.Module, ABC):
312
 
313
 
314
  class RelPositionalEmbedding(PositionalEncoding):
315
- """
316
- Relative Positional Embedding module.
317
- """
318
-
319
- def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
320
- """
321
- Creates the relative positional encoding matrix.
322
- """
323
  if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
324
  return None
325
- positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1)
326
  pos_length = positions.size(0)
327
  pe = torch.zeros(pos_length, self.dim, device=positions.device)
328
  div_term = torch.exp(
@@ -342,29 +316,23 @@ class RelPositionalEmbedding(PositionalEncoding):
342
 
343
 
344
  class RotaryPositionalEmbedding(PositionalEncoding):
345
- """
346
- Rotary Positional Embedding module.
347
- """
348
 
349
- def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
350
- """
351
- Creates or extends the rotary positional encoding matrix.
352
- """
353
  if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
354
  return None
355
- positions = torch.arange(0, length, dtype=torch.float32, device=device)
356
  inv_freq = 1.0 / (
357
- self.base ** (torch.arange(0, self.dim, 2, device=positions.device).float() / self.dim)
358
  )
359
- t = torch.arange(length, device=positions.device, dtype=inv_freq.dtype)
360
  freqs = torch.einsum("i,j->ij", t, inv_freq)
361
  emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
362
  return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
363
 
364
  def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
365
- cos_emb = self.pe[0 : x.shape[1]]
366
  half_pe = self.pe.shape[0] // 2
367
- sin_emb = self.pe[half_pe : half_pe + x.shape[1]]
368
  return x, [cos_emb, sin_emb]
369
 
370
 
 
203
  return self.linear_out(x)
204
 
205
 
206
+
207
  class RelPositionMultiHeadAttention(MultiHeadAttention):
 
 
 
208
 
209
  def __init__(self, n_head: int, n_feat: int):
210
  super().__init__(n_head, n_feat)
 
212
  self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
213
  self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
214
 
215
+ @staticmethod
216
+ def rel_shift(x: Tensor) -> Tensor:
217
  b, h, qlen, pos_len = x.size()
218
  x = torch.nn.functional.pad(x, pad=(1, 0))
219
  x = x.view(b, h, -1, qlen)
220
  return x[:, :, 1:].view(b, h, qlen, pos_len)
221
 
222
  def forward(
223
+ self,
224
+ query: Tensor,
225
+ key: Tensor,
226
+ value: Tensor,
227
+ pos_emb: Tensor,
228
+ mask: Optional[Tensor] = None,
229
  ) -> Tensor:
230
  q, k, v = self.forward_qkv(query, key, value)
231
  q = q.transpose(1, 2)
 
242
 
243
 
244
  class RotaryPositionMultiHeadAttention(MultiHeadAttention):
 
 
 
245
 
246
  def forward(
247
+ self,
248
+ query: Tensor,
249
+ key: Tensor,
250
+ value: Tensor,
251
+ pos_emb: List[Tensor],
252
+ mask: Optional[Tensor] = None,
253
  ) -> Tensor:
254
  b, t, _ = value.size()
255
  query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
 
265
  value.view(t, b, self.h * self.d_k).transpose(0, 1),
266
  )
267
 
268
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
 
269
  out = self.forward_attention(v, scores, mask)
 
 
 
 
 
 
 
 
270
 
271
  return out
272
 
273
 
274
  class PositionalEncoding(nn.Module, ABC):
 
 
 
275
 
276
  def __init__(self, dim: int, base: int):
277
  super().__init__()
 
279
  self.base = base
280
 
281
  @abstractmethod
282
+ def create_pe(self, length: int) -> Optional[Tensor]:
283
  pass
284
 
285
+ def extend_pe(self, length: int):
286
+ pe = self.create_pe(length)
 
 
 
287
  if pe is None:
288
  return
289
  if hasattr(self, "pe"):
 
293
 
294
 
295
  class RelPositionalEmbedding(PositionalEncoding):
296
+ def create_pe(self, length: int) -> Optional[Tensor]:
 
 
 
 
 
 
 
297
  if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
298
  return None
299
+ positions = torch.arange(length - 1, -length, -1).unsqueeze(1)
300
  pos_length = positions.size(0)
301
  pe = torch.zeros(pos_length, self.dim, device=positions.device)
302
  div_term = torch.exp(
 
316
 
317
 
318
  class RotaryPositionalEmbedding(PositionalEncoding):
 
 
 
319
 
320
+ def create_pe(self, length: int) -> Optional[Tensor]:
 
 
 
321
  if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
322
  return None
323
+ positions = torch.arange(0, length, dtype=torch.float32)
324
  inv_freq = 1.0 / (
325
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
326
  )
327
+ t = torch.arange(length, device=positions.device).type_as(inv_freq)
328
  freqs = torch.einsum("i,j->ij", t, inv_freq)
329
  emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
330
  return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
331
 
332
  def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
333
+ cos_emb = self.pe[0: x.shape[1]]
334
  half_pe = self.pe.shape[0] // 2
335
+ sin_emb = self.pe[half_pe: half_pe + x.shape[1]]
336
  return x, [cos_emb, sin_emb]
337
 
338