Update encoder.py
Browse files- encoder.py +27 -59
encoder.py
CHANGED
|
@@ -203,10 +203,8 @@ class MultiHeadAttention(nn.Module, ABC):
|
|
| 203 |
return self.linear_out(x)
|
| 204 |
|
| 205 |
|
|
|
|
| 206 |
class RelPositionMultiHeadAttention(MultiHeadAttention):
|
| 207 |
-
"""
|
| 208 |
-
Relative Position Multi-Head Attention module.
|
| 209 |
-
"""
|
| 210 |
|
| 211 |
def __init__(self, n_head: int, n_feat: int):
|
| 212 |
super().__init__(n_head, n_feat)
|
|
@@ -214,19 +212,20 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
|
|
| 214 |
self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
|
| 215 |
self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
|
| 216 |
|
| 217 |
-
|
|
|
|
| 218 |
b, h, qlen, pos_len = x.size()
|
| 219 |
x = torch.nn.functional.pad(x, pad=(1, 0))
|
| 220 |
x = x.view(b, h, -1, qlen)
|
| 221 |
return x[:, :, 1:].view(b, h, qlen, pos_len)
|
| 222 |
|
| 223 |
def forward(
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
) -> Tensor:
|
| 231 |
q, k, v = self.forward_qkv(query, key, value)
|
| 232 |
q = q.transpose(1, 2)
|
|
@@ -243,17 +242,14 @@ class RelPositionMultiHeadAttention(MultiHeadAttention):
|
|
| 243 |
|
| 244 |
|
| 245 |
class RotaryPositionMultiHeadAttention(MultiHeadAttention):
|
| 246 |
-
"""
|
| 247 |
-
Rotary Position Multi-Head Attention module.
|
| 248 |
-
"""
|
| 249 |
|
| 250 |
def forward(
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
) -> Tensor:
|
| 258 |
b, t, _ = value.size()
|
| 259 |
query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
|
|
@@ -269,25 +265,13 @@ class RotaryPositionMultiHeadAttention(MultiHeadAttention):
|
|
| 269 |
value.view(t, b, self.h * self.d_k).transpose(0, 1),
|
| 270 |
)
|
| 271 |
|
| 272 |
-
|
| 273 |
-
scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
|
| 274 |
out = self.forward_attention(v, scores, mask)
|
| 275 |
-
# else:
|
| 276 |
-
# if mask is None:
|
| 277 |
-
# scores = flash_attn_func(q, k, v)
|
| 278 |
-
# else:
|
| 279 |
-
# scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
|
| 280 |
-
|
| 281 |
-
# scores = scores.view(b, -1, self.h * self.d_k)
|
| 282 |
-
# out = self.linear_out(scores)
|
| 283 |
|
| 284 |
return out
|
| 285 |
|
| 286 |
|
| 287 |
class PositionalEncoding(nn.Module, ABC):
|
| 288 |
-
"""
|
| 289 |
-
Base class of Positional Encodings.
|
| 290 |
-
"""
|
| 291 |
|
| 292 |
def __init__(self, dim: int, base: int):
|
| 293 |
super().__init__()
|
|
@@ -295,14 +279,11 @@ class PositionalEncoding(nn.Module, ABC):
|
|
| 295 |
self.base = base
|
| 296 |
|
| 297 |
@abstractmethod
|
| 298 |
-
def create_pe(self, length: int
|
| 299 |
pass
|
| 300 |
|
| 301 |
-
def extend_pe(self, length: int
|
| 302 |
-
|
| 303 |
-
Extends the positional encoding buffer to process longer sequences.
|
| 304 |
-
"""
|
| 305 |
-
pe = self.create_pe(length, device)
|
| 306 |
if pe is None:
|
| 307 |
return
|
| 308 |
if hasattr(self, "pe"):
|
|
@@ -312,17 +293,10 @@ class PositionalEncoding(nn.Module, ABC):
|
|
| 312 |
|
| 313 |
|
| 314 |
class RelPositionalEmbedding(PositionalEncoding):
|
| 315 |
-
|
| 316 |
-
Relative Positional Embedding module.
|
| 317 |
-
"""
|
| 318 |
-
|
| 319 |
-
def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
|
| 320 |
-
"""
|
| 321 |
-
Creates the relative positional encoding matrix.
|
| 322 |
-
"""
|
| 323 |
if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
|
| 324 |
return None
|
| 325 |
-
positions = torch.arange(length - 1, -length, -1
|
| 326 |
pos_length = positions.size(0)
|
| 327 |
pe = torch.zeros(pos_length, self.dim, device=positions.device)
|
| 328 |
div_term = torch.exp(
|
|
@@ -342,29 +316,23 @@ class RelPositionalEmbedding(PositionalEncoding):
|
|
| 342 |
|
| 343 |
|
| 344 |
class RotaryPositionalEmbedding(PositionalEncoding):
|
| 345 |
-
"""
|
| 346 |
-
Rotary Positional Embedding module.
|
| 347 |
-
"""
|
| 348 |
|
| 349 |
-
def create_pe(self, length: int
|
| 350 |
-
"""
|
| 351 |
-
Creates or extends the rotary positional encoding matrix.
|
| 352 |
-
"""
|
| 353 |
if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
|
| 354 |
return None
|
| 355 |
-
positions = torch.arange(0, length, dtype=torch.float32
|
| 356 |
inv_freq = 1.0 / (
|
| 357 |
-
|
| 358 |
)
|
| 359 |
-
t = torch.arange(length, device=positions.device
|
| 360 |
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 361 |
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
|
| 362 |
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
|
| 363 |
|
| 364 |
def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
|
| 365 |
-
cos_emb = self.pe[0
|
| 366 |
half_pe = self.pe.shape[0] // 2
|
| 367 |
-
sin_emb = self.pe[half_pe
|
| 368 |
return x, [cos_emb, sin_emb]
|
| 369 |
|
| 370 |
|
|
|
|
| 203 |
return self.linear_out(x)
|
| 204 |
|
| 205 |
|
| 206 |
+
|
| 207 |
class RelPositionMultiHeadAttention(MultiHeadAttention):
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
def __init__(self, n_head: int, n_feat: int):
|
| 210 |
super().__init__(n_head, n_feat)
|
|
|
|
| 212 |
self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
|
| 213 |
self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
|
| 214 |
|
| 215 |
+
@staticmethod
|
| 216 |
+
def rel_shift(x: Tensor) -> Tensor:
|
| 217 |
b, h, qlen, pos_len = x.size()
|
| 218 |
x = torch.nn.functional.pad(x, pad=(1, 0))
|
| 219 |
x = x.view(b, h, -1, qlen)
|
| 220 |
return x[:, :, 1:].view(b, h, qlen, pos_len)
|
| 221 |
|
| 222 |
def forward(
|
| 223 |
+
self,
|
| 224 |
+
query: Tensor,
|
| 225 |
+
key: Tensor,
|
| 226 |
+
value: Tensor,
|
| 227 |
+
pos_emb: Tensor,
|
| 228 |
+
mask: Optional[Tensor] = None,
|
| 229 |
) -> Tensor:
|
| 230 |
q, k, v = self.forward_qkv(query, key, value)
|
| 231 |
q = q.transpose(1, 2)
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
class RotaryPositionMultiHeadAttention(MultiHeadAttention):
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
def forward(
|
| 247 |
+
self,
|
| 248 |
+
query: Tensor,
|
| 249 |
+
key: Tensor,
|
| 250 |
+
value: Tensor,
|
| 251 |
+
pos_emb: List[Tensor],
|
| 252 |
+
mask: Optional[Tensor] = None,
|
| 253 |
) -> Tensor:
|
| 254 |
b, t, _ = value.size()
|
| 255 |
query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
|
|
|
|
| 265 |
value.view(t, b, self.h * self.d_k).transpose(0, 1),
|
| 266 |
)
|
| 267 |
|
| 268 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
|
|
|
| 269 |
out = self.forward_attention(v, scores, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
return out
|
| 272 |
|
| 273 |
|
| 274 |
class PositionalEncoding(nn.Module, ABC):
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
def __init__(self, dim: int, base: int):
|
| 277 |
super().__init__()
|
|
|
|
| 279 |
self.base = base
|
| 280 |
|
| 281 |
@abstractmethod
|
| 282 |
+
def create_pe(self, length: int) -> Optional[Tensor]:
|
| 283 |
pass
|
| 284 |
|
| 285 |
+
def extend_pe(self, length: int):
|
| 286 |
+
pe = self.create_pe(length)
|
|
|
|
|
|
|
|
|
|
| 287 |
if pe is None:
|
| 288 |
return
|
| 289 |
if hasattr(self, "pe"):
|
|
|
|
| 293 |
|
| 294 |
|
| 295 |
class RelPositionalEmbedding(PositionalEncoding):
|
| 296 |
+
def create_pe(self, length: int) -> Optional[Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
|
| 298 |
return None
|
| 299 |
+
positions = torch.arange(length - 1, -length, -1).unsqueeze(1)
|
| 300 |
pos_length = positions.size(0)
|
| 301 |
pe = torch.zeros(pos_length, self.dim, device=positions.device)
|
| 302 |
div_term = torch.exp(
|
|
|
|
| 316 |
|
| 317 |
|
| 318 |
class RotaryPositionalEmbedding(PositionalEncoding):
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
+
def create_pe(self, length: int) -> Optional[Tensor]:
|
|
|
|
|
|
|
|
|
|
| 321 |
if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
|
| 322 |
return None
|
| 323 |
+
positions = torch.arange(0, length, dtype=torch.float32)
|
| 324 |
inv_freq = 1.0 / (
|
| 325 |
+
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
| 326 |
)
|
| 327 |
+
t = torch.arange(length, device=positions.device).type_as(inv_freq)
|
| 328 |
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 329 |
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
|
| 330 |
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
|
| 331 |
|
| 332 |
def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
|
| 333 |
+
cos_emb = self.pe[0: x.shape[1]]
|
| 334 |
half_pe = self.pe.shape[0] // 2
|
| 335 |
+
sin_emb = self.pe[half_pe: half_pe + x.shape[1]]
|
| 336 |
return x, [cos_emb, sin_emb]
|
| 337 |
|
| 338 |
|