KitsuVp commited on
Commit
acb91a9
·
verified ·
1 Parent(s): 05f0438

Initial FanFormer checkpoint with architecture and README

Browse files
Files changed (2) hide show
  1. model.safetensors +1 -1
  2. model_architecture.py +4 -4
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:39fb924d47a28f1b2456828b984b26bc5c05382f02b7fc878e067b729edfd88a
3
  size 331511816
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4f2f4437bfc76ca3cb6196b6e0f32087eca273cb2a54656973e1727095f7c4f
3
  size 331511816
model_architecture.py CHANGED
@@ -452,8 +452,8 @@ class FANformerMultiheadAttention(nn.Module):
452
  v_unpadded = v_unpadded.reshape(-1, H, D)
453
 
454
  # Normalizar vectores Q y K para mejorar estabilidad numérica
455
- q_norm = F.normalize(q_unpadded, p=2, dim=-1)
456
- k_norm = F.normalize(k_unpadded, p=2, dim=-1)
457
 
458
  # Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
459
  # s = self.ssmax_scale.view(1, H, 1)
@@ -488,8 +488,8 @@ class FANformerMultiheadAttention(nn.Module):
488
  mask: Optional[torch.Tensor] = None, # Mantener mask opcional
489
  is_causal: bool = False) -> torch.Tensor:
490
  # Normalizar vectores Q y K para mejorar estabilidad numérica
491
- q_norm = F.normalize(q, p=2, dim=-1)
492
- k_norm = F.normalize(k, p=2, dim=-1)
493
 
494
  # Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
495
  # s = self.ssmax_scale.view(-1, 1, 1)
 
452
  v_unpadded = v_unpadded.reshape(-1, H, D)
453
 
454
  # Normalizar vectores Q y K para mejorar estabilidad numérica
455
+ q_norm = F.normalize(q_unpadded, p=2, dim=-1).to(torch.bfloat16)
456
+ k_norm = F.normalize(k_unpadded, p=2, dim=-1).to(torch.bfloat16)
457
 
458
  # Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
459
  # s = self.ssmax_scale.view(1, H, 1)
 
488
  mask: Optional[torch.Tensor] = None, # Mantener mask opcional
489
  is_causal: bool = False) -> torch.Tensor:
490
  # Normalizar vectores Q y K para mejorar estabilidad numérica
491
+ q_norm = F.normalize(q, p=2, dim=-1).to(torch.bfloat16)
492
+ k_norm = F.normalize(k, p=2, dim=-1).to(torch.bfloat16)
493
 
494
  # Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
495
  # s = self.ssmax_scale.view(-1, 1, 1)