Initial FanFormer checkpoint with architecture and README
Browse files- model.safetensors +1 -1
- model_architecture.py +4 -4
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 331511816
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4f2f4437bfc76ca3cb6196b6e0f32087eca273cb2a54656973e1727095f7c4f
|
| 3 |
size 331511816
|
model_architecture.py
CHANGED
|
@@ -452,8 +452,8 @@ class FANformerMultiheadAttention(nn.Module):
|
|
| 452 |
v_unpadded = v_unpadded.reshape(-1, H, D)
|
| 453 |
|
| 454 |
# Normalizar vectores Q y K para mejorar estabilidad numérica
|
| 455 |
-
q_norm = F.normalize(q_unpadded, p=2, dim=-1)
|
| 456 |
-
k_norm = F.normalize(k_unpadded, p=2, dim=-1)
|
| 457 |
|
| 458 |
# Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
|
| 459 |
# s = self.ssmax_scale.view(1, H, 1)
|
|
@@ -488,8 +488,8 @@ class FANformerMultiheadAttention(nn.Module):
|
|
| 488 |
mask: Optional[torch.Tensor] = None, # Mantener mask opcional
|
| 489 |
is_causal: bool = False) -> torch.Tensor:
|
| 490 |
# Normalizar vectores Q y K para mejorar estabilidad numérica
|
| 491 |
-
q_norm = F.normalize(q, p=2, dim=-1)
|
| 492 |
-
k_norm = F.normalize(k, p=2, dim=-1)
|
| 493 |
|
| 494 |
# Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
|
| 495 |
# s = self.ssmax_scale.view(-1, 1, 1)
|
|
|
|
| 452 |
v_unpadded = v_unpadded.reshape(-1, H, D)
|
| 453 |
|
| 454 |
# Normalizar vectores Q y K para mejorar estabilidad numérica
|
| 455 |
+
q_norm = F.normalize(q_unpadded, p=2, dim=-1).to(torch.bfloat16)
|
| 456 |
+
k_norm = F.normalize(k_unpadded, p=2, dim=-1).to(torch.bfloat16)
|
| 457 |
|
| 458 |
# Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
|
| 459 |
# s = self.ssmax_scale.view(1, H, 1)
|
|
|
|
| 488 |
mask: Optional[torch.Tensor] = None, # Mantener mask opcional
|
| 489 |
is_causal: bool = False) -> torch.Tensor:
|
| 490 |
# Normalizar vectores Q y K para mejorar estabilidad numérica
|
| 491 |
+
q_norm = F.normalize(q, p=2, dim=-1).to(torch.bfloat16)
|
| 492 |
+
k_norm = F.normalize(k, p=2, dim=-1).to(torch.bfloat16)
|
| 493 |
|
| 494 |
# Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
|
| 495 |
# s = self.ssmax_scale.view(-1, 1, 1)
|