Initial FanFormer checkpoint with architecture and README
Browse files- model.safetensors +2 -2
- model_architecture.py +89 -74
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39fb924d47a28f1b2456828b984b26bc5c05382f02b7fc878e067b729edfd88a
|
| 3 |
+
size 331511816
|
model_architecture.py
CHANGED
|
@@ -353,6 +353,8 @@ class FANformerMultiheadAttention(nn.Module):
|
|
| 353 |
Implementaci贸n de la atenci贸n multi-cabeza con FANformer.
|
| 354 |
Aplica normalizaci贸n a Q, K, V individualmente y utiliza unpadding para mejorar el rendimiento.
|
| 355 |
Incorpora modelado de periodicidad a trav茅s de proyecciones CoLA_FAN.
|
|
|
|
|
|
|
| 356 |
"""
|
| 357 |
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.12, use_rope: bool = True,
|
| 358 |
layer_index: int = 1, max_seq_len: int = 512, p: float = 0.15,
|
|
@@ -366,21 +368,26 @@ class FANformerMultiheadAttention(nn.Module):
|
|
| 366 |
self.layer_index = layer_index
|
| 367 |
self.use_pre_norm = use_pre_norm
|
| 368 |
self.p = p # Proporci贸n para periodicidad
|
| 369 |
-
|
| 370 |
if embed_dim % num_heads != 0:
|
| 371 |
raise ValueError("embed_dim debe ser divisible por num_heads")
|
| 372 |
-
|
| 373 |
self.head_dim = embed_dim // num_heads
|
| 374 |
self.use_rope = use_rope
|
| 375 |
-
|
| 376 |
if num_gqa_groups is None:
|
| 377 |
num_gqa_groups = num_heads
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
try:
|
| 380 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 381 |
self.flash_attn_func = flash_attn_func
|
| 382 |
self.flash_attn_varlen_func = flash_attn_varlen_func
|
| 383 |
except ImportError as e:
|
|
|
|
| 384 |
raise ImportError(f"Error al inicializar FlashAttention: {e}")
|
| 385 |
|
| 386 |
# Para el unpadding
|
|
@@ -389,20 +396,21 @@ class FANformerMultiheadAttention(nn.Module):
|
|
| 389 |
self.unpad_input = unpad_input
|
| 390 |
self.pad_input = pad_input
|
| 391 |
except ImportError as e:
|
|
|
|
| 392 |
raise ImportError(f"Error al importar funciones de padding: {e}")
|
| 393 |
|
| 394 |
-
#
|
| 395 |
-
self.ssmax_scale = nn.Parameter(torch.ones(num_heads, dtype=torch.bfloat16) * 0.168)
|
| 396 |
-
nn.init.uniform_(self.ssmax_scale, a=0.166, b=0.170)
|
| 397 |
-
self.register_buffer('seq_scale', torch.log(torch.tensor(max_seq_len, dtype=torch.bfloat16)))
|
| 398 |
-
|
| 399 |
# Capas de normalizaci贸n para la entrada (Pre-Norm en primer bloque o QKV-Norm para los dem谩s)
|
| 400 |
self.norm = nn.RMSNorm(embed_dim, eps=1e-5)
|
| 401 |
-
|
| 402 |
# Capas de dropout (simplificadas)
|
| 403 |
-
self.attention_dropout = progressive_dropout(dropout, depth=
|
| 404 |
# Eliminado: self.projection_dropout = progressive_dropout(dropout * 1.1, depth=1)
|
| 405 |
-
self.output_dropout = progressive_dropout(dropout, depth=
|
| 406 |
|
| 407 |
# Proyecciones para Q, K, V usando GQAFANLinear (implementaci贸n FANformer)
|
| 408 |
self.Wq = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
|
|
@@ -413,165 +421,172 @@ class FANformerMultiheadAttention(nn.Module):
|
|
| 413 |
self.out_proj = CoLA_Linear(embed_dim, embed_dim, rank=embed_dim // 4)
|
| 414 |
|
| 415 |
def scaled_dot_product_attention_flash_unpadded(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 416 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 417 |
is_causal: bool = False) -> torch.Tensor:
|
| 418 |
B, H, S, D = q.shape # batch, heads, sequence length, head dimension
|
| 419 |
-
|
|
|
|
| 420 |
if attention_mask is None:
|
| 421 |
# Si no hay m谩scara de atenci贸n, usamos la versi贸n regular
|
| 422 |
return self.scaled_dot_product_attention_flash(q, k, v, mask=None, is_causal=is_causal)
|
| 423 |
-
|
| 424 |
# Convertir las tensiones a [B, S, H, D] para unpad_input
|
| 425 |
q_unpad = q.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 426 |
k_unpad = k.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 427 |
v_unpad = v.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 428 |
-
|
| 429 |
# Preparar m谩scara: convertir a bool si es necesario
|
|
|
|
| 430 |
if attention_mask.dtype != torch.bool:
|
| 431 |
attention_mask = attention_mask.bool()
|
| 432 |
-
|
| 433 |
# Hacer unpadding de los tensores
|
|
|
|
| 434 |
q_unpadded, indices_q, cu_seqlens_q, max_seqlen_q, _ = self.unpad_input(q_unpad, attention_mask)
|
| 435 |
k_unpadded, indices_k, cu_seqlens_k, max_seqlen_k, _ = self.unpad_input(k_unpad, attention_mask)
|
| 436 |
v_unpadded, _, _, _, _ = self.unpad_input(v_unpad, attention_mask)
|
| 437 |
-
|
| 438 |
# Reacomodar para flash_attn_varlen_func: [Total, H, D]
|
| 439 |
q_unpadded = q_unpadded.reshape(-1, H, D)
|
| 440 |
k_unpadded = k_unpadded.reshape(-1, H, D)
|
| 441 |
v_unpadded = v_unpadded.reshape(-1, H, D)
|
| 442 |
-
|
| 443 |
# Normalizar vectores Q y K para mejorar estabilidad num茅rica
|
| 444 |
-
q_norm = F.normalize(q_unpadded, p=2, dim=-1)
|
| 445 |
-
k_norm = F.normalize(k_unpadded, p=2, dim=-1)
|
| 446 |
-
|
| 447 |
-
#
|
| 448 |
-
s = self.ssmax_scale.view(1, H, 1)
|
| 449 |
-
q_adjusted = q_norm * (self.seq_scale * s)
|
| 450 |
-
|
| 451 |
-
# Factor de escala para softmax
|
| 452 |
-
|
| 453 |
-
|
| 454 |
try:
|
| 455 |
-
# Usar flash attention sin padding
|
| 456 |
output_unpadded = self.flash_attn_varlen_func(
|
| 457 |
-
|
| 458 |
cu_seqlens_q, cu_seqlens_k,
|
| 459 |
max_seqlen_q, max_seqlen_k,
|
| 460 |
dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
|
| 461 |
-
softmax_scale=
|
| 462 |
causal=is_causal
|
| 463 |
)
|
| 464 |
-
|
| 465 |
# Volver a aplicar padding
|
| 466 |
output_padded = self.pad_input(output_unpadded, indices_q, B, S)
|
| 467 |
-
|
| 468 |
# Reorganizar a [B, H, S, D]
|
| 469 |
output = output_padded.reshape(B, S, H, D).permute(0, 2, 1, 3)
|
| 470 |
-
|
| 471 |
return output
|
| 472 |
-
|
| 473 |
except Exception as e:
|
|
|
|
| 474 |
raise RuntimeError(f"Error en flash_attn_varlen_func: {e}")
|
| 475 |
|
| 476 |
def scaled_dot_product_attention_flash(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 477 |
-
mask: Optional[torch.Tensor] = None,
|
| 478 |
is_causal: bool = False) -> torch.Tensor:
|
| 479 |
# Normalizar vectores Q y K para mejorar estabilidad num茅rica
|
| 480 |
-
q_norm = F.normalize(q, p=2, dim=-1)
|
| 481 |
-
k_norm = F.normalize(k, p=2, dim=-1)
|
| 482 |
-
|
| 483 |
-
#
|
| 484 |
-
s = self.ssmax_scale.view(-1, 1, 1)
|
| 485 |
-
q_adjusted = q_norm * (self.seq_scale * s)
|
| 486 |
-
|
| 487 |
# Preparar tensores para Flash Attention (requiere shape [B, S, H, D])
|
| 488 |
-
q_trans =
|
| 489 |
k_trans = k_norm.permute(0, 2, 1, 3)
|
| 490 |
v_trans = v.permute(0, 2, 1, 3)
|
| 491 |
-
|
| 492 |
-
#
|
| 493 |
if q_trans.size(-1) != k_trans.size(-1):
|
| 494 |
raise ValueError(f"Las dimensiones de head no coinciden: q={q_trans.size(-1)}, k={k_trans.size(-1)}")
|
| 495 |
-
|
| 496 |
-
# Factor de escala para softmax
|
| 497 |
-
softmax_scale = 1.0 / math.sqrt(q_trans.size(-1))
|
| 498 |
-
|
| 499 |
try:
|
| 500 |
-
# Aplicar Flash Attention
|
| 501 |
output = self.flash_attn_func(
|
| 502 |
q_trans, k_trans, v_trans,
|
| 503 |
dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
|
| 504 |
-
softmax_scale=
|
| 505 |
causal=is_causal
|
|
|
|
| 506 |
)
|
| 507 |
-
|
|
|
|
| 508 |
if output is None:
|
| 509 |
raise ValueError("flash_attn_func devolvi贸 None. Verifica las dimensiones y tipos de los tensores de entrada.")
|
| 510 |
-
|
| 511 |
# Volver a la forma original
|
| 512 |
output = output.permute(0, 2, 1, 3)
|
| 513 |
return output
|
| 514 |
-
|
| 515 |
except Exception as e:
|
|
|
|
| 516 |
raise RuntimeError(f"Error en flash_attn_func: {e}")
|
| 517 |
|
| 518 |
def forward(self, X: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal: bool = True) -> torch.Tensor:
|
| 519 |
B, T, _ = X.shape
|
| 520 |
-
|
|
|
|
| 521 |
# Implementaci贸n de HybridNorm*
|
| 522 |
if self.use_pre_norm:
|
| 523 |
# Primer bloque: Pre-Norm en atenci贸n
|
| 524 |
-
|
|
|
|
| 525 |
# Proyecciones para Q, K, V con FANformer
|
| 526 |
Q = self.Wq(X_norm) # [B, T, num_heads, head_dim]
|
| 527 |
K = self.Wk(X_norm) # [B, T, num_heads, head_dim]
|
| 528 |
V = self.Wv(X_norm) # [B, T, num_heads, head_dim]
|
| 529 |
else:
|
| 530 |
# Otros bloques: QKV-Norm
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
|
|
|
| 535 |
# Permutar a formato [B, num_heads, T, head_dim]
|
| 536 |
Q = Q.permute(0, 2, 1, 3)
|
| 537 |
K = K.permute(0, 2, 1, 3)
|
| 538 |
V = V.permute(0, 2, 1, 3)
|
| 539 |
-
|
| 540 |
# Aplicar RoPE si est谩 activado
|
| 541 |
if self.use_rope:
|
| 542 |
Q = apply_rope_vectorized(Q)
|
| 543 |
K = apply_rope_vectorized(K)
|
| 544 |
-
|
| 545 |
-
# Convertir a bfloat16 para flash attention
|
| 546 |
Q = Q.to(torch.bfloat16)
|
| 547 |
K = K.to(torch.bfloat16)
|
| 548 |
V = V.to(torch.bfloat16)
|
| 549 |
-
|
| 550 |
# Procesar la secuencia utilizando unpadding si hay m谩scara de atenci贸n
|
|
|
|
| 551 |
if attention_mask is not None:
|
| 552 |
attn_output = self.scaled_dot_product_attention_flash_unpadded(
|
| 553 |
-
Q, K, V,
|
| 554 |
-
attention_mask=attention_mask,
|
| 555 |
is_causal=causal
|
| 556 |
)
|
| 557 |
else:
|
| 558 |
# Si no hay m谩scara, usar la versi贸n regular
|
| 559 |
attn_output = self.scaled_dot_product_attention_flash(
|
| 560 |
-
Q, K, V,
|
| 561 |
-
mask=None,
|
| 562 |
is_causal=causal
|
| 563 |
)
|
| 564 |
-
|
| 565 |
-
# Eliminada la aplicaci贸n redundante de dropout
|
| 566 |
# attn_output = self.attention_dropout(attn_output)
|
| 567 |
-
|
| 568 |
# Reorganizar la salida y aplicar proyecci贸n final
|
| 569 |
out = attn_output.permute(0, 2, 1, 3).contiguous()
|
| 570 |
out = out.reshape(B, T, self.embed_dim)
|
| 571 |
out = self.output_dropout(self.out_proj(out))
|
| 572 |
-
|
| 573 |
-
return out
|
| 574 |
|
|
|
|
| 575 |
############################################
|
| 576 |
# NUEVO M脫DULO: SWIGLU CON COLA (MLP)
|
| 577 |
############################################
|
|
|
|
| 353 |
Implementaci贸n de la atenci贸n multi-cabeza con FANformer.
|
| 354 |
Aplica normalizaci贸n a Q, K, V individualmente y utiliza unpadding para mejorar el rendimiento.
|
| 355 |
Incorpora modelado de periodicidad a trav茅s de proyecciones CoLA_FAN.
|
| 356 |
+
[MODIFICADO] Se elimin贸 el escalado ssmax_scale y seq_scale de Q.
|
| 357 |
+
[MODIFICADO] Se aplica conversi贸n expl铆cita a bfloat16 *despu茅s* de las operaciones de normalizaci贸n.
|
| 358 |
"""
|
| 359 |
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.12, use_rope: bool = True,
|
| 360 |
layer_index: int = 1, max_seq_len: int = 512, p: float = 0.15,
|
|
|
|
| 368 |
self.layer_index = layer_index
|
| 369 |
self.use_pre_norm = use_pre_norm
|
| 370 |
self.p = p # Proporci贸n para periodicidad
|
| 371 |
+
|
| 372 |
if embed_dim % num_heads != 0:
|
| 373 |
raise ValueError("embed_dim debe ser divisible por num_heads")
|
| 374 |
+
|
| 375 |
self.head_dim = embed_dim // num_heads
|
| 376 |
self.use_rope = use_rope
|
| 377 |
+
|
| 378 |
if num_gqa_groups is None:
|
| 379 |
num_gqa_groups = num_heads
|
| 380 |
+
# A帽adido chequeo de divisibilidad para GQA
|
| 381 |
+
elif num_heads % num_gqa_groups != 0:
|
| 382 |
+
raise ValueError("num_heads debe ser divisible por num_gqa_groups")
|
| 383 |
+
|
| 384 |
|
| 385 |
try:
|
| 386 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 387 |
self.flash_attn_func = flash_attn_func
|
| 388 |
self.flash_attn_varlen_func = flash_attn_varlen_func
|
| 389 |
except ImportError as e:
|
| 390 |
+
# Mantener el comportamiento original de lanzar error si no se encuentra
|
| 391 |
raise ImportError(f"Error al inicializar FlashAttention: {e}")
|
| 392 |
|
| 393 |
# Para el unpadding
|
|
|
|
| 396 |
self.unpad_input = unpad_input
|
| 397 |
self.pad_input = pad_input
|
| 398 |
except ImportError as e:
|
| 399 |
+
# Mantener el comportamiento original de lanzar error si no se encuentra
|
| 400 |
raise ImportError(f"Error al importar funciones de padding: {e}")
|
| 401 |
|
| 402 |
+
# Eliminada la inicializaci贸n de par谩metros de escala ssmax_scale y seq_scale
|
| 403 |
+
# self.ssmax_scale = nn.Parameter(torch.ones(num_heads, dtype=torch.bfloat16) * 0.168)
|
| 404 |
+
# nn.init.uniform_(self.ssmax_scale, a=0.166, b=0.170)
|
| 405 |
+
# self.register_buffer('seq_scale', torch.log(torch.tensor(max_seq_len, dtype=torch.bfloat16)))
|
| 406 |
+
|
| 407 |
# Capas de normalizaci贸n para la entrada (Pre-Norm en primer bloque o QKV-Norm para los dem谩s)
|
| 408 |
self.norm = nn.RMSNorm(embed_dim, eps=1e-5)
|
| 409 |
+
|
| 410 |
# Capas de dropout (simplificadas)
|
| 411 |
+
self.attention_dropout = progressive_dropout(dropout, depth=layer_index) # Usar layer_index
|
| 412 |
# Eliminado: self.projection_dropout = progressive_dropout(dropout * 1.1, depth=1)
|
| 413 |
+
self.output_dropout = progressive_dropout(dropout, depth=layer_index) # Usar layer_index
|
| 414 |
|
| 415 |
# Proyecciones para Q, K, V usando GQAFANLinear (implementaci贸n FANformer)
|
| 416 |
self.Wq = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
|
|
|
|
| 421 |
self.out_proj = CoLA_Linear(embed_dim, embed_dim, rank=embed_dim // 4)
|
| 422 |
|
| 423 |
def scaled_dot_product_attention_flash_unpadded(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 424 |
+
attention_mask: Optional[torch.Tensor] = None, # Revertido a Optional
|
| 425 |
is_causal: bool = False) -> torch.Tensor:
|
| 426 |
B, H, S, D = q.shape # batch, heads, sequence length, head dimension
|
| 427 |
+
|
| 428 |
+
# Mantener la l贸gica original de manejo de m谩scara opcional
|
| 429 |
if attention_mask is None:
|
| 430 |
# Si no hay m谩scara de atenci贸n, usamos la versi贸n regular
|
| 431 |
return self.scaled_dot_product_attention_flash(q, k, v, mask=None, is_causal=is_causal)
|
| 432 |
+
|
| 433 |
# Convertir las tensiones a [B, S, H, D] para unpad_input
|
| 434 |
q_unpad = q.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 435 |
k_unpad = k.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 436 |
v_unpad = v.permute(0, 2, 1, 3) # [B, S, H, D]
|
| 437 |
+
|
| 438 |
# Preparar m谩scara: convertir a bool si es necesario
|
| 439 |
+
# Mantener la l贸gica original
|
| 440 |
if attention_mask.dtype != torch.bool:
|
| 441 |
attention_mask = attention_mask.bool()
|
| 442 |
+
|
| 443 |
# Hacer unpadding de los tensores
|
| 444 |
+
# Se mantienen las salidas originales, incluyendo el quinto elemento descartado
|
| 445 |
q_unpadded, indices_q, cu_seqlens_q, max_seqlen_q, _ = self.unpad_input(q_unpad, attention_mask)
|
| 446 |
k_unpadded, indices_k, cu_seqlens_k, max_seqlen_k, _ = self.unpad_input(k_unpad, attention_mask)
|
| 447 |
v_unpadded, _, _, _, _ = self.unpad_input(v_unpad, attention_mask)
|
| 448 |
+
|
| 449 |
# Reacomodar para flash_attn_varlen_func: [Total, H, D]
|
| 450 |
q_unpadded = q_unpadded.reshape(-1, H, D)
|
| 451 |
k_unpadded = k_unpadded.reshape(-1, H, D)
|
| 452 |
v_unpadded = v_unpadded.reshape(-1, H, D)
|
| 453 |
+
|
| 454 |
# Normalizar vectores Q y K para mejorar estabilidad num茅rica
|
| 455 |
+
q_norm = F.normalize(q_unpadded, p=2, dim=-1)
|
| 456 |
+
k_norm = F.normalize(k_unpadded, p=2, dim=-1)
|
| 457 |
+
|
| 458 |
+
# Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
|
| 459 |
+
# s = self.ssmax_scale.view(1, H, 1)
|
| 460 |
+
# q_adjusted = q_norm * (self.seq_scale * s)
|
| 461 |
+
|
| 462 |
+
# Factor de escala est谩ndar para softmax
|
| 463 |
+
|
|
|
|
| 464 |
try:
|
| 465 |
+
# Usar flash attention sin padding, pasando q_norm
|
| 466 |
output_unpadded = self.flash_attn_varlen_func(
|
| 467 |
+
q_norm, k_norm, v_unpadded, # Usar q_norm directamente
|
| 468 |
cu_seqlens_q, cu_seqlens_k,
|
| 469 |
max_seqlen_q, max_seqlen_k,
|
| 470 |
dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
|
| 471 |
+
softmax_scale=None, # Escala est谩ndar
|
| 472 |
causal=is_causal
|
| 473 |
)
|
| 474 |
+
|
| 475 |
# Volver a aplicar padding
|
| 476 |
output_padded = self.pad_input(output_unpadded, indices_q, B, S)
|
| 477 |
+
|
| 478 |
# Reorganizar a [B, H, S, D]
|
| 479 |
output = output_padded.reshape(B, S, H, D).permute(0, 2, 1, 3)
|
| 480 |
+
|
| 481 |
return output
|
| 482 |
+
|
| 483 |
except Exception as e:
|
| 484 |
+
# Mantener el manejo de errores original
|
| 485 |
raise RuntimeError(f"Error en flash_attn_varlen_func: {e}")
|
| 486 |
|
| 487 |
def scaled_dot_product_attention_flash(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 488 |
+
mask: Optional[torch.Tensor] = None, # Mantener mask opcional
|
| 489 |
is_causal: bool = False) -> torch.Tensor:
|
| 490 |
# Normalizar vectores Q y K para mejorar estabilidad num茅rica
|
| 491 |
+
q_norm = F.normalize(q, p=2, dim=-1)
|
| 492 |
+
k_norm = F.normalize(k, p=2, dim=-1)
|
| 493 |
+
|
| 494 |
+
# Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
|
| 495 |
+
# s = self.ssmax_scale.view(-1, 1, 1)
|
| 496 |
+
# q_adjusted = q_norm * (self.seq_scale * s)
|
| 497 |
+
|
| 498 |
# Preparar tensores para Flash Attention (requiere shape [B, S, H, D])
|
| 499 |
+
q_trans = q_norm.permute(0, 2, 1, 3) # Usar q_norm directamente
|
| 500 |
k_trans = k_norm.permute(0, 2, 1, 3)
|
| 501 |
v_trans = v.permute(0, 2, 1, 3)
|
| 502 |
+
|
| 503 |
+
# Mantener la verificaci贸n de dimensiones original
|
| 504 |
if q_trans.size(-1) != k_trans.size(-1):
|
| 505 |
raise ValueError(f"Las dimensiones de head no coinciden: q={q_trans.size(-1)}, k={k_trans.size(-1)}")
|
| 506 |
+
|
| 507 |
+
# Factor de escala est谩ndar para softmax
|
|
|
|
|
|
|
| 508 |
try:
|
| 509 |
+
# Aplicar Flash Attention, pasando q_trans
|
| 510 |
output = self.flash_attn_func(
|
| 511 |
q_trans, k_trans, v_trans,
|
| 512 |
dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
|
| 513 |
+
softmax_scale=None, # Escala est谩ndar
|
| 514 |
causal=is_causal
|
| 515 |
+
# mask no se usa aqu铆
|
| 516 |
)
|
| 517 |
+
|
| 518 |
+
# Mantener la verificaci贸n de salida None original
|
| 519 |
if output is None:
|
| 520 |
raise ValueError("flash_attn_func devolvi贸 None. Verifica las dimensiones y tipos de los tensores de entrada.")
|
| 521 |
+
|
| 522 |
# Volver a la forma original
|
| 523 |
output = output.permute(0, 2, 1, 3)
|
| 524 |
return output
|
| 525 |
+
|
| 526 |
except Exception as e:
|
| 527 |
+
# Mantener el manejo de errores original
|
| 528 |
raise RuntimeError(f"Error en flash_attn_func: {e}")
|
| 529 |
|
| 530 |
def forward(self, X: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal: bool = True) -> torch.Tensor:
|
| 531 |
B, T, _ = X.shape
|
| 532 |
+
norm_func = self.norm # Referencia a la capa de normalizaci贸n
|
| 533 |
+
|
| 534 |
# Implementaci贸n de HybridNorm*
|
| 535 |
if self.use_pre_norm:
|
| 536 |
# Primer bloque: Pre-Norm en atenci贸n
|
| 537 |
+
# Aplicar norm y luego convertir expl铆citamente a bfloat16
|
| 538 |
+
X_norm = norm_func(X).to(torch.bfloat16)
|
| 539 |
# Proyecciones para Q, K, V con FANformer
|
| 540 |
Q = self.Wq(X_norm) # [B, T, num_heads, head_dim]
|
| 541 |
K = self.Wk(X_norm) # [B, T, num_heads, head_dim]
|
| 542 |
V = self.Wv(X_norm) # [B, T, num_heads, head_dim]
|
| 543 |
else:
|
| 544 |
# Otros bloques: QKV-Norm
|
| 545 |
+
# Aplicar norm y convertir expl铆citamente a bfloat16 antes de cada proyecci贸n
|
| 546 |
+
Q = self.Wq(norm_func(X).to(torch.bfloat16))
|
| 547 |
+
K = self.Wk(norm_func(X).to(torch.bfloat16))
|
| 548 |
+
V = self.Wv(norm_func(X).to(torch.bfloat16))
|
| 549 |
+
|
| 550 |
# Permutar a formato [B, num_heads, T, head_dim]
|
| 551 |
Q = Q.permute(0, 2, 1, 3)
|
| 552 |
K = K.permute(0, 2, 1, 3)
|
| 553 |
V = V.permute(0, 2, 1, 3)
|
| 554 |
+
|
| 555 |
# Aplicar RoPE si est谩 activado
|
| 556 |
if self.use_rope:
|
| 557 |
Q = apply_rope_vectorized(Q)
|
| 558 |
K = apply_rope_vectorized(K)
|
| 559 |
+
|
| 560 |
+
# Convertir a bfloat16 para flash attention (mantener esta conversi贸n expl铆cita)
|
| 561 |
Q = Q.to(torch.bfloat16)
|
| 562 |
K = K.to(torch.bfloat16)
|
| 563 |
V = V.to(torch.bfloat16)
|
| 564 |
+
|
| 565 |
# Procesar la secuencia utilizando unpadding si hay m谩scara de atenci贸n
|
| 566 |
+
# Mantener la l贸gica original para decidir la ruta
|
| 567 |
if attention_mask is not None:
|
| 568 |
attn_output = self.scaled_dot_product_attention_flash_unpadded(
|
| 569 |
+
Q, K, V,
|
| 570 |
+
attention_mask=attention_mask,
|
| 571 |
is_causal=causal
|
| 572 |
)
|
| 573 |
else:
|
| 574 |
# Si no hay m谩scara, usar la versi贸n regular
|
| 575 |
attn_output = self.scaled_dot_product_attention_flash(
|
| 576 |
+
Q, K, V,
|
| 577 |
+
mask=None,
|
| 578 |
is_causal=causal
|
| 579 |
)
|
| 580 |
+
|
| 581 |
+
# Eliminada la aplicaci贸n redundante de dropout (ya estaba eliminada)
|
| 582 |
# attn_output = self.attention_dropout(attn_output)
|
| 583 |
+
|
| 584 |
# Reorganizar la salida y aplicar proyecci贸n final
|
| 585 |
out = attn_output.permute(0, 2, 1, 3).contiguous()
|
| 586 |
out = out.reshape(B, T, self.embed_dim)
|
| 587 |
out = self.output_dropout(self.out_proj(out))
|
|
|
|
|
|
|
| 588 |
|
| 589 |
+
return out
|
| 590 |
############################################
|
| 591 |
# NUEVO M脫DULO: SWIGLU CON COLA (MLP)
|
| 592 |
############################################
|