# ==================================================================== # modeling_unified.py # ==================================================================== """ Unified Language Model with GPAS + LNS Integration + xIELU Activation + CoLA (Linear Only) + LaX + Weight Tying + Canon Layers (A+C Only) MIGRATED TO HUGGINGFACE TRANSFORMERS - FINAL VERSION WITH ALL FIXES + CORRECTED LaX IMPLEMENTATION UPDATED: Standard Transformer with advanced variance control, parameter efficiency, Canon horizontal information flow, and WORKING LaX Inter-Layer Combines advanced Transformer architecture with CORRECTED variance control mechanisms, advanced variance control via GPAS and LNS, xIELU activation function, FIXED LaX integration, and Canon Layers (A+C only) Based on LLaMA 3 architecture with 30M parameters MIGRATION TO HUGGINGFACE - FINAL FIXED VERSION + LaX CORRECTION: ============================================================== 1. **HUGGINGFACE INTEGRATION**: Migrado de PyTorch Lightning a Transformers v4.53.3 2. **UPDATED API**: processing_class en lugar de tokenizer (deprecated) 3. **UPDATED COMPUTE_LOSS**: Método actualizado con num_items_in_batch parameter 4. **FIXED LOGGING**: Corregido self.log() syntax según documentación oficial HF 5. **RESTORED PAD HANDLING**: pad_token_id → -100 conversion for CrossEntropyLoss (from original code) 6. **NATIVE TORCH COMPILE**: Moved to TrainingArguments (torch_compile=True) 7. **FIXED WEIGHT TYING**: Corrected _tied_weights_keys as class attribute (HF standard) 8. **VALIDATION DIAGNOSTIC**: Added simple method to diagnose validation loss issues 9. **CUSTOM CONFIGURATION**: PretrainedConfig personalizada con todos los parámetros 10. **PRETRAINED MODEL**: Hereda de PreTrainedModel para compatibilidad completa 11. **MAINTAINED OPTIMIZERS**: Muon + AdamW híbrido preservado 12. **MAINTAINED PRECISION**: bf16-true preservado 13. **MAINTAINED TRAINING**: Custom Trainer con todas las métricas y logging 14. **MAINTAINED ARCHITECTURE**: Toda la arquitectura personalizada preservada 15. **AUTO TOKENIZER**: Integración completa con AutoTokenizer dinámico 16. **AUTOCLASS SUPPORT**: Registro completo para AutoConfig y AutoModel 17. **✅ FIXED LaX**: Implementación correcta Inter-Layer con Linear Gate funcional """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import ( AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM, PreTrainedModel, ) import math import os from typing import Optional, Tuple, Dict, Any, cast, List from flash_attn import flash_attn_func import numpy as np # ✅ ABSOLUTE IMPORT - No relative imports for Hub compatibility from configuration_unified import UnifiedModelConfig # Fix tokenizer parallelism warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" torch.set_float32_matmul_precision('high') def init_cola_components(A: nn.Linear, B: nn.Linear): nn.init.kaiming_normal_(A.weight, mode='fan_in', nonlinearity='relu') nn.init.xavier_normal_(B.weight, gain=0.8) if B.bias is not None: nn.init.zeros_(B.bias) def init_embedding(embedding: nn.Embedding): nn.init.normal_(embedding.weight, mean=0.0, std=0.02) class CanonLayer(nn.Module): def __init__(self, hidden_dim: int, kernel_size: int = 4): """ Canon layer using a 1D causal convolution with residual connection. """ super().__init__() self.hidden_dim = hidden_dim self.kernel_size = kernel_size # Use causal convolution with explicit initialization self.causal_conv1d = nn.Conv1d( in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, groups=hidden_dim, # Depthwise convolution padding=0, # No automatic padding bias=True ) # Initialize weights more conservatively (as per paper) nn.init.zeros_(self.causal_conv1d.weight) nn.init.zeros_(self.causal_conv1d.bias) def forward(self, h: torch.Tensor) -> torch.Tensor: """ Applies the Canon layer transformation with causal masking. """ # Conv1d expects input shape (batch_size, channels, sequence_length) h_permuted = h.permute(0, 2, 1) # (batch, hidden_dim, seq_len) # Add padding of (kernel_size - 1) only to the left side padding = self.kernel_size - 1 h_padded = F.pad(h_permuted, (padding, 0)) # Apply causal convolution conv_out = self.causal_conv1d(h_padded) # Permute back to the original shape conv_out_permuted = conv_out.permute(0, 2, 1) # Add the residual connection output = h + conv_out_permuted return output class CoLA_Linear(nn.Module): def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu, bias: bool = True): super().__init__() if rank is None: rank = in_features // 4 self.rank = rank self.activation = activation self.A = nn.Linear(in_features, rank, bias=False) self.B = nn.Linear(rank, out_features, bias=bias) init_cola_components(self.A, self.B) def forward(self, x: torch.Tensor, prev_latent: Optional[torch.Tensor] = None, lax_beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass with optional LaX Inter-Layer integration. Args: x: Input tensor prev_latent: Previous latent from same module type in previous layer (for LaX) lax_beta: Linear gate parameter (scalar) for LaX Returns: Tuple of (output, current_latent) where current_latent can be used for next layer """ # Standard CoLA forward: A -> activation latent = self.A(x) latent_activated = self.activation(latent) # Apply LaX Inter-Layer if previous latent exists if prev_latent is not None and lax_beta is not None and prev_latent.shape == latent_activated.shape: # Linear Gate: h_i = h_i + β * h_{i-1} latent_activated = latent_activated + lax_beta * prev_latent # B projection output = self.B(latent_activated) return output, latent_activated class LayerNormScaling(nn.Module): def __init__(self, layer_depth: int): super().__init__() if layer_depth < 1: raise ValueError(f"layer_depth debe ser ≥ 1, got {layer_depth}") self.layer_depth = layer_depth self.scaling_factor = 1.0 / math.sqrt(float(layer_depth)) def forward(self, normalized_input: torch.Tensor) -> torch.Tensor: return normalized_input * self.scaling_factor class GPAS(nn.Module): def __init__(self, d_model: int): super().__init__() self.d_model = d_model self.alpha = nn.Parameter(torch.zeros(1)) def forward(self, x: torch.Tensor) -> torch.Tensor: x_detached = x.detach() scaled_component = F.silu(self.alpha) * x_detached x_scaled = x - scaled_component return x_scaled class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, x, seq_len=None): if seq_len is None: seq_len = x.shape[-2] t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(x.dtype), emb.sin().to(x.dtype) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class XIELU(nn.Module): def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta: float = 0.5): super().__init__() self.beta = beta self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1)) self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1)) self.register_buffer('eps', torch.tensor(-1e-6)) def forward(self, x: torch.Tensor) -> torch.Tensor: alpha_p = F.softplus(self.alpha_p) alpha_n = self.beta + F.softplus(self.alpha_n) return torch.where( x > 0, alpha_p * x * x + self.beta * x, alpha_n * torch.expm1(torch.clamp(x, min=self.eps)) - alpha_n * x + self.beta * x ) class StandardMLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, config=None, layer_idx: int = 0): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.config = config self.layer_idx = layer_idx self.up_proj = CoLA_Linear(hidden_size, intermediate_size, bias=False) self.down_proj = CoLA_Linear(intermediate_size, hidden_size, bias=False) if config is not None: self.activation = XIELU( alpha_p_init=config.xielu_alpha_p_init, alpha_n_init=config.xielu_alpha_n_init, beta=config.xielu_beta ) else: self.activation = XIELU(alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() # LaX Linear Gate parameters (β scalars) if config is not None and config.lax_enabled: self.lax_beta_up = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2 self.lax_beta_down = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2 else: self.lax_beta_up = None self.lax_beta_down = None def forward(self, x: torch.Tensor, lax_buffer: Optional[Dict] = None) -> torch.Tensor: # LaX: Get previous latents from buffer prev_up_latent = None prev_down_latent = None if lax_buffer is not None and self.lax_beta_up is not None: prev_up_latent = lax_buffer.get(('mlp_up', self.layer_idx - 1)) prev_down_latent = lax_buffer.get(('mlp_down', self.layer_idx - 1)) # Up projection with LaX intermediate, up_latent = self.up_proj(x, prev_up_latent, self.lax_beta_up) # Store current up latent for next layer if lax_buffer is not None: lax_buffer[('mlp_up', self.layer_idx)] = up_latent.clone() # Activation and dropout activated = self.activation(intermediate) activated = self.dropout(activated) # Down projection with LaX output, down_latent = self.down_proj(activated, prev_down_latent, self.lax_beta_down) # Store current down latent for next layer if lax_buffer is not None: lax_buffer[('mlp_down', self.layer_idx)] = down_latent.clone() return output class GroupedQueryAttention(nn.Module): def __init__(self, config, layer_idx: int = 0): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads # FANFormer components self.fanformer_p = getattr(config, 'fanformer_p', 0.15) self.d_periodic = int(self.hidden_size * self.fanformer_p) self.d_standard = self.hidden_size - 2 * self.d_periodic assert self.d_standard > 0, \ f"fanformer_p={self.fanformer_p} is too high. d_standard={self.d_standard} must be > 0" self.fan_w_p = CoLA_Linear(self.hidden_size, self.d_periodic, bias=False) self.fan_w_p_bar = CoLA_Linear(self.hidden_size, self.d_standard, bias=False) self.q_proj = CoLA_Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = CoLA_Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.v_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.rotary_emb = RotaryEmbedding( self.head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) # LaX Linear Gate parameters (β scalars) - NO o_proj según plan if config.lax_enabled: self.lax_beta_q = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2 self.lax_beta_k = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2 self.lax_beta_v = nn.Parameter(torch.full((1,), 0.2)) # 0.0 → 0.2 else: self.lax_beta_q = None self.lax_beta_k = None self.lax_beta_v = None def _fan_layer_prime(self, x: torch.Tensor) -> torch.Tensor: periodic_proj, _ = self.fan_w_p(x) standard_proj, _ = self.fan_w_p_bar(x) cos_component = torch.cos(periodic_proj) sin_component = torch.sin(periodic_proj) x_f = torch.cat([cos_component, sin_component, standard_proj], dim=-1) return x_f def _compute_flash_attention( self, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, seq_len: int, position_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size = query_states.shape[0] q_rope = query_states.transpose(1, 2) k_rope = key_states.transpose(1, 2) cos, sin = self.rotary_emb(value_states, seq_len=seq_len) q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin, position_ids) query_states = q_rope.transpose(1, 2) key_states = k_rope.transpose(1, 2) from flash_attn import flash_attn_func attn_output = flash_attn_func( query_states, key_states, value_states, dropout_p=self.config.attention_dropout if self.training else 0.0, causal=True, ) return attn_output def forward(self, hidden_states, position_ids=None, attention_mask=None, lax_buffer: Optional[Dict] = None): batch_size, seq_len, _ = hidden_states.shape enhanced_input = self._fan_layer_prime(hidden_states) # LaX: Get previous latents from buffer prev_q_latent = None prev_k_latent = None prev_v_latent = None if lax_buffer is not None and self.lax_beta_q is not None: prev_q_latent = lax_buffer.get(('attn_q', self.layer_idx - 1)) prev_k_latent = lax_buffer.get(('attn_k', self.layer_idx - 1)) prev_v_latent = lax_buffer.get(('attn_v', self.layer_idx - 1)) # Q/K/V projections with LaX query_states, q_latent = self.q_proj(enhanced_input, prev_q_latent, self.lax_beta_q) key_states, k_latent = self.k_proj(enhanced_input, prev_k_latent, self.lax_beta_k) value_states, v_latent = self.v_proj(enhanced_input, prev_v_latent, self.lax_beta_v) # Store current latents for next layer if lax_buffer is not None: lax_buffer[('attn_q', self.layer_idx)] = q_latent.clone() lax_buffer[('attn_k', self.layer_idx)] = k_latent.clone() lax_buffer[('attn_v', self.layer_idx)] = v_latent.clone() query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim) key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) q_flat = query_states.reshape(-1, self.head_dim) k_flat = key_states.reshape(-1, self.head_dim) v_flat = value_states.reshape(-1, self.head_dim) q_normalized = self.q_norm(q_flat) k_normalized = self.k_norm(k_flat) v_normalized = self.v_norm(v_flat) query_states = q_normalized.view(batch_size, seq_len, self.num_heads, self.head_dim) key_states = k_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) value_states = v_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) attn_output = self._compute_flash_attention( query_states=query_states, key_states=key_states, value_states=value_states, seq_len=seq_len, position_ids=position_ids ) attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size) # O projection WITHOUT LaX (según plan) output, _ = self.o_proj(attn_output) return output class DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx < 0: raise ValueError(f"layer_idx debe ser >= 0, got {layer_idx}") self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = GroupedQueryAttention(config, layer_idx) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = StandardMLP( config.hidden_size, config.intermediate_size, config.mlp_dropout, config, layer_idx ) self.dropout_output = nn.Dropout(0.01) self.lns_attention = LayerNormScaling(layer_depth=layer_idx + 1) self.lns_mlp = LayerNormScaling(layer_depth=layer_idx + 1) self.gpas_attention = GPAS(config.hidden_size) self.gpas_mlp = GPAS(config.hidden_size) # Canon layers (A+C only) # Canon-A: Before attention block if config.canon_enabled and config.canon_a_enabled: self.canon_a = CanonLayer(config.hidden_size, config.canon_kernel_size) else: self.canon_a = None # Canon-C: Before MLP block if config.canon_enabled and config.canon_c_enabled: self.canon_c = CanonLayer(config.hidden_size, config.canon_kernel_size) else: self.canon_c = None def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, lax_buffer: Optional[Dict] = None) -> torch.Tensor: residual = hidden_states # Apply Canon-A before attention if self.canon_a is not None: hidden_states = self.canon_a(hidden_states) attention_input = self.input_layernorm(hidden_states) attention_input = self.lns_attention(attention_input) attention_output = self.self_attn(attention_input, position_ids, attention_mask, lax_buffer) hidden_states = residual + attention_output hidden_states = self.gpas_attention(hidden_states) hidden_states = self.dropout_output(hidden_states) residual = hidden_states # Apply Canon-C before MLP if self.canon_c is not None: hidden_states = self.canon_c(hidden_states) mlp_input = self.post_attention_layernorm(hidden_states) mlp_input = self.lns_mlp(mlp_input) mlp_output = self.mlp(mlp_input, lax_buffer) hidden_states = residual + mlp_output hidden_states = self.gpas_mlp(hidden_states) hidden_states = self.dropout_output(hidden_states) return hidden_states class UnifiedModel(PreTrainedModel): """ UnifiedModel that inherits from PreTrainedModel for full HuggingFace compatibility. With AutoClass support for seamless Hub integration. """ config_class = UnifiedModelConfig # ✅ FIXED: _tied_weights_keys as class attribute (HuggingFace standard) _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: UnifiedModelConfig): super().__init__(config) self.config = config if config.vocab_size is None: raise ValueError("config.vocab_size must be set from tokenizer before model initialization") self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embedding_dropout = nn.Dropout(config.embedding_dropout) self.output_dropout = nn.Dropout(0.05) # Create lm_head for output projections self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.layers = nn.ModuleList() for i in range(config.num_hidden_layers): self.layers.append(DecoderLayer(config, i)) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Initialize weights self.post_init() self._print_configuration() def tie_weights(self): """ ✅ FIXED: Simplified tie_weights method following HuggingFace standard. Tie the word embeddings and the output layer. This is called automatically if config.tie_word_embeddings is True. """ if self.config.tie_word_embeddings: print("🔗 Applying weight tying: lm_head.weight = embed_tokens.weight") self.lm_head.weight = self.embed_tokens.weight print("✅ Weight tying successful: Parameters are properly shared") def _init_weights(self, module): """Initialize weights following the custom initialization scheme.""" if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02, a=-0.04, b=0.04) elif isinstance(module, CoLA_Linear): pass # CoLA_Linear has its own initialization def _print_configuration(self): # Conteo ingenuo de todos los parámetros registrados total_params_naive = sum(p.numel() for p in self.parameters()) # Conteo inteligente considerando weight tying total_params_actual = total_params_naive vocab_params = self.config.vocab_size * self.config.hidden_size tied_savings = 0 # ✅ CORRECCIÓN: Detectar y ajustar por weight tying real if self.config.tie_word_embeddings: # Verificar si los tensors están realmente atados en memoria embed_weight = self.embed_tokens.weight lm_head_weight = self.lm_head.weight if embed_weight is lm_head_weight: # Los tensors son idénticos - restar la duplicación tied_savings = vocab_params total_params_actual = total_params_naive - tied_savings else: # Weight tying configurado pero no aplicado aún tied_savings = 0 # Cálculos de optimización existentes total_linear_params = 0 total_cola_params = 0 canon_params = 0 lax_params = 0 for name, module in self.named_modules(): if isinstance(module, CoLA_Linear): in_features = module.A.in_features out_features = module.B.out_features rank = module.rank standard_params = in_features * out_features cola_params = (in_features * rank) + (rank * out_features) total_linear_params += standard_params total_cola_params += cola_params elif isinstance(module, CanonLayer): # Canon layer parameters: depthwise conv1d + bias canon_layer_params = module.hidden_dim * module.kernel_size + module.hidden_dim canon_params += canon_layer_params elif hasattr(module, 'lax_beta_q') and module.lax_beta_q is not None: # Count LaX β parameters lax_params += 3 # q, k, v elif hasattr(module, 'lax_beta_up') and module.lax_beta_up is not None: # Count LaX β parameters lax_params += 2 # up, down cola_reduction = ((total_linear_params - total_cola_params) / total_linear_params) * 100 if total_linear_params > 0 else 0 canon_overhead = (canon_params / total_params_actual) * 100 if total_params_actual > 0 else 0 lax_overhead = (lax_params / total_params_actual) * 100 if total_params_actual > 0 else 0 print(f"\n📊 UNIFIED Model + GPAS + LNS + xIELU + CoLA (Linear Only) + LaX + Canon (A+C) + Weight Tying:") # ✅ MEJORADO: Mostrar conteo real vs ingenuo para transparencia if self.config.tie_word_embeddings and tied_savings > 0: print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M (effective)") print(f"📊 Parameter Breakdown:") print(f" • Naive count: {total_params_naive/1e6:.2f}M (all registered params)") print(f" • Actual count: {total_params_actual/1e6:.2f}M (after weight tying)") print(f" • Weight tying savings: {tied_savings/1e6:.2f}M ({tied_savings/total_params_naive*100:.1f}%)") else: print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M") print(f"📚 DYNAMIC Vocabulary Size: {self.config.vocab_size} (from tokenizer)") print(f"🔗 ✅ PROPER Weight Tying: {'ENABLED' if self.config.tie_word_embeddings else 'DISABLED'}") # ✅ CORRECCIÓN: Mostrar estado real del weight tying if self.config.tie_word_embeddings: if tied_savings > 0: print(f"💾 Weight Tying Status: ✅ ACTIVE (tensors are shared in memory)") else: print(f"💾 Weight Tying Status: ⏳ CONFIGURED (will be applied during post_init)") print(f"🚀 ACTIVATION: xIELU (αp_init={self.config.xielu_alpha_p_init}, αn_init={self.config.xielu_alpha_n_init}, β={self.config.xielu_beta})") print(f"🔄 UPGRADE: SwiGLU → StandardMLP + xIELU (better efficiency & adaptability)") print(f"🗜️ CoLA Integration: {cola_reduction:.1f}% parameter reduction in internal projections") print(f"🔀 LaX Enabled: {'YES' if self.config.lax_enabled else 'NO'} ✅ WORKING Inter-Layer (Linear Gate)") if self.config.lax_enabled: print(f" • LaX Method: Inter-Layer with Linear Gate (β scalars)") print(f" • LaX Applied to: q_proj, k_proj, v_proj, up_proj, down_proj (NOT o_proj)") print(f" • LaX Parameters: {lax_params} β scalars ({lax_overhead:.6f}% overhead)") print(f" • LaX Initialization: β=0.0 (conservative start)") print(f"🎼 Canon Layers Enabled: {'YES' if self.config.canon_enabled else 'NO'} (A+C ONLY)") if self.config.canon_enabled: print(f" • Canon-A (Before Attention): {'✅' if self.config.canon_a_enabled else '❌'}") print(f" • Canon-B (Inside Attention): ❌ PERMANENTLY DISABLED") print(f" • Canon-C (Before MLP): {'✅' if self.config.canon_c_enabled else '❌'}") print(f" • Canon-D (Inside MLP): ❌ PERMANENTLY DISABLED") print(f" • Canon Kernel Size: {self.config.canon_kernel_size}") print(f" • Canon Parameters Overhead: {canon_overhead:.3f}% ({canon_params/1e3:.1f}K params)") print(f"⚡ GPAS Enabled: ALWAYS (Dynamic variance control)") print(f"📏 LNS Enabled: ALWAYS (Static depth scaling)") print(f"🔧 Variance Control: Triple-level (LNS + GPAS + Canon A+C) ALWAYS") print(f"🔗 Residual Connections: STANDARD + HORIZONTAL (Canon A+C only)") print(f"🧹 CLEAN: Standard transformer architecture - CrossEntropyLoss manages PAD naturally") print(f"⚡ FlashAttention: Scaled Dot-Product Attention with GQA + automatic causal masking") print(f"🎯 TOKENIZER AGNOSTIC: Dynamic vocab_size and pad_token_id") print(f"🎯 SIMPLIFIED: CoLA Linear Only + Canon A+C Only = Better performance & less overhead") print(f"🔗 ✅ FIXED Weight Tying: _tied_weights_keys as class attribute (HF standard)") print(f"🎼 Canon A+C BENEFITS: Strategic horizontal information flow with minimal parameters") print(f"🔀 ✅ FIXED LaX: Functional Inter-Layer with ephemeral buffer (no broken reset)") print(f"🤗 HUGGINGFACE COMPATIBLE: Full PreTrainedModel integration v4.53.3") print(f"⚡ ✅ NATIVE TORCH COMPILE: Will be handled by TrainingArguments") print(f"🚀 ✅ AUTOCLASS SUPPORT: Compatible with AutoConfig.from_pretrained() and AutoModel.from_pretrained()") def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, **kwargs ): batch_size, seq_len = input_ids.shape # ✅ LaX: Create ephemeral buffer for this forward pass lax_buffer = {} if self.config.lax_enabled else None hidden_states = self.embed_tokens(input_ids) hidden_states = self.embedding_dropout(hidden_states) for layer in self.layers: hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=attention_mask, lax_buffer=lax_buffer) hidden_states = self.norm(hidden_states) hidden_states = self.output_dropout(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) # ✅ RESTORED: Change pad tokens to -100 so CrossEntropyLoss ignores them (from original code) if self.config.pad_token_id is not None: shift_labels[shift_labels == self.config.pad_token_id] = -100 loss = loss_fct(shift_logits, shift_labels) # ✅ LaX buffer is automatically cleaned up (ephemeral, goes out of scope) # Return in HuggingFace format from transformers.modeling_outputs import CausalLMOutputWithPast return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, ) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 1.0, top_p: float = 0.9, top_k: Optional[int] = None, do_sample: bool = True, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, **kwargs ) -> torch.Tensor: """ Generate sequences using the model. Compatible with AutoModelForCausalLM interface. """ # Set default token IDs if pad_token_id is None: pad_token_id = self.config.pad_token_id if eos_token_id is None: eos_token_id = self.config.eos_token_id batch_size = input_ids.shape[0] device = input_ids.device generated = input_ids.clone() for _ in range(max_new_tokens): # Forward pass (LaX buffer is created fresh each time) outputs = self.forward(generated) logits = outputs.logits # Get the logits for the last token next_token_logits = logits[:, -1, :] if do_sample: # Apply temperature if temperature != 1.0: next_token_logits = next_token_logits / temperature # Apply top-k filtering if top_k is not None: values, indices = torch.topk(next_token_logits, top_k) next_token_logits[next_token_logits < values[:, [-1]]] = -float('inf') # Apply top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # Scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = -float('inf') # Sample from the filtered distribution probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: # Greedy decoding next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Append the new token generated = torch.cat([generated, next_token], dim=1) # Check for EOS token if eos_token_id is not None and (next_token == eos_token_id).all(): break return generated # ✅ AUTOCLASS REGISTRATION - Required for Hub compatibility # Register the configuration and model for AutoClass support AutoConfig.register("unified_model", UnifiedModelConfig) AutoModel.register(UnifiedModelConfig, UnifiedModel) AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel) print("🚀 ✅ AUTOCLASS REGISTRATION COMPLETE:") print(" • AutoConfig.register('unified_model', UnifiedModelConfig)") print(" • AutoModel.register(UnifiedModelConfig, UnifiedModel)") print(" • AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)") print(" • Users can now load with: AutoModel.from_pretrained('your-repo', trust_remote_code=True)")