|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Unified Language Model with GPAS + LNS Integration + xIELU Activation + CoLA (Linear Only) + LaX + Weight Tying + Canon Layers (A+C Only) |
|
|
MIGRATED TO HUGGINGFACE TRANSFORMERS - FINAL VERSION WITH ALL FIXES + CORRECTED LaX IMPLEMENTATION |
|
|
UPDATED: Standard Transformer with advanced variance control, parameter efficiency, Canon horizontal information flow, and WORKING LaX Inter-Layer |
|
|
Combines advanced Transformer architecture with CORRECTED variance control mechanisms, |
|
|
advanced variance control via GPAS and LNS, xIELU activation function, FIXED LaX integration, and Canon Layers (A+C only) |
|
|
Based on LLaMA 3 architecture with 30M parameters |
|
|
|
|
|
MIGRATION TO HUGGINGFACE - FINAL FIXED VERSION + LaX CORRECTION: |
|
|
============================================================== |
|
|
|
|
|
1. **HUGGINGFACE INTEGRATION**: Migrado de PyTorch Lightning a Transformers v4.53.3 |
|
|
2. **UPDATED API**: processing_class en lugar de tokenizer (deprecated) |
|
|
3. **UPDATED COMPUTE_LOSS**: Método actualizado con num_items_in_batch parameter |
|
|
4. **FIXED LOGGING**: Corregido self.log() syntax según documentación oficial HF |
|
|
5. **RESTORED PAD HANDLING**: pad_token_id → -100 conversion for CrossEntropyLoss (from original code) |
|
|
6. **NATIVE TORCH COMPILE**: Moved to TrainingArguments (torch_compile=True) |
|
|
7. **FIXED WEIGHT TYING**: Corrected _tied_weights_keys as class attribute (HF standard) |
|
|
8. **VALIDATION DIAGNOSTIC**: Added simple method to diagnose validation loss issues |
|
|
9. **CUSTOM CONFIGURATION**: PretrainedConfig personalizada con todos los parámetros |
|
|
10. **PRETRAINED MODEL**: Hereda de PreTrainedModel para compatibilidad completa |
|
|
11. **MAINTAINED OPTIMIZERS**: Muon + AdamW híbrido preservado |
|
|
12. **MAINTAINED PRECISION**: bf16-true preservado |
|
|
13. **MAINTAINED TRAINING**: Custom Trainer con todas las métricas y logging |
|
|
14. **MAINTAINED ARCHITECTURE**: Toda la arquitectura personalizada preservada |
|
|
15. **AUTO TOKENIZER**: Integración completa con AutoTokenizer dinámico |
|
|
16. **AUTOCLASS SUPPORT**: Registro completo para AutoConfig y AutoModel |
|
|
17. **✅ FIXED LaX**: Implementación correcta Inter-Layer con Linear Gate funcional |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
PreTrainedModel, |
|
|
) |
|
|
import math |
|
|
import os |
|
|
from typing import Optional, Tuple, Dict, Any, cast, List |
|
|
from flash_attn import flash_attn_func |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from configuration_unified import UnifiedModelConfig |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
def init_cola_components(A: nn.Linear, B: nn.Linear): |
|
|
nn.init.kaiming_normal_(A.weight, mode='fan_in', nonlinearity='relu') |
|
|
nn.init.xavier_normal_(B.weight, gain=0.8) |
|
|
if B.bias is not None: |
|
|
nn.init.zeros_(B.bias) |
|
|
|
|
|
def init_embedding(embedding: nn.Embedding): |
|
|
nn.init.normal_(embedding.weight, mean=0.0, std=0.02) |
|
|
|
|
|
class CanonLayer(nn.Module): |
|
|
def __init__(self, hidden_dim: int, kernel_size: int = 4): |
|
|
""" |
|
|
Canon layer using a 1D causal convolution with residual connection. |
|
|
""" |
|
|
super().__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
self.kernel_size = kernel_size |
|
|
|
|
|
|
|
|
self.causal_conv1d = nn.Conv1d( |
|
|
in_channels=hidden_dim, |
|
|
out_channels=hidden_dim, |
|
|
kernel_size=kernel_size, |
|
|
groups=hidden_dim, |
|
|
padding=0, |
|
|
bias=True |
|
|
) |
|
|
|
|
|
|
|
|
nn.init.zeros_(self.causal_conv1d.weight) |
|
|
nn.init.zeros_(self.causal_conv1d.bias) |
|
|
|
|
|
def forward(self, h: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Applies the Canon layer transformation with causal masking. |
|
|
""" |
|
|
|
|
|
h_permuted = h.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
padding = self.kernel_size - 1 |
|
|
h_padded = F.pad(h_permuted, (padding, 0)) |
|
|
|
|
|
|
|
|
conv_out = self.causal_conv1d(h_padded) |
|
|
|
|
|
|
|
|
conv_out_permuted = conv_out.permute(0, 2, 1) |
|
|
|
|
|
|
|
|
output = h + conv_out_permuted |
|
|
|
|
|
return output |
|
|
|
|
|
class CoLA_Linear(nn.Module): |
|
|
def __init__(self, in_features: int, out_features: int, rank: Optional[int] = None, activation=F.gelu, bias: bool = True): |
|
|
super().__init__() |
|
|
if rank is None: |
|
|
rank = in_features // 4 |
|
|
self.rank = rank |
|
|
self.activation = activation |
|
|
|
|
|
self.A = nn.Linear(in_features, rank, bias=False) |
|
|
self.B = nn.Linear(rank, out_features, bias=bias) |
|
|
|
|
|
init_cola_components(self.A, self.B) |
|
|
|
|
|
def forward(self, x: torch.Tensor, prev_latent: Optional[torch.Tensor] = None, lax_beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Forward pass with optional LaX Inter-Layer integration. |
|
|
|
|
|
Args: |
|
|
x: Input tensor |
|
|
prev_latent: Previous latent from same module type in previous layer (for LaX) |
|
|
lax_beta: Linear gate parameter (scalar) for LaX |
|
|
|
|
|
Returns: |
|
|
Tuple of (output, current_latent) where current_latent can be used for next layer |
|
|
""" |
|
|
|
|
|
latent = self.A(x) |
|
|
latent_activated = self.activation(latent) |
|
|
|
|
|
|
|
|
if prev_latent is not None and lax_beta is not None and prev_latent.shape == latent_activated.shape: |
|
|
|
|
|
latent_activated = latent_activated + lax_beta * prev_latent |
|
|
|
|
|
|
|
|
output = self.B(latent_activated) |
|
|
|
|
|
return output, latent_activated |
|
|
|
|
|
class LayerNormScaling(nn.Module): |
|
|
def __init__(self, layer_depth: int): |
|
|
super().__init__() |
|
|
|
|
|
if layer_depth < 1: |
|
|
raise ValueError(f"layer_depth debe ser ≥ 1, got {layer_depth}") |
|
|
|
|
|
self.layer_depth = layer_depth |
|
|
self.scaling_factor = 1.0 / math.sqrt(float(layer_depth)) |
|
|
|
|
|
def forward(self, normalized_input: torch.Tensor) -> torch.Tensor: |
|
|
return normalized_input * self.scaling_factor |
|
|
|
|
|
class GPAS(nn.Module): |
|
|
def __init__(self, d_model: int): |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.alpha = nn.Parameter(torch.zeros(1)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x_detached = x.detach() |
|
|
scaled_component = F.silu(self.alpha) * x_detached |
|
|
x_scaled = x - scaled_component |
|
|
|
|
|
return x_scaled |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.base = base |
|
|
|
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, x, seq_len=None): |
|
|
if seq_len is None: |
|
|
seq_len = x.shape[-2] |
|
|
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) |
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
return emb.cos().to(x.dtype), emb.sin().to(x.dtype) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): |
|
|
def rotate_half(x): |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
class XIELU(nn.Module): |
|
|
def __init__(self, alpha_p_init: float = 0.8, alpha_n_init: float = 0.8, beta: float = 0.5): |
|
|
super().__init__() |
|
|
|
|
|
self.beta = beta |
|
|
|
|
|
self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init)) - 1)) |
|
|
self.alpha_n = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_n_init - self.beta)) - 1)) |
|
|
|
|
|
self.register_buffer('eps', torch.tensor(-1e-6)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
alpha_p = F.softplus(self.alpha_p) |
|
|
alpha_n = self.beta + F.softplus(self.alpha_n) |
|
|
|
|
|
return torch.where( |
|
|
x > 0, |
|
|
alpha_p * x * x + self.beta * x, |
|
|
alpha_n * torch.expm1(torch.clamp(x, min=self.eps)) - alpha_n * x + self.beta * x |
|
|
) |
|
|
|
|
|
class StandardMLP(nn.Module): |
|
|
def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.0, config=None, layer_idx: int = 0): |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
self.up_proj = CoLA_Linear(hidden_size, intermediate_size, bias=False) |
|
|
self.down_proj = CoLA_Linear(intermediate_size, hidden_size, bias=False) |
|
|
|
|
|
if config is not None: |
|
|
self.activation = XIELU( |
|
|
alpha_p_init=config.xielu_alpha_p_init, |
|
|
alpha_n_init=config.xielu_alpha_n_init, |
|
|
beta=config.xielu_beta |
|
|
) |
|
|
else: |
|
|
self.activation = XIELU(alpha_p_init=0.8, alpha_n_init=0.8, beta=0.5) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
|
|
|
|
|
|
if config is not None and config.lax_enabled: |
|
|
self.lax_beta_up = nn.Parameter(torch.full((1,), 0.2)) |
|
|
self.lax_beta_down = nn.Parameter(torch.full((1,), 0.2)) |
|
|
else: |
|
|
self.lax_beta_up = None |
|
|
self.lax_beta_down = None |
|
|
|
|
|
def forward(self, x: torch.Tensor, lax_buffer: Optional[Dict] = None) -> torch.Tensor: |
|
|
|
|
|
prev_up_latent = None |
|
|
prev_down_latent = None |
|
|
if lax_buffer is not None and self.lax_beta_up is not None: |
|
|
prev_up_latent = lax_buffer.get(('mlp_up', self.layer_idx - 1)) |
|
|
prev_down_latent = lax_buffer.get(('mlp_down', self.layer_idx - 1)) |
|
|
|
|
|
|
|
|
intermediate, up_latent = self.up_proj(x, prev_up_latent, self.lax_beta_up) |
|
|
|
|
|
|
|
|
if lax_buffer is not None: |
|
|
lax_buffer[('mlp_up', self.layer_idx)] = up_latent.clone() |
|
|
|
|
|
|
|
|
activated = self.activation(intermediate) |
|
|
activated = self.dropout(activated) |
|
|
|
|
|
|
|
|
output, down_latent = self.down_proj(activated, prev_down_latent, self.lax_beta_down) |
|
|
|
|
|
|
|
|
if lax_buffer is not None: |
|
|
lax_buffer[('mlp_down', self.layer_idx)] = down_latent.clone() |
|
|
|
|
|
return output |
|
|
|
|
|
class GroupedQueryAttention(nn.Module): |
|
|
def __init__(self, config, layer_idx: int = 0): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
|
|
|
|
|
|
self.fanformer_p = getattr(config, 'fanformer_p', 0.15) |
|
|
|
|
|
self.d_periodic = int(self.hidden_size * self.fanformer_p) |
|
|
self.d_standard = self.hidden_size - 2 * self.d_periodic |
|
|
|
|
|
assert self.d_standard > 0, \ |
|
|
f"fanformer_p={self.fanformer_p} is too high. d_standard={self.d_standard} must be > 0" |
|
|
|
|
|
self.fan_w_p = CoLA_Linear(self.hidden_size, self.d_periodic, bias=False) |
|
|
self.fan_w_p_bar = CoLA_Linear(self.hidden_size, self.d_standard, bias=False) |
|
|
|
|
|
self.q_proj = CoLA_Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
|
self.k_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.v_proj = CoLA_Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.o_proj = CoLA_Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
|
|
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
self.v_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
|
|
|
self.rotary_emb = RotaryEmbedding( |
|
|
self.head_dim, |
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
base=config.rope_theta |
|
|
) |
|
|
|
|
|
|
|
|
if config.lax_enabled: |
|
|
self.lax_beta_q = nn.Parameter(torch.full((1,), 0.2)) |
|
|
self.lax_beta_k = nn.Parameter(torch.full((1,), 0.2)) |
|
|
self.lax_beta_v = nn.Parameter(torch.full((1,), 0.2)) |
|
|
else: |
|
|
self.lax_beta_q = None |
|
|
self.lax_beta_k = None |
|
|
self.lax_beta_v = None |
|
|
|
|
|
def _fan_layer_prime(self, x: torch.Tensor) -> torch.Tensor: |
|
|
periodic_proj, _ = self.fan_w_p(x) |
|
|
standard_proj, _ = self.fan_w_p_bar(x) |
|
|
|
|
|
cos_component = torch.cos(periodic_proj) |
|
|
sin_component = torch.sin(periodic_proj) |
|
|
|
|
|
x_f = torch.cat([cos_component, sin_component, standard_proj], dim=-1) |
|
|
|
|
|
return x_f |
|
|
|
|
|
def _compute_flash_attention( |
|
|
self, |
|
|
query_states: torch.Tensor, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
seq_len: int, |
|
|
position_ids: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
batch_size = query_states.shape[0] |
|
|
|
|
|
q_rope = query_states.transpose(1, 2) |
|
|
k_rope = key_states.transpose(1, 2) |
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, seq_len=seq_len) |
|
|
q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin, position_ids) |
|
|
|
|
|
query_states = q_rope.transpose(1, 2) |
|
|
key_states = k_rope.transpose(1, 2) |
|
|
|
|
|
from flash_attn import flash_attn_func |
|
|
|
|
|
attn_output = flash_attn_func( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
dropout_p=self.config.attention_dropout if self.training else 0.0, |
|
|
causal=True, |
|
|
) |
|
|
|
|
|
return attn_output |
|
|
|
|
|
def forward(self, hidden_states, position_ids=None, attention_mask=None, lax_buffer: Optional[Dict] = None): |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
enhanced_input = self._fan_layer_prime(hidden_states) |
|
|
|
|
|
|
|
|
prev_q_latent = None |
|
|
prev_k_latent = None |
|
|
prev_v_latent = None |
|
|
if lax_buffer is not None and self.lax_beta_q is not None: |
|
|
prev_q_latent = lax_buffer.get(('attn_q', self.layer_idx - 1)) |
|
|
prev_k_latent = lax_buffer.get(('attn_k', self.layer_idx - 1)) |
|
|
prev_v_latent = lax_buffer.get(('attn_v', self.layer_idx - 1)) |
|
|
|
|
|
|
|
|
query_states, q_latent = self.q_proj(enhanced_input, prev_q_latent, self.lax_beta_q) |
|
|
key_states, k_latent = self.k_proj(enhanced_input, prev_k_latent, self.lax_beta_k) |
|
|
value_states, v_latent = self.v_proj(enhanced_input, prev_v_latent, self.lax_beta_v) |
|
|
|
|
|
|
|
|
if lax_buffer is not None: |
|
|
lax_buffer[('attn_q', self.layer_idx)] = q_latent.clone() |
|
|
lax_buffer[('attn_k', self.layer_idx)] = k_latent.clone() |
|
|
lax_buffer[('attn_v', self.layer_idx)] = v_latent.clone() |
|
|
|
|
|
query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
|
|
value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
|
|
|
|
|
q_flat = query_states.reshape(-1, self.head_dim) |
|
|
k_flat = key_states.reshape(-1, self.head_dim) |
|
|
v_flat = value_states.reshape(-1, self.head_dim) |
|
|
|
|
|
q_normalized = self.q_norm(q_flat) |
|
|
k_normalized = self.k_norm(k_flat) |
|
|
v_normalized = self.v_norm(v_flat) |
|
|
|
|
|
query_states = q_normalized.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
key_states = k_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
|
|
value_states = v_normalized.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim) |
|
|
|
|
|
attn_output = self._compute_flash_attention( |
|
|
query_states=query_states, |
|
|
key_states=key_states, |
|
|
value_states=value_states, |
|
|
seq_len=seq_len, |
|
|
position_ids=position_ids |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size) |
|
|
|
|
|
|
|
|
output, _ = self.o_proj(attn_output) |
|
|
return output |
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
|
def __init__(self, config, layer_idx: int): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
if layer_idx < 0: |
|
|
raise ValueError(f"layer_idx debe ser >= 0, got {layer_idx}") |
|
|
|
|
|
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.self_attn = GroupedQueryAttention(config, layer_idx) |
|
|
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
self.mlp = StandardMLP( |
|
|
config.hidden_size, |
|
|
config.intermediate_size, |
|
|
config.mlp_dropout, |
|
|
config, |
|
|
layer_idx |
|
|
) |
|
|
|
|
|
self.dropout_output = nn.Dropout(0.01) |
|
|
|
|
|
self.lns_attention = LayerNormScaling(layer_depth=layer_idx + 1) |
|
|
self.lns_mlp = LayerNormScaling(layer_depth=layer_idx + 1) |
|
|
|
|
|
self.gpas_attention = GPAS(config.hidden_size) |
|
|
self.gpas_mlp = GPAS(config.hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
if config.canon_enabled and config.canon_a_enabled: |
|
|
self.canon_a = CanonLayer(config.hidden_size, config.canon_kernel_size) |
|
|
else: |
|
|
self.canon_a = None |
|
|
|
|
|
|
|
|
if config.canon_enabled and config.canon_c_enabled: |
|
|
self.canon_c = CanonLayer(config.hidden_size, config.canon_kernel_size) |
|
|
else: |
|
|
self.canon_c = None |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, lax_buffer: Optional[Dict] = None) -> torch.Tensor: |
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
if self.canon_a is not None: |
|
|
hidden_states = self.canon_a(hidden_states) |
|
|
|
|
|
attention_input = self.input_layernorm(hidden_states) |
|
|
attention_input = self.lns_attention(attention_input) |
|
|
attention_output = self.self_attn(attention_input, position_ids, attention_mask, lax_buffer) |
|
|
hidden_states = residual + attention_output |
|
|
hidden_states = self.gpas_attention(hidden_states) |
|
|
hidden_states = self.dropout_output(hidden_states) |
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
if self.canon_c is not None: |
|
|
hidden_states = self.canon_c(hidden_states) |
|
|
|
|
|
mlp_input = self.post_attention_layernorm(hidden_states) |
|
|
mlp_input = self.lns_mlp(mlp_input) |
|
|
mlp_output = self.mlp(mlp_input, lax_buffer) |
|
|
hidden_states = residual + mlp_output |
|
|
hidden_states = self.gpas_mlp(hidden_states) |
|
|
hidden_states = self.dropout_output(hidden_states) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class UnifiedModel(PreTrainedModel): |
|
|
""" |
|
|
UnifiedModel that inherits from PreTrainedModel for full HuggingFace compatibility. |
|
|
With AutoClass support for seamless Hub integration. |
|
|
""" |
|
|
config_class = UnifiedModelConfig |
|
|
|
|
|
|
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config: UnifiedModelConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
if config.vocab_size is None: |
|
|
raise ValueError("config.vocab_size must be set from tokenizer before model initialization") |
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.embedding_dropout = nn.Dropout(config.embedding_dropout) |
|
|
self.output_dropout = nn.Dropout(0.05) |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.layers = nn.ModuleList() |
|
|
for i in range(config.num_hidden_layers): |
|
|
self.layers.append(DecoderLayer(config, i)) |
|
|
|
|
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self._print_configuration() |
|
|
|
|
|
def tie_weights(self): |
|
|
""" |
|
|
✅ FIXED: Simplified tie_weights method following HuggingFace standard. |
|
|
Tie the word embeddings and the output layer. |
|
|
This is called automatically if config.tie_word_embeddings is True. |
|
|
""" |
|
|
if self.config.tie_word_embeddings: |
|
|
print("🔗 Applying weight tying: lm_head.weight = embed_tokens.weight") |
|
|
self.lm_head.weight = self.embed_tokens.weight |
|
|
print("✅ Weight tying successful: Parameters are properly shared") |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize weights following the custom initialization scheme.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02, a=-0.04, b=0.04) |
|
|
elif isinstance(module, CoLA_Linear): |
|
|
pass |
|
|
|
|
|
def _print_configuration(self): |
|
|
|
|
|
total_params_naive = sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
|
|
|
total_params_actual = total_params_naive |
|
|
vocab_params = self.config.vocab_size * self.config.hidden_size |
|
|
tied_savings = 0 |
|
|
|
|
|
|
|
|
if self.config.tie_word_embeddings: |
|
|
|
|
|
embed_weight = self.embed_tokens.weight |
|
|
lm_head_weight = self.lm_head.weight |
|
|
|
|
|
if embed_weight is lm_head_weight: |
|
|
|
|
|
tied_savings = vocab_params |
|
|
total_params_actual = total_params_naive - tied_savings |
|
|
else: |
|
|
|
|
|
tied_savings = 0 |
|
|
|
|
|
|
|
|
total_linear_params = 0 |
|
|
total_cola_params = 0 |
|
|
canon_params = 0 |
|
|
lax_params = 0 |
|
|
|
|
|
for name, module in self.named_modules(): |
|
|
if isinstance(module, CoLA_Linear): |
|
|
in_features = module.A.in_features |
|
|
out_features = module.B.out_features |
|
|
rank = module.rank |
|
|
|
|
|
standard_params = in_features * out_features |
|
|
cola_params = (in_features * rank) + (rank * out_features) |
|
|
|
|
|
total_linear_params += standard_params |
|
|
total_cola_params += cola_params |
|
|
elif isinstance(module, CanonLayer): |
|
|
|
|
|
canon_layer_params = module.hidden_dim * module.kernel_size + module.hidden_dim |
|
|
canon_params += canon_layer_params |
|
|
elif hasattr(module, 'lax_beta_q') and module.lax_beta_q is not None: |
|
|
|
|
|
lax_params += 3 |
|
|
elif hasattr(module, 'lax_beta_up') and module.lax_beta_up is not None: |
|
|
|
|
|
lax_params += 2 |
|
|
|
|
|
cola_reduction = ((total_linear_params - total_cola_params) / total_linear_params) * 100 if total_linear_params > 0 else 0 |
|
|
canon_overhead = (canon_params / total_params_actual) * 100 if total_params_actual > 0 else 0 |
|
|
lax_overhead = (lax_params / total_params_actual) * 100 if total_params_actual > 0 else 0 |
|
|
|
|
|
print(f"\n📊 UNIFIED Model + GPAS + LNS + xIELU + CoLA (Linear Only) + LaX + Canon (A+C) + Weight Tying:") |
|
|
|
|
|
|
|
|
if self.config.tie_word_embeddings and tied_savings > 0: |
|
|
print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M (effective)") |
|
|
print(f"📊 Parameter Breakdown:") |
|
|
print(f" • Naive count: {total_params_naive/1e6:.2f}M (all registered params)") |
|
|
print(f" • Actual count: {total_params_actual/1e6:.2f}M (after weight tying)") |
|
|
print(f" • Weight tying savings: {tied_savings/1e6:.2f}M ({tied_savings/total_params_naive*100:.1f}%)") |
|
|
else: |
|
|
print(f"🎯 Total Parameters: {total_params_actual/1e6:.2f}M") |
|
|
|
|
|
print(f"📚 DYNAMIC Vocabulary Size: {self.config.vocab_size} (from tokenizer)") |
|
|
print(f"🔗 ✅ PROPER Weight Tying: {'ENABLED' if self.config.tie_word_embeddings else 'DISABLED'}") |
|
|
|
|
|
|
|
|
if self.config.tie_word_embeddings: |
|
|
if tied_savings > 0: |
|
|
print(f"💾 Weight Tying Status: ✅ ACTIVE (tensors are shared in memory)") |
|
|
else: |
|
|
print(f"💾 Weight Tying Status: ⏳ CONFIGURED (will be applied during post_init)") |
|
|
|
|
|
print(f"🚀 ACTIVATION: xIELU (αp_init={self.config.xielu_alpha_p_init}, αn_init={self.config.xielu_alpha_n_init}, β={self.config.xielu_beta})") |
|
|
print(f"🔄 UPGRADE: SwiGLU → StandardMLP + xIELU (better efficiency & adaptability)") |
|
|
print(f"🗜️ CoLA Integration: {cola_reduction:.1f}% parameter reduction in internal projections") |
|
|
print(f"🔀 LaX Enabled: {'YES' if self.config.lax_enabled else 'NO'} ✅ WORKING Inter-Layer (Linear Gate)") |
|
|
if self.config.lax_enabled: |
|
|
print(f" • LaX Method: Inter-Layer with Linear Gate (β scalars)") |
|
|
print(f" • LaX Applied to: q_proj, k_proj, v_proj, up_proj, down_proj (NOT o_proj)") |
|
|
print(f" • LaX Parameters: {lax_params} β scalars ({lax_overhead:.6f}% overhead)") |
|
|
print(f" • LaX Initialization: β=0.0 (conservative start)") |
|
|
print(f"🎼 Canon Layers Enabled: {'YES' if self.config.canon_enabled else 'NO'} (A+C ONLY)") |
|
|
if self.config.canon_enabled: |
|
|
print(f" • Canon-A (Before Attention): {'✅' if self.config.canon_a_enabled else '❌'}") |
|
|
print(f" • Canon-B (Inside Attention): ❌ PERMANENTLY DISABLED") |
|
|
print(f" • Canon-C (Before MLP): {'✅' if self.config.canon_c_enabled else '❌'}") |
|
|
print(f" • Canon-D (Inside MLP): ❌ PERMANENTLY DISABLED") |
|
|
print(f" • Canon Kernel Size: {self.config.canon_kernel_size}") |
|
|
print(f" • Canon Parameters Overhead: {canon_overhead:.3f}% ({canon_params/1e3:.1f}K params)") |
|
|
print(f"⚡ GPAS Enabled: ALWAYS (Dynamic variance control)") |
|
|
print(f"📏 LNS Enabled: ALWAYS (Static depth scaling)") |
|
|
print(f"🔧 Variance Control: Triple-level (LNS + GPAS + Canon A+C) ALWAYS") |
|
|
print(f"🔗 Residual Connections: STANDARD + HORIZONTAL (Canon A+C only)") |
|
|
print(f"🧹 CLEAN: Standard transformer architecture - CrossEntropyLoss manages PAD naturally") |
|
|
print(f"⚡ FlashAttention: Scaled Dot-Product Attention with GQA + automatic causal masking") |
|
|
print(f"🎯 TOKENIZER AGNOSTIC: Dynamic vocab_size and pad_token_id") |
|
|
print(f"🎯 SIMPLIFIED: CoLA Linear Only + Canon A+C Only = Better performance & less overhead") |
|
|
print(f"🔗 ✅ FIXED Weight Tying: _tied_weights_keys as class attribute (HF standard)") |
|
|
print(f"🎼 Canon A+C BENEFITS: Strategic horizontal information flow with minimal parameters") |
|
|
print(f"🔀 ✅ FIXED LaX: Functional Inter-Layer with ephemeral buffer (no broken reset)") |
|
|
print(f"🤗 HUGGINGFACE COMPATIBLE: Full PreTrainedModel integration v4.53.3") |
|
|
print(f"⚡ ✅ NATIVE TORCH COMPILE: Will be handled by TrainingArguments") |
|
|
print(f"🚀 ✅ AUTOCLASS SUPPORT: Compatible with AutoConfig.from_pretrained() and AutoModel.from_pretrained()") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
**kwargs |
|
|
): |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
lax_buffer = {} if self.config.lax_enabled else None |
|
|
|
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
hidden_states = self.embedding_dropout(hidden_states) |
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=attention_mask, lax_buffer=lax_buffer) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
hidden_states = self.output_dropout(hidden_states) |
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
|
|
|
|
|
|
if self.config.pad_token_id is not None: |
|
|
shift_labels[shift_labels == self.config.pad_token_id] = -100 |
|
|
|
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=None, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
max_new_tokens: int = 50, |
|
|
temperature: float = 1.0, |
|
|
top_p: float = 0.9, |
|
|
top_k: Optional[int] = None, |
|
|
do_sample: bool = True, |
|
|
pad_token_id: Optional[int] = None, |
|
|
eos_token_id: Optional[int] = None, |
|
|
**kwargs |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Generate sequences using the model. |
|
|
Compatible with AutoModelForCausalLM interface. |
|
|
""" |
|
|
|
|
|
if pad_token_id is None: |
|
|
pad_token_id = self.config.pad_token_id |
|
|
if eos_token_id is None: |
|
|
eos_token_id = self.config.eos_token_id |
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
device = input_ids.device |
|
|
|
|
|
generated = input_ids.clone() |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
outputs = self.forward(generated) |
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
if do_sample: |
|
|
|
|
|
if temperature != 1.0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
values, indices = torch.topk(next_token_logits, top_k) |
|
|
next_token_logits[next_token_logits < values[:, [-1]]] = -float('inf') |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits[indices_to_remove] = -float('inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
|
|
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
|
|
|
if eos_token_id is not None and (next_token == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
return generated |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("unified_model", UnifiedModelConfig) |
|
|
AutoModel.register(UnifiedModelConfig, UnifiedModel) |
|
|
AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel) |
|
|
|
|
|
print("🚀 ✅ AUTOCLASS REGISTRATION COMPLETE:") |
|
|
print(" • AutoConfig.register('unified_model', UnifiedModelConfig)") |
|
|
print(" • AutoModel.register(UnifiedModelConfig, UnifiedModel)") |
|
|
print(" • AutoModelForCausalLM.register(UnifiedModelConfig, UnifiedModel)") |
|
|
print(" • Users can now load with: AutoModel.from_pretrained('your-repo', trust_remote_code=True)") |