unified-model-instruct / modeling_unified.py
KitsuVp's picture
Upload 2 files
b9eef49 verified
# ====================================================================
# modeling_unified.py
# ====================================================================
"""
Unified Language Model with GPAS + LNS Integration + xIELU Activation + CoLA (Linear Only) + LaX + Weight Tying + Canon Layers (A+C Only)
MIGRATED TO HUGGINGFACE TRANSFORMERS - FINAL VERSION WITH ALL FIXES + CORRECTED LaX IMPLEMENTATION
UPDATED: Standard Transformer with advanced variance control, parameter efficiency, Canon horizontal information flow, and WORKING LaX Inter-Layer
Combines advanced Transformer architecture with CORRECTED variance control mechanisms,
advanced variance control via GPAS and LNS, xIELU activation function, FIXED LaX integration, and Canon Layers (A+C only)
Based on LLaMA 3 architecture with 30M parameters
MIGRATION TO HUGGINGFACE - FINAL FIXED VERSION + LaX CORRECTION:
==============================================================
1. **HUGGINGFACE INTEGRATION**: Migrado de PyTorch Lightning a Transformers v4.53.3
2. **UPDATED API**: processing_class en lugar de tokenizer (deprecated)
3. **UPDATED COMPUTE_LOSS**: Método actualizado con num_items_in_batch parameter
4. **FIXED LOGGING**: Corregido self.log() syntax según documentación oficial HF
5. **RESTORED PAD HANDLING**: pad_token_id → -100 conversion for CrossEntropyLoss (from original code)
6. **NATIVE TORCH COMPILE**: Moved to TrainingArguments (torch_compile=True)
7. **FIXED WEIGHT TYING**: Corrected _tied_weights_keys as class attribute (HF standard)
8. **VALIDATION DIAGNOSTIC**: Added simple method to diagnose validation loss issues
9. **CUSTOM CONFIGURATION**: PretrainedConfig personalizada con todos los parámetros
10. **PRETRAINED MODEL**: Hereda de PreTrainedModel para compatibilidad completa
11. **MAINTAINED OPTIMIZERS**: Muon + AdamW híbrido preservado
12. **MAINTAINED PRECISION**: bf16-true preservado
13. **MAINTAINED TRAINING**: Custom Trainer con todas las métricas y logging
14. **MAINTAINED ARCHITECTURE**: Toda la arquitectura personalizada preservada
15. **AUTO TOKENIZER**: Integración completa con AutoTokenizer dinámico
16. **AUTOCLASS SUPPORT**: Registro completo para AutoConfig y AutoModel
17. **✅ FIXED LaX**: Implementación correcta Inter-Layer con Linear Gate funcional
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
PreTrainedModel,
)
import math
import os
from typing import Optional, Tuple, Dict, Any, cast, List
from flash_attn import flash_attn_func
import numpy as np
# ✅ ABSOLUTE IMPORT - No relative imports for Hub compatibility
from configuration_unified import UnifiedModelConfig
# Fix tokenizer parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_float32_matmul_precision('high')
def init_cola_components(A: nn.Linear, B: nn.Linear):
nn.init.kaiming_normal_(A.weight, mode='fan_in', nonlinearity='relu')
nn.init.xavier_normal_(B.weight, gain=0.8)
if B.bias is not None:
nn.init.zeros_(B.bias)
def init_embedding(embedding: nn.Embedding):
nn.init.normal_(embedding.weight, mean=0.0, std=0.02)
class CanonLayer(nn.Module):
def __init__(self, hidden_dim: int, kernel_size: int = 4):
"""
Canon layer using a 1D causal convolution with residual connection.
"""
super().__init__()
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
# Use causal convolution with explicit initialization
self.causal_conv1d = nn.Conv1d(
in_channels=hidden_dim,
out_channels=hidden_dim,
kernel_size=kernel_size,
groups=hidden_dim, # Depthwise convolution
padding=0, # No automatic padding
bias=True
)
# Initialize weights more conservatively (as per paper)
nn.init.zeros_(self.causal_conv1d.weight)
nn.init.zeros_(self.causal_conv1d.bias)
def forward(self, h: torch.Tensor) -> torch.Tensor:
"""
Applies the Canon layer transformation with causal masking.
"""
# Conv1d expects input shape (batch_size, channels, sequence_length)
h_permuted = h.permute(0, 2, 1) # (batch, hidden_dim, seq_len)
# Add padding of (kernel_size - 1) only to the left side
padding = self.kernel_size - 1
h_padded = F.pad(h_permuted, (padding, 0))
# Apply causal convolution
conv_out = self.causal_conv1d(h_padded)
# Permute back to the original shape
conv_out_permuted = conv_out.permute(0, 2, 1)
# Add the residual connection
output = h + conv_out_permuted
return output
class CoLA_Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu, bias: bool = True):
super().__init__()
if rank is None:
rank = in_features // 4
self.rank = rank
self.activation = activation
self.A = nn.Linear(in_features, rank, bias=False)
self.B = nn.Linear(rank, out_features, bias=bias)
init_cola_components(self.A, self.B)
def forward(self, x: torch.Tensor, prev_latent: Optional[torch.Tensor] = None, lax_beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass with optional LaX Inter-Layer integration.
Args:
x: Input tensor
prev_latent: Previous latent from same module type in previous layer (for LaX)
lax_beta: Linear gate parameter (scalar) for LaX
Returns:
Tuple of (output, current_latent) where current_latent can be used for next layer
"""
# Standard CoLA forward: A -> activation
latent = self.A(x)
latent_activated = self.activation(latent)
# Apply LaX Inter-Layer if previous latent exists
if prev_latent is not None and lax_beta is not None and prev_latent.shape == latent_activated.shape:
# Linear Gate: h_i = h_i + β * h_{i-1}
latent_activated = latent_activated + lax_beta * prev_latent
# B projection
output = self.B(latent_activated)
return output, latent_activated
class LayerNormScaling(nn.Module):
def __init__(self, layer_depth: int):
super().__init__()
if layer_depth < 1:
raise ValueError(f"layer_depth debe ser ≥ 1, got {layer_depth}")
self.layer_depth = layer_depth
self.scaling_factor = 1.0 / math.sqrt(float(layer_depth))
def forward(self, normalized_input: torch.Tensor) -> torch.Tensor:
return normalized_input * self.scaling_factor
class GPAS(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
self.alpha = nn.Parameter(torch.zeros(1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_detached = x.detach()
scaled_component = F.silu(self.alpha) * x_detached
x_scaled = x - scaled_component
return x_scaled
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[-2]
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class XIELU(nn.Module):
def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta: float = 0.5):
super().__init__()
self.beta = beta
self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1))
self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1))
self.register_buffer('eps', torch.tensor(-1e-6))
def forward(self, x: torch.Tensor) -> torch.Tensor:
alpha_p = F.softplus(self.alpha_p)
alpha_n = self.beta + F.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
alpha_n * torch.expm1(torch.clamp(x, min=self.eps)) - alpha_n * x + self.beta * x
)
class StandardMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, config=None, layer_idx: int = 0):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.config = config
self.layer_idx = layer_idx
self.up_proj = CoLA_Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = CoLA_Linear(intermediate_size, hidden_size, bias=False)
if config is not None:
self.activation = XIELU(
alpha_p_init=config.xielu_alpha_p_init,
alpha_n_init=config.xielu_alpha_n_init,
beta=config.xielu_beta
)
else:
self.activation = XIELU(alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
# LaX Linear Gate parameters (β scalars)
if config is not None and config.lax_enabled:
self.lax_beta_up = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2
self.lax_beta_down = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2
else:
self.lax_beta_up = None
self.lax_beta_down = None
def forward(self, x: torch.Tensor, lax_buffer: Optional[Dict] = None) -> torch.Tensor:
# LaX: Get previous latents from buffer
prev_up_latent = None
prev_down_latent = None
if lax_buffer is not None and self.lax_beta_up is not None:
prev_up_latent = lax_buffer.get(('mlp_up', self.layer_idx - 1))
prev_down_latent = lax_buffer.get(('mlp_down', self.layer_idx - 1))
# Up projection with LaX
intermediate, up_latent = self.up_proj(x, prev_up_latent, self.lax_beta_up)
# Store current up latent for next layer
if lax_buffer is not None:
lax_buffer[('mlp_up', self.layer_idx)] = up_latent.clone()
# Activation and dropout
activated = self.activation(intermediate)
activated = self.dropout(activated)
# Down projection with LaX
output, down_latent = self.down_proj(activated, prev_down_latent, self.lax_beta_down)
# Store current down latent for next layer
if lax_buffer is not None:
lax_buffer[('mlp_down', self.layer_idx)] = down_latent.clone()
return output
class GroupedQueryAttention(nn.Module):
def __init__(self, config, layer_idx: int = 0):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# FANFormer components
self.fanformer_p = getattr(config, 'fanformer_p', 0.15)
self.d_periodic = int(self.hidden_size * self.fanformer_p)
self.d_standard = self.hidden_size - 2 * self.d_periodic
assert self.d_standard > 0, \
f"fanformer_p={self.fanformer_p} is too high. d_standard={self.d_standard} must be > 0"
self.fan_w_p = CoLA_Linear(self.hidden_size, self.d_periodic, bias=False)
self.fan_w_p_bar = CoLA_Linear(self.hidden_size, self.d_standard, bias=False)
self.q_proj = CoLA_Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = CoLA_Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.v_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta
)
# LaX Linear Gate parameters (β scalars) - NO o_proj según plan
if config.lax_enabled:
self.lax_beta_q = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2
self.lax_beta_k = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2
self.lax_beta_v = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2
else:
self.lax_beta_q = None
self.lax_beta_k = None
self.lax_beta_v = None
def _fan_layer_prime(self, x: torch.Tensor) -> torch.Tensor:
periodic_proj, _ = self.fan_w_p(x)
standard_proj, _ = self.fan_w_p_bar(x)
cos_component = torch.cos(periodic_proj)
sin_component = torch.sin(periodic_proj)
x_f = torch.cat([cos_component, sin_component, standard_proj], dim=-1)
return x_f
def _compute_flash_attention(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
seq_len: int,
position_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size = query_states.shape[0]
q_rope = query_states.transpose(1, 2)
k_rope = key_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=seq_len)
q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin, position_ids)
query_states = q_rope.transpose(1, 2)
key_states = k_rope.transpose(1, 2)
from flash_attn import flash_attn_func
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=self.config.attention_dropout if self.training else 0.0,
causal=True,
)
return attn_output
def forward(self, hidden_states, position_ids=None, attention_mask=None, lax_buffer: Optional[Dict] = None):
batch_size, seq_len, _ = hidden_states.shape
enhanced_input = self._fan_layer_prime(hidden_states)
# LaX: Get previous latents from buffer
prev_q_latent = None
prev_k_latent = None
prev_v_latent = None
if lax_buffer is not None and self.lax_beta_q is not None:
prev_q_latent = lax_buffer.get(('attn_q', self.layer_idx - 1))
prev_k_latent = lax_buffer.get(('attn_k', self.layer_idx - 1))
prev_v_latent = lax_buffer.get(('attn_v', self.layer_idx - 1))
# Q/K/V projections with LaX
query_states, q_latent = self.q_proj(enhanced_input, prev_q_latent, self.lax_beta_q)
key_states, k_latent = self.k_proj(enhanced_input, prev_k_latent, self.lax_beta_k)
value_states, v_latent = self.v_proj(enhanced_input, prev_v_latent, self.lax_beta_v)
# Store current latents for next layer
if lax_buffer is not None:
lax_buffer[('attn_q', self.layer_idx)] = q_latent.clone()
lax_buffer[('attn_k', self.layer_idx)] = k_latent.clone()
lax_buffer[('attn_v', self.layer_idx)] = v_latent.clone()
query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
q_flat = query_states.reshape(-1, self.head_dim)
k_flat = key_states.reshape(-1, self.head_dim)
v_flat = value_states.reshape(-1, self.head_dim)
q_normalized = self.q_norm(q_flat)
k_normalized = self.k_norm(k_flat)
v_normalized = self.v_norm(v_flat)
query_states = q_normalized.view(batch_size, seq_len, self.num_heads, self.head_dim)
key_states = k_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
value_states = v_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
attn_output = self._compute_flash_attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
seq_len=seq_len,
position_ids=position_ids
)
attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
# O projection WITHOUT LaX (según plan)
output, _ = self.o_proj(attn_output)
return output
class DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx < 0:
raise ValueError(f"layer_idx debe ser >= 0, got {layer_idx}")
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = GroupedQueryAttention(config, layer_idx)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = StandardMLP(
config.hidden_size,
config.intermediate_size,
config.mlp_dropout,
config,
layer_idx
)
self.dropout_output = nn.Dropout(0.01)
self.lns_attention = LayerNormScaling(layer_depth=layer_idx + 1)
self.lns_mlp = LayerNormScaling(layer_depth=layer_idx + 1)
self.gpas_attention = GPAS(config.hidden_size)
self.gpas_mlp = GPAS(config.hidden_size)
# Canon layers (A+C only)
# Canon-A: Before attention block
if config.canon_enabled and config.canon_a_enabled:
self.canon_a = CanonLayer(config.hidden_size, config.canon_kernel_size)
else:
self.canon_a = None
# Canon-C: Before MLP block
if config.canon_enabled and config.canon_c_enabled:
self.canon_c = CanonLayer(config.hidden_size, config.canon_kernel_size)
else:
self.canon_c = None
def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, lax_buffer: Optional[Dict] = None) -> torch.Tensor:
residual = hidden_states
# Apply Canon-A before attention
if self.canon_a is not None:
hidden_states = self.canon_a(hidden_states)
attention_input = self.input_layernorm(hidden_states)
attention_input = self.lns_attention(attention_input)
attention_output = self.self_attn(attention_input, position_ids, attention_mask, lax_buffer)
hidden_states = residual + attention_output
hidden_states = self.gpas_attention(hidden_states)
hidden_states = self.dropout_output(hidden_states)
residual = hidden_states
# Apply Canon-C before MLP
if self.canon_c is not None:
hidden_states = self.canon_c(hidden_states)
mlp_input = self.post_attention_layernorm(hidden_states)
mlp_input = self.lns_mlp(mlp_input)
mlp_output = self.mlp(mlp_input, lax_buffer)
hidden_states = residual + mlp_output
hidden_states = self.gpas_mlp(hidden_states)
hidden_states = self.dropout_output(hidden_states)
return hidden_states
class UnifiedModel(PreTrainedModel):
"""
UnifiedModel that inherits from PreTrainedModel for full HuggingFace compatibility.
With AutoClass support for seamless Hub integration.
"""
config_class = UnifiedModelConfig
# ✅ FIXED: _tied_weights_keys as class attribute (HuggingFace standard)
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: UnifiedModelConfig):
super().__init__(config)
self.config = config
if config.vocab_size is None:
raise ValueError("config.vocab_size must be set from tokenizer before model initialization")
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.embedding_dropout = nn.Dropout(config.embedding_dropout)
self.output_dropout = nn.Dropout(0.05)
# Create lm_head for output projections
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.layers = nn.ModuleList()
for i in range(config.num_hidden_layers):
self.layers.append(DecoderLayer(config, i))
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Initialize weights
self.post_init()
self._print_configuration()
def tie_weights(self):
"""
✅ FIXED: Simplified tie_weights method following HuggingFace standard.
Tie the word embeddings and the output layer.
This is called automatically if config.tie_word_embeddings is True.
"""
if self.config.tie_word_embeddings:
print("🔗 Applying weight tying: lm_head.weight = embed_tokens.weight")
self.lm_head.weight = self.embed_tokens.weight
print("✅ Weight tying successful: Parameters are properly shared")
def _init_weights(self, module):
"""Initialize weights following the custom initialization scheme."""
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02, a=-0.04, b=0.04)
elif isinstance(module, CoLA_Linear):
pass # CoLA_Linear has its own initialization
def _print_configuration(self):
# Conteo ingenuo de todos los parámetros registrados
total_params_naive = sum(p.numel() for p in self.parameters())
# Conteo inteligente considerando weight tying
total_params_actual = total_params_naive
vocab_params = self.config.vocab_size * self.config.hidden_size
tied_savings = 0
# ✅ CORRECCIÓN: Detectar y ajustar por weight tying real
if self.config.tie_word_embeddings:
# Verificar si los tensors están realmente atados en memoria
embed_weight = self.embed_tokens.weight
lm_head_weight = self.lm_head.weight
if embed_weight is lm_head_weight:
# Los tensors son idénticos - restar la duplicación
tied_savings = vocab_params
total_params_actual = total_params_naive - tied_savings
else:
# Weight tying configurado pero no aplicado aún
tied_savings = 0
# Cálculos de optimización existentes
total_linear_params = 0
total_cola_params = 0
canon_params = 0
lax_params = 0
for name, module in self.named_modules():
if isinstance(module, CoLA_Linear):
in_features = module.A.in_features
out_features = module.B.out_features
rank = module.rank
standard_params = in_features * out_features
cola_params = (in_features * rank) + (rank * out_features)
total_linear_params += standard_params
total_cola_params += cola_params
elif isinstance(module, CanonLayer):
# Canon layer parameters: depthwise conv1d + bias
canon_layer_params = module.hidden_dim * module.kernel_size + module.hidden_dim
canon_params += canon_layer_params
elif hasattr(module, 'lax_beta_q') and module.lax_beta_q is not None:
# Count LaX β parameters
lax_params += 3 # q, k, v
elif hasattr(module, 'lax_beta_up') and module.lax_beta_up is not None:
# Count LaX β parameters
lax_params += 2 # up, down
cola_reduction = ((total_linear_params - total_cola_params) / total_linear_params) * 100 if total_linear_params > 0 else 0
canon_overhead = (canon_params / total_params_actual) * 100 if total_params_actual > 0 else 0
lax_overhead = (lax_params / total_params_actual) * 100 if total_params_actual > 0 else 0
print(f"\n📊 UNIFIED Model + GPAS + LNS + xIELU + CoLA (Linear Only) + LaX + Canon (A+C) + Weight Tying:")
# ✅ MEJORADO: Mostrar conteo real vs ingenuo para transparencia
if self.config.tie_word_embeddings and tied_savings > 0:
print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M (effective)")
print(f"📊 Parameter Breakdown:")
print(f" • Naive count: {total_params_naive/1e6:.2f}M (all registered params)")
print(f" • Actual count: {total_params_actual/1e6:.2f}M (after weight tying)")
print(f" • Weight tying savings: {tied_savings/1e6:.2f}M ({tied_savings/total_params_naive*100:.1f}%)")
else:
print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M")
print(f"📚 DYNAMIC Vocabulary Size: {self.config.vocab_size} (from tokenizer)")
print(f"🔗 ✅ PROPER Weight Tying: {'ENABLED' if self.config.tie_word_embeddings else 'DISABLED'}")
# ✅ CORRECCIÓN: Mostrar estado real del weight tying
if self.config.tie_word_embeddings:
if tied_savings > 0:
print(f"💾 Weight Tying Status: ✅ ACTIVE (tensors are shared in memory)")
else:
print(f"💾 Weight Tying Status: ⏳ CONFIGURED (will be applied during post_init)")
print(f"🚀 ACTIVATION: xIELU (αp_init={self.config.xielu_alpha_p_init}, αn_init={self.config.xielu_alpha_n_init}, β={self.config.xielu_beta})")
print(f"🔄 UPGRADE: SwiGLU → StandardMLP + xIELU (better efficiency & adaptability)")
print(f"🗜️ CoLA Integration: {cola_reduction:.1f}% parameter reduction in internal projections")
print(f"🔀 LaX Enabled: {'YES' if self.config.lax_enabled else 'NO'} ✅ WORKING Inter-Layer (Linear Gate)")
if self.config.lax_enabled:
print(f" • LaX Method: Inter-Layer with Linear Gate (β scalars)")
print(f" • LaX Applied to: q_proj, k_proj, v_proj, up_proj, down_proj (NOT o_proj)")
print(f" • LaX Parameters: {lax_params} β scalars ({lax_overhead:.6f}% overhead)")
print(f" • LaX Initialization: β=0.0 (conservative start)")
print(f"🎼 Canon Layers Enabled: {'YES' if self.config.canon_enabled else 'NO'} (A+C ONLY)")
if self.config.canon_enabled:
print(f" • Canon-A (Before Attention): {'✅' if self.config.canon_a_enabled else '❌'}")
print(f" • Canon-B (Inside Attention): ❌ PERMANENTLY DISABLED")
print(f" • Canon-C (Before MLP): {'✅' if self.config.canon_c_enabled else '❌'}")
print(f" • Canon-D (Inside MLP): ❌ PERMANENTLY DISABLED")
print(f" • Canon Kernel Size: {self.config.canon_kernel_size}")
print(f" • Canon Parameters Overhead: {canon_overhead:.3f}% ({canon_params/1e3:.1f}K params)")
print(f"⚡ GPAS Enabled: ALWAYS (Dynamic variance control)")
print(f"📏 LNS Enabled: ALWAYS (Static depth scaling)")
print(f"🔧 Variance Control: Triple-level (LNS + GPAS + Canon A+C) ALWAYS")
print(f"🔗 Residual Connections: STANDARD + HORIZONTAL (Canon A+C only)")
print(f"🧹 CLEAN: Standard transformer architecture - CrossEntropyLoss manages PAD naturally")
print(f"⚡ FlashAttention: Scaled Dot-Product Attention with GQA + automatic causal masking")
print(f"🎯 TOKENIZER AGNOSTIC: Dynamic vocab_size and pad_token_id")
print(f"🎯 SIMPLIFIED: CoLA Linear Only + Canon A+C Only = Better performance & less overhead")
print(f"🔗 ✅ FIXED Weight Tying: _tied_weights_keys as class attribute (HF standard)")
print(f"🎼 Canon A+C BENEFITS: Strategic horizontal information flow with minimal parameters")
print(f"🔀 ✅ FIXED LaX: Functional Inter-Layer with ephemeral buffer (no broken reset)")
print(f"🤗 HUGGINGFACE COMPATIBLE: Full PreTrainedModel integration v4.53.3")
print(f"⚡ ✅ NATIVE TORCH COMPILE: Will be handled by TrainingArguments")
print(f"🚀 ✅ AUTOCLASS SUPPORT: Compatible with AutoConfig.from_pretrained() and AutoModel.from_pretrained()")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs
):
batch_size, seq_len = input_ids.shape
# ✅ LaX: Create ephemeral buffer for this forward pass
lax_buffer = {} if self.config.lax_enabled else None
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.embedding_dropout(hidden_states)
for layer in self.layers:
hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=attention_mask, lax_buffer=lax_buffer)
hidden_states = self.norm(hidden_states)
hidden_states = self.output_dropout(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
# ✅ RESTORED: Change pad tokens to -100 so CrossEntropyLoss ignores them (from original code)
if self.config.pad_token_id is not None:
shift_labels[shift_labels == self.config.pad_token_id] = -100
loss = loss_fct(shift_logits, shift_labels)
# ✅ LaX buffer is automatically cleaned up (ephemeral, goes out of scope)
# Return in HuggingFace format
from transformers.modeling_outputs import CausalLMOutputWithPast
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 1.0,
top_p: float = 0.9,
top_k: Optional[int] = None,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
**kwargs
) -> torch.Tensor:
"""
Generate sequences using the model.
Compatible with AutoModelForCausalLM interface.
"""
# Set default token IDs
if pad_token_id is None:
pad_token_id = self.config.pad_token_id
if eos_token_id is None:
eos_token_id = self.config.eos_token_id
batch_size = input_ids.shape[0]
device = input_ids.device
generated = input_ids.clone()
for _ in range(max_new_tokens):
# Forward pass (LaX buffer is created fresh each time)
outputs = self.forward(generated)
logits = outputs.logits
# Get the logits for the last token
next_token_logits = logits[:, -1, :]
if do_sample:
# Apply temperature
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Apply top-k filtering
if top_k is not None:
values, indices = torch.topk(next_token_logits, top_k)
next_token_logits[next_token_logits < values[:, [-1]]] = -float('inf')
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = -float('inf')
# Sample from the filtered distribution
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Append the new token
generated = torch.cat([generated, next_token], dim=1)
# Check for EOS token
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
# ✅ AUTOCLASS REGISTRATION - Required for Hub compatibility
# Register the configuration and model for AutoClass support
AutoConfig.register("unified_model", UnifiedModelConfig)
AutoModel.register(UnifiedModelConfig, UnifiedModel)
AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)
print("🚀 ✅ AUTOCLASS REGISTRATION COMPLETE:")
print(" • AutoConfig.register('unified_model', UnifiedModelConfig)")
print(" • AutoModel.register(UnifiedModelConfig, UnifiedModel)")
print(" • AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)")
print(" • Users can now load with: AutoModel.from_pretrained('your-repo', trust_remote_code=True)")