KitsuVp commited on
Commit
05f0438
verified
1 Parent(s): 34c6d09

Initial FanFormer checkpoint with architecture and README

Browse files
Files changed (2) hide show
  1. model.safetensors +2 -2
  2. model_architecture.py +89 -74
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:607ebeab78df2738e7039a379d2e0b022cdf42f7ff9b675f38be06d91f72a160
3
- size 331514552
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39fb924d47a28f1b2456828b984b26bc5c05382f02b7fc878e067b729edfd88a
3
+ size 331511816
model_architecture.py CHANGED
@@ -353,6 +353,8 @@ class FANformerMultiheadAttention(nn.Module):
353
  Implementaci贸n de la atenci贸n multi-cabeza con FANformer.
354
  Aplica normalizaci贸n a Q, K, V individualmente y utiliza unpadding para mejorar el rendimiento.
355
  Incorpora modelado de periodicidad a trav茅s de proyecciones CoLA_FAN.
 
 
356
  """
357
  def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.12, use_rope: bool = True,
358
  layer_index: int = 1, max_seq_len: int = 512, p: float = 0.15,
@@ -366,21 +368,26 @@ class FANformerMultiheadAttention(nn.Module):
366
  self.layer_index = layer_index
367
  self.use_pre_norm = use_pre_norm
368
  self.p = p # Proporci贸n para periodicidad
369
-
370
  if embed_dim % num_heads != 0:
371
  raise ValueError("embed_dim debe ser divisible por num_heads")
372
-
373
  self.head_dim = embed_dim // num_heads
374
  self.use_rope = use_rope
375
-
376
  if num_gqa_groups is None:
377
  num_gqa_groups = num_heads
 
 
 
 
378
 
379
  try:
380
  from flash_attn import flash_attn_func, flash_attn_varlen_func
381
  self.flash_attn_func = flash_attn_func
382
  self.flash_attn_varlen_func = flash_attn_varlen_func
383
  except ImportError as e:
 
384
  raise ImportError(f"Error al inicializar FlashAttention: {e}")
385
 
386
  # Para el unpadding
@@ -389,20 +396,21 @@ class FANformerMultiheadAttention(nn.Module):
389
  self.unpad_input = unpad_input
390
  self.pad_input = pad_input
391
  except ImportError as e:
 
392
  raise ImportError(f"Error al importar funciones de padding: {e}")
393
 
394
- # Inicializaci贸n de par谩metros de escala
395
- self.ssmax_scale = nn.Parameter(torch.ones(num_heads, dtype=torch.bfloat16) * 0.168)
396
- nn.init.uniform_(self.ssmax_scale, a=0.166, b=0.170)
397
- self.register_buffer('seq_scale', torch.log(torch.tensor(max_seq_len, dtype=torch.bfloat16)))
398
-
399
  # Capas de normalizaci贸n para la entrada (Pre-Norm en primer bloque o QKV-Norm para los dem谩s)
400
  self.norm = nn.RMSNorm(embed_dim, eps=1e-5)
401
-
402
  # Capas de dropout (simplificadas)
403
- self.attention_dropout = progressive_dropout(dropout, depth=1)
404
  # Eliminado: self.projection_dropout = progressive_dropout(dropout * 1.1, depth=1)
405
- self.output_dropout = progressive_dropout(dropout, depth=1)
406
 
407
  # Proyecciones para Q, K, V usando GQAFANLinear (implementaci贸n FANformer)
408
  self.Wq = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
@@ -413,165 +421,172 @@ class FANformerMultiheadAttention(nn.Module):
413
  self.out_proj = CoLA_Linear(embed_dim, embed_dim, rank=embed_dim // 4)
414
 
415
  def scaled_dot_product_attention_flash_unpadded(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
416
- attention_mask: Optional[torch.Tensor] = None,
417
  is_causal: bool = False) -> torch.Tensor:
418
  B, H, S, D = q.shape # batch, heads, sequence length, head dimension
419
-
 
420
  if attention_mask is None:
421
  # Si no hay m谩scara de atenci贸n, usamos la versi贸n regular
422
  return self.scaled_dot_product_attention_flash(q, k, v, mask=None, is_causal=is_causal)
423
-
424
  # Convertir las tensiones a [B, S, H, D] para unpad_input
425
  q_unpad = q.permute(0, 2, 1, 3) # [B, S, H, D]
426
  k_unpad = k.permute(0, 2, 1, 3) # [B, S, H, D]
427
  v_unpad = v.permute(0, 2, 1, 3) # [B, S, H, D]
428
-
429
  # Preparar m谩scara: convertir a bool si es necesario
 
430
  if attention_mask.dtype != torch.bool:
431
  attention_mask = attention_mask.bool()
432
-
433
  # Hacer unpadding de los tensores
 
434
  q_unpadded, indices_q, cu_seqlens_q, max_seqlen_q, _ = self.unpad_input(q_unpad, attention_mask)
435
  k_unpadded, indices_k, cu_seqlens_k, max_seqlen_k, _ = self.unpad_input(k_unpad, attention_mask)
436
  v_unpadded, _, _, _, _ = self.unpad_input(v_unpad, attention_mask)
437
-
438
  # Reacomodar para flash_attn_varlen_func: [Total, H, D]
439
  q_unpadded = q_unpadded.reshape(-1, H, D)
440
  k_unpadded = k_unpadded.reshape(-1, H, D)
441
  v_unpadded = v_unpadded.reshape(-1, H, D)
442
-
443
  # Normalizar vectores Q y K para mejorar estabilidad num茅rica
444
- q_norm = F.normalize(q_unpadded, p=2, dim=-1).to(torch.bfloat16)
445
- k_norm = F.normalize(k_unpadded, p=2, dim=-1).to(torch.bfloat16)
446
-
447
- # Ajustar q con factor de escala
448
- s = self.ssmax_scale.view(1, H, 1)
449
- q_adjusted = q_norm * (self.seq_scale * s)
450
-
451
- # Factor de escala para softmax
452
- softmax_scale = 1.0 / math.sqrt(D)
453
-
454
  try:
455
- # Usar flash attention sin padding
456
  output_unpadded = self.flash_attn_varlen_func(
457
- q_adjusted, k_norm, v_unpadded,
458
  cu_seqlens_q, cu_seqlens_k,
459
  max_seqlen_q, max_seqlen_k,
460
  dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
461
- softmax_scale=softmax_scale,
462
  causal=is_causal
463
  )
464
-
465
  # Volver a aplicar padding
466
  output_padded = self.pad_input(output_unpadded, indices_q, B, S)
467
-
468
  # Reorganizar a [B, H, S, D]
469
  output = output_padded.reshape(B, S, H, D).permute(0, 2, 1, 3)
470
-
471
  return output
472
-
473
  except Exception as e:
 
474
  raise RuntimeError(f"Error en flash_attn_varlen_func: {e}")
475
 
476
  def scaled_dot_product_attention_flash(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
477
- mask: Optional[torch.Tensor] = None,
478
  is_causal: bool = False) -> torch.Tensor:
479
  # Normalizar vectores Q y K para mejorar estabilidad num茅rica
480
- q_norm = F.normalize(q, p=2, dim=-1).to(torch.bfloat16)
481
- k_norm = F.normalize(k, p=2, dim=-1).to(torch.bfloat16)
482
-
483
- # Ajustar q con factor de escala
484
- s = self.ssmax_scale.view(-1, 1, 1)
485
- q_adjusted = q_norm * (self.seq_scale * s)
486
-
487
  # Preparar tensores para Flash Attention (requiere shape [B, S, H, D])
488
- q_trans = q_adjusted.permute(0, 2, 1, 3)
489
  k_trans = k_norm.permute(0, 2, 1, 3)
490
  v_trans = v.permute(0, 2, 1, 3)
491
-
492
- # Verificar dimensiones
493
  if q_trans.size(-1) != k_trans.size(-1):
494
  raise ValueError(f"Las dimensiones de head no coinciden: q={q_trans.size(-1)}, k={k_trans.size(-1)}")
495
-
496
- # Factor de escala para softmax
497
- softmax_scale = 1.0 / math.sqrt(q_trans.size(-1))
498
-
499
  try:
500
- # Aplicar Flash Attention
501
  output = self.flash_attn_func(
502
  q_trans, k_trans, v_trans,
503
  dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
504
- softmax_scale=softmax_scale,
505
  causal=is_causal
 
506
  )
507
-
 
508
  if output is None:
509
  raise ValueError("flash_attn_func devolvi贸 None. Verifica las dimensiones y tipos de los tensores de entrada.")
510
-
511
  # Volver a la forma original
512
  output = output.permute(0, 2, 1, 3)
513
  return output
514
-
515
  except Exception as e:
 
516
  raise RuntimeError(f"Error en flash_attn_func: {e}")
517
 
518
  def forward(self, X: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal: bool = True) -> torch.Tensor:
519
  B, T, _ = X.shape
520
-
 
521
  # Implementaci贸n de HybridNorm*
522
  if self.use_pre_norm:
523
  # Primer bloque: Pre-Norm en atenci贸n
524
- X_norm = self.norm(X)
 
525
  # Proyecciones para Q, K, V con FANformer
526
  Q = self.Wq(X_norm) # [B, T, num_heads, head_dim]
527
  K = self.Wk(X_norm) # [B, T, num_heads, head_dim]
528
  V = self.Wv(X_norm) # [B, T, num_heads, head_dim]
529
  else:
530
  # Otros bloques: QKV-Norm
531
- Q = self.Wq(self.norm(X)) # [B, T, num_heads, head_dim]
532
- K = self.Wk(self.norm(X)) # [B, T, num_heads, head_dim]
533
- V = self.Wv(self.norm(X)) # [B, T, num_heads, head_dim]
534
-
 
535
  # Permutar a formato [B, num_heads, T, head_dim]
536
  Q = Q.permute(0, 2, 1, 3)
537
  K = K.permute(0, 2, 1, 3)
538
  V = V.permute(0, 2, 1, 3)
539
-
540
  # Aplicar RoPE si est谩 activado
541
  if self.use_rope:
542
  Q = apply_rope_vectorized(Q)
543
  K = apply_rope_vectorized(K)
544
-
545
- # Convertir a bfloat16 para flash attention
546
  Q = Q.to(torch.bfloat16)
547
  K = K.to(torch.bfloat16)
548
  V = V.to(torch.bfloat16)
549
-
550
  # Procesar la secuencia utilizando unpadding si hay m谩scara de atenci贸n
 
551
  if attention_mask is not None:
552
  attn_output = self.scaled_dot_product_attention_flash_unpadded(
553
- Q, K, V,
554
- attention_mask=attention_mask,
555
  is_causal=causal
556
  )
557
  else:
558
  # Si no hay m谩scara, usar la versi贸n regular
559
  attn_output = self.scaled_dot_product_attention_flash(
560
- Q, K, V,
561
- mask=None,
562
  is_causal=causal
563
  )
564
-
565
- # Eliminada la aplicaci贸n redundante de dropout:
566
  # attn_output = self.attention_dropout(attn_output)
567
-
568
  # Reorganizar la salida y aplicar proyecci贸n final
569
  out = attn_output.permute(0, 2, 1, 3).contiguous()
570
  out = out.reshape(B, T, self.embed_dim)
571
  out = self.output_dropout(self.out_proj(out))
572
-
573
- return out
574
 
 
575
  ############################################
576
  # NUEVO M脫DULO: SWIGLU CON COLA (MLP)
577
  ############################################
 
353
  Implementaci贸n de la atenci贸n multi-cabeza con FANformer.
354
  Aplica normalizaci贸n a Q, K, V individualmente y utiliza unpadding para mejorar el rendimiento.
355
  Incorpora modelado de periodicidad a trav茅s de proyecciones CoLA_FAN.
356
+ [MODIFICADO] Se elimin贸 el escalado ssmax_scale y seq_scale de Q.
357
+ [MODIFICADO] Se aplica conversi贸n expl铆cita a bfloat16 *despu茅s* de las operaciones de normalizaci贸n.
358
  """
359
  def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.12, use_rope: bool = True,
360
  layer_index: int = 1, max_seq_len: int = 512, p: float = 0.15,
 
368
  self.layer_index = layer_index
369
  self.use_pre_norm = use_pre_norm
370
  self.p = p # Proporci贸n para periodicidad
371
+
372
  if embed_dim % num_heads != 0:
373
  raise ValueError("embed_dim debe ser divisible por num_heads")
374
+
375
  self.head_dim = embed_dim // num_heads
376
  self.use_rope = use_rope
377
+
378
  if num_gqa_groups is None:
379
  num_gqa_groups = num_heads
380
+ # A帽adido chequeo de divisibilidad para GQA
381
+ elif num_heads % num_gqa_groups != 0:
382
+ raise ValueError("num_heads debe ser divisible por num_gqa_groups")
383
+
384
 
385
  try:
386
  from flash_attn import flash_attn_func, flash_attn_varlen_func
387
  self.flash_attn_func = flash_attn_func
388
  self.flash_attn_varlen_func = flash_attn_varlen_func
389
  except ImportError as e:
390
+ # Mantener el comportamiento original de lanzar error si no se encuentra
391
  raise ImportError(f"Error al inicializar FlashAttention: {e}")
392
 
393
  # Para el unpadding
 
396
  self.unpad_input = unpad_input
397
  self.pad_input = pad_input
398
  except ImportError as e:
399
+ # Mantener el comportamiento original de lanzar error si no se encuentra
400
  raise ImportError(f"Error al importar funciones de padding: {e}")
401
 
402
+ # Eliminada la inicializaci贸n de par谩metros de escala ssmax_scale y seq_scale
403
+ # self.ssmax_scale = nn.Parameter(torch.ones(num_heads, dtype=torch.bfloat16) * 0.168)
404
+ # nn.init.uniform_(self.ssmax_scale, a=0.166, b=0.170)
405
+ # self.register_buffer('seq_scale', torch.log(torch.tensor(max_seq_len, dtype=torch.bfloat16)))
406
+
407
  # Capas de normalizaci贸n para la entrada (Pre-Norm en primer bloque o QKV-Norm para los dem谩s)
408
  self.norm = nn.RMSNorm(embed_dim, eps=1e-5)
409
+
410
  # Capas de dropout (simplificadas)
411
+ self.attention_dropout = progressive_dropout(dropout, depth=layer_index) # Usar layer_index
412
  # Eliminado: self.projection_dropout = progressive_dropout(dropout * 1.1, depth=1)
413
+ self.output_dropout = progressive_dropout(dropout, depth=layer_index) # Usar layer_index
414
 
415
  # Proyecciones para Q, K, V usando GQAFANLinear (implementaci贸n FANformer)
416
  self.Wq = GQAFANLinear(embed_dim, embed_dim, num_heads, num_gqa_groups, p=p)
 
421
  self.out_proj = CoLA_Linear(embed_dim, embed_dim, rank=embed_dim // 4)
422
 
423
  def scaled_dot_product_attention_flash_unpadded(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
424
+ attention_mask: Optional[torch.Tensor] = None, # Revertido a Optional
425
  is_causal: bool = False) -> torch.Tensor:
426
  B, H, S, D = q.shape # batch, heads, sequence length, head dimension
427
+
428
+ # Mantener la l贸gica original de manejo de m谩scara opcional
429
  if attention_mask is None:
430
  # Si no hay m谩scara de atenci贸n, usamos la versi贸n regular
431
  return self.scaled_dot_product_attention_flash(q, k, v, mask=None, is_causal=is_causal)
432
+
433
  # Convertir las tensiones a [B, S, H, D] para unpad_input
434
  q_unpad = q.permute(0, 2, 1, 3) # [B, S, H, D]
435
  k_unpad = k.permute(0, 2, 1, 3) # [B, S, H, D]
436
  v_unpad = v.permute(0, 2, 1, 3) # [B, S, H, D]
437
+
438
  # Preparar m谩scara: convertir a bool si es necesario
439
+ # Mantener la l贸gica original
440
  if attention_mask.dtype != torch.bool:
441
  attention_mask = attention_mask.bool()
442
+
443
  # Hacer unpadding de los tensores
444
+ # Se mantienen las salidas originales, incluyendo el quinto elemento descartado
445
  q_unpadded, indices_q, cu_seqlens_q, max_seqlen_q, _ = self.unpad_input(q_unpad, attention_mask)
446
  k_unpadded, indices_k, cu_seqlens_k, max_seqlen_k, _ = self.unpad_input(k_unpad, attention_mask)
447
  v_unpadded, _, _, _, _ = self.unpad_input(v_unpad, attention_mask)
448
+
449
  # Reacomodar para flash_attn_varlen_func: [Total, H, D]
450
  q_unpadded = q_unpadded.reshape(-1, H, D)
451
  k_unpadded = k_unpadded.reshape(-1, H, D)
452
  v_unpadded = v_unpadded.reshape(-1, H, D)
453
+
454
  # Normalizar vectores Q y K para mejorar estabilidad num茅rica
455
+ q_norm = F.normalize(q_unpadded, p=2, dim=-1)
456
+ k_norm = F.normalize(k_unpadded, p=2, dim=-1)
457
+
458
+ # Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
459
+ # s = self.ssmax_scale.view(1, H, 1)
460
+ # q_adjusted = q_norm * (self.seq_scale * s)
461
+
462
+ # Factor de escala est谩ndar para softmax
463
+
 
464
  try:
465
+ # Usar flash attention sin padding, pasando q_norm
466
  output_unpadded = self.flash_attn_varlen_func(
467
+ q_norm, k_norm, v_unpadded, # Usar q_norm directamente
468
  cu_seqlens_q, cu_seqlens_k,
469
  max_seqlen_q, max_seqlen_k,
470
  dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
471
+ softmax_scale=None, # Escala est谩ndar
472
  causal=is_causal
473
  )
474
+
475
  # Volver a aplicar padding
476
  output_padded = self.pad_input(output_unpadded, indices_q, B, S)
477
+
478
  # Reorganizar a [B, H, S, D]
479
  output = output_padded.reshape(B, S, H, D).permute(0, 2, 1, 3)
480
+
481
  return output
482
+
483
  except Exception as e:
484
+ # Mantener el manejo de errores original
485
  raise RuntimeError(f"Error en flash_attn_varlen_func: {e}")
486
 
487
  def scaled_dot_product_attention_flash(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
488
+ mask: Optional[torch.Tensor] = None, # Mantener mask opcional
489
  is_causal: bool = False) -> torch.Tensor:
490
  # Normalizar vectores Q y K para mejorar estabilidad num茅rica
491
+ q_norm = F.normalize(q, p=2, dim=-1)
492
+ k_norm = F.normalize(k, p=2, dim=-1)
493
+
494
+ # Eliminado el ajuste de q con factor de escala ssmax_scale y seq_scale
495
+ # s = self.ssmax_scale.view(-1, 1, 1)
496
+ # q_adjusted = q_norm * (self.seq_scale * s)
497
+
498
  # Preparar tensores para Flash Attention (requiere shape [B, S, H, D])
499
+ q_trans = q_norm.permute(0, 2, 1, 3) # Usar q_norm directamente
500
  k_trans = k_norm.permute(0, 2, 1, 3)
501
  v_trans = v.permute(0, 2, 1, 3)
502
+
503
+ # Mantener la verificaci贸n de dimensiones original
504
  if q_trans.size(-1) != k_trans.size(-1):
505
  raise ValueError(f"Las dimensiones de head no coinciden: q={q_trans.size(-1)}, k={k_trans.size(-1)}")
506
+
507
+ # Factor de escala est谩ndar para softmax
 
 
508
  try:
509
+ # Aplicar Flash Attention, pasando q_trans
510
  output = self.flash_attn_func(
511
  q_trans, k_trans, v_trans,
512
  dropout_p=self.attention_dropout.p, # Aplicamos dropout aqu铆
513
+ softmax_scale=None, # Escala est谩ndar
514
  causal=is_causal
515
+ # mask no se usa aqu铆
516
  )
517
+
518
+ # Mantener la verificaci贸n de salida None original
519
  if output is None:
520
  raise ValueError("flash_attn_func devolvi贸 None. Verifica las dimensiones y tipos de los tensores de entrada.")
521
+
522
  # Volver a la forma original
523
  output = output.permute(0, 2, 1, 3)
524
  return output
525
+
526
  except Exception as e:
527
+ # Mantener el manejo de errores original
528
  raise RuntimeError(f"Error en flash_attn_func: {e}")
529
 
530
  def forward(self, X: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, causal: bool = True) -> torch.Tensor:
531
  B, T, _ = X.shape
532
+ norm_func = self.norm # Referencia a la capa de normalizaci贸n
533
+
534
  # Implementaci贸n de HybridNorm*
535
  if self.use_pre_norm:
536
  # Primer bloque: Pre-Norm en atenci贸n
537
+ # Aplicar norm y luego convertir expl铆citamente a bfloat16
538
+ X_norm = norm_func(X).to(torch.bfloat16)
539
  # Proyecciones para Q, K, V con FANformer
540
  Q = self.Wq(X_norm) # [B, T, num_heads, head_dim]
541
  K = self.Wk(X_norm) # [B, T, num_heads, head_dim]
542
  V = self.Wv(X_norm) # [B, T, num_heads, head_dim]
543
  else:
544
  # Otros bloques: QKV-Norm
545
+ # Aplicar norm y convertir expl铆citamente a bfloat16 antes de cada proyecci贸n
546
+ Q = self.Wq(norm_func(X).to(torch.bfloat16))
547
+ K = self.Wk(norm_func(X).to(torch.bfloat16))
548
+ V = self.Wv(norm_func(X).to(torch.bfloat16))
549
+
550
  # Permutar a formato [B, num_heads, T, head_dim]
551
  Q = Q.permute(0, 2, 1, 3)
552
  K = K.permute(0, 2, 1, 3)
553
  V = V.permute(0, 2, 1, 3)
554
+
555
  # Aplicar RoPE si est谩 activado
556
  if self.use_rope:
557
  Q = apply_rope_vectorized(Q)
558
  K = apply_rope_vectorized(K)
559
+
560
+ # Convertir a bfloat16 para flash attention (mantener esta conversi贸n expl铆cita)
561
  Q = Q.to(torch.bfloat16)
562
  K = K.to(torch.bfloat16)
563
  V = V.to(torch.bfloat16)
564
+
565
  # Procesar la secuencia utilizando unpadding si hay m谩scara de atenci贸n
566
+ # Mantener la l贸gica original para decidir la ruta
567
  if attention_mask is not None:
568
  attn_output = self.scaled_dot_product_attention_flash_unpadded(
569
+ Q, K, V,
570
+ attention_mask=attention_mask,
571
  is_causal=causal
572
  )
573
  else:
574
  # Si no hay m谩scara, usar la versi贸n regular
575
  attn_output = self.scaled_dot_product_attention_flash(
576
+ Q, K, V,
577
+ mask=None,
578
  is_causal=causal
579
  )
580
+
581
+ # Eliminada la aplicaci贸n redundante de dropout (ya estaba eliminada)
582
  # attn_output = self.attention_dropout(attn_output)
583
+
584
  # Reorganizar la salida y aplicar proyecci贸n final
585
  out = attn_output.permute(0, 2, 1, 3).contiguous()
586
  out = out.reshape(B, T, self.embed_dim)
587
  out = self.output_dropout(self.out_proj(out))
 
 
588
 
589
+ return out
590
  ############################################
591
  # NUEVO M脫DULO: SWIGLU CON COLA (MLP)
592
  ############################################