""" Hybrid ASPP-Attention Architecture (Asterisk Model) Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms to enhance model expressiveness while maintaining efficiency. Architecture Design: - Hybrid layers: Standard attention + ASPP operator in parallel - Gate mechanism for dynamic fusion - Knowledge distillation from SmolLM2-135M base model """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaMLP, ) from transformers import AutoConfig, AutoModelForCausalLM from typing import Optional, Tuple, List class AsteriskConfig(LlamaConfig): """ Configuration class for Asterisk model. Inherits from LlamaConfig with custom model_type. """ model_type = "asterisk" def __init__( self, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, **kwargs ): super().__init__(**kwargs) self.hybrid_layer_indices = hybrid_layer_indices self.aspp_hidden_dim = aspp_hidden_dim self.aspp_num_steps = aspp_num_steps self.aspp_dropout = aspp_dropout class ASPPOperator(nn.Module): """ Asterisk Operator (ASPP) - Point-wise Parallel Propagation Simplified version WITHOUT neighbor gathering to reduce overfitting: - Optional dimensionality reduction for efficiency - Point-wise evolution: h_i^(t+1) = φ(h_i^(t)) [NO neighbors] - Multi-step evolution for depth without added complexity - Dropout for regularization Args: hidden_size: Dimension of hidden states (input/output) aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size) num_steps: Number of evolution steps K (default: 2) dropout: Dropout rate for regularization (default: 0.1) """ def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1): super().__init__() self.hidden_size = hidden_size self.aspp_hidden_dim = aspp_hidden_dim or hidden_size self.num_steps = num_steps # Projection to lower dimension (if specified) self.use_projection = (self.aspp_hidden_dim != hidden_size) if self.use_projection: self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim) self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size) self.proj_dropout = nn.Dropout(dropout) # Point-wise update function φ - NO neighbor gathering # Much smaller: only processes current position self.update_net = nn.Sequential( nn.Linear(self.aspp_hidden_dim, self.aspp_hidden_dim * 2), nn.SiLU(), nn.Dropout(dropout), nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim), nn.Dropout(dropout), ) # Learnable K-step parameter # sigmoid(1.0) ≈ 0.73, giving k_steps ≈ 1.5 → 2 steps initially self.k_logit = nn.Parameter(torch.tensor(1.0)) # Learnable residual scale self.residual_scale = nn.Parameter(torch.tensor(0.1)) # Layer norm for stability self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Args: hidden_states: [batch_size, seq_len, hidden_size] Returns: evolved_states: [batch_size, seq_len, hidden_size] """ # Project to lower dimension if needed if self.use_projection: h_t = self.down_proj(hidden_states) h_t = self.proj_dropout(h_t) else: h_t = hidden_states # Learnable number of steps k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps)) # K-step point-wise evolution (NO neighbor gathering) for t in range(k_steps): # Apply point-wise update rule φ h_t_next = self.update_net(h_t) # Scaled residual connection for stability h_t = h_t + self.residual_scale * h_t_next h_t = self.norm(h_t) # Project back to original dimension if needed if self.use_projection: h_t = self.up_proj(h_t) h_t = self.proj_dropout(h_t) return h_t class HybridASPPAttentionLayer(LlamaDecoderLayer): """ Hybrid layer combining ASPP operator and standard attention Inherits from LlamaDecoderLayer to maintain compatibility Architecture: 1. Parallel branches: - ASPP operator for local structured reasoning - Standard LlamaAttention for global context 2. Gated fusion of both outputs 3. Feed-forward network """ def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): # Initialize parent LlamaDecoderLayer super().__init__(config, layer_idx) # Add ASPP branch self.aspp_operator = ASPPOperator( hidden_size=config.hidden_size, aspp_hidden_dim=aspp_hidden_dim, num_steps=aspp_num_steps, dropout=aspp_dropout ) # Gated fusion mechanism with dropout self.fusion_gate = nn.Sequential( nn.Linear(config.hidden_size * 2, config.hidden_size), nn.Dropout(aspp_dropout), nn.Sigmoid() ) # Initialize gate to be balanced (output 0.5 initially) with torch.no_grad(): self.fusion_gate[0].bias.fill_(0.0) # sigmoid(0) = 0.5 def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: """ Override LlamaDecoderLayer.forward to add ASPP branch Returns single tensor to match LlamaDecoderLayer API in transformers 4.57.6 """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # ASPP branch aspp_output = self.aspp_operator(hidden_states) # Attention branch - use parent's self_attn attn_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, ) attn_output = attn_outputs[0] # Gated fusion fusion_input = torch.cat([aspp_output, attn_output], dim=-1) gate = self.fusion_gate(fusion_input) # Combine with gating: gate * ASPP + (1-gate) * Attention fused_output = gate * aspp_output + (1 - gate) * attn_output # Residual connection hidden_states = residual + fused_output # MLP block (use parent's mlp) residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states # Return single tensor like LlamaDecoderLayer return hidden_states class AsteriskLlamaModel(LlamaModel): """ Asterisk-Llama model with full hybrid ASPP-Attention architecture All layers use hybrid ASPP+Attention by default for maximum expressiveness. """ def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): super().__init__(config) # Determine which layers to make hybrid (default: ALL layers) if hybrid_layer_indices is None: # Use ALL layers as hybrid (full hybrid architecture) num_layers = config.num_hidden_layers hybrid_layer_indices = list(range(num_layers)) self.hybrid_layer_indices = hybrid_layer_indices # Replace specified layers with hybrid layers for idx in hybrid_layer_indices: if idx < len(self.layers): self.layers[idx] = HybridASPPAttentionLayer( config, layer_idx=idx, aspp_hidden_dim=aspp_hidden_dim, aspp_num_steps=aspp_num_steps, aspp_dropout=aspp_dropout ) # Initialize weights self.post_init() class AsteriskForCausalLM(LlamaForCausalLM): """ Asterisk Causal LM with Hybrid ASPP-Attention architecture Registered as: AsteriskForCausalLM """ config_class = AsteriskConfig def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): # Read all ASPP parameters from config if not explicitly provided if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'): hybrid_layer_indices = config.hybrid_layer_indices if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'): aspp_hidden_dim = config.aspp_hidden_dim if hasattr(config, 'aspp_num_steps'): aspp_num_steps = config.aspp_num_steps if hasattr(config, 'aspp_dropout'): aspp_dropout = config.aspp_dropout super().__init__(config) # Replace model with Asterisk version self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout) # Store hybrid layer info in config for serialization self.config.hybrid_layer_indices = hybrid_layer_indices # Initialize weights self.post_init() @classmethod def from_pretrained_base( cls, base_model_path: str, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, **kwargs ): """ Load base model and convert to Asterisk architecture Args: base_model_path: Path to base SmolLM2 model hybrid_layer_indices: Which layers to make hybrid (None for all) aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size) aspp_num_steps: Number of evolution steps K for ASPP (default: 2) aspp_dropout: Dropout rate for ASPP regularization (default: 0.1) """ # Load base model base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs) base_config = base_model.config # Create Asterisk config from base config with ASPP params asterisk_config = AsteriskConfig( **base_config.to_dict(), hybrid_layer_indices=hybrid_layer_indices, aspp_hidden_dim=aspp_hidden_dim, aspp_num_steps=aspp_num_steps, aspp_dropout=aspp_dropout ) # Create Asterisk model asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout) # Transfer weights from base model (non-hybrid layers and embeddings) asterisk_model.load_state_dict(base_model.state_dict(), strict=False) print(f"✓ Converted base model to Asterisk architecture") print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}") aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)" print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}") return asterisk_model, base_model # Register the model for AutoModel AutoConfig.register("asterisk", AsteriskConfig) AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM) def get_model_info(model): """Print model architecture information""" total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" • Total parameters: {total_params:,}") print(f" • Trainable parameters: {trainable_params:,}") print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)") if isinstance(model, AsteriskForCausalLM): print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}") print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}") # Example usage if __name__ == "__main__": print("=" * 80) print("Asterisk Architecture - ASPP + Standard Attention") print("=" * 80) # Configuration base_model_path = "SmolLM2-135M-Instruct" # Create Asterisk model print("\n🔧 Creating Asterisk model...") asterisk_model, base_model = AsteriskForCausalLM.from_pretrained_base( base_model_path, hybrid_layer_indices=None, # Auto-select ALL layers (full hybrid) aspp_num_steps=2, # Reduced from 3 aspp_neighbor_radius=1, # Reduced from 2 aspp_dropout=0.1, # Added dropout torch_dtype=torch.bfloat16, device_map="auto", ) print("\n📊 Base model info:") get_model_info(base_model) print("\n📊 Asterisk model info:") get_model_info(asterisk_model) print("\n✨ Model ready for training!")