""" Hybrid ASPP-Attention Architecture (Asterisk Model) Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms to enhance model expressiveness while maintaining efficiency. Architecture Design: - Hybrid layers: Standard attention + ASPP operator in parallel - Gate mechanism for dynamic fusion - Knowledge distillation from SmolLM2-135M base model """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaMLP, ) from transformers import AutoConfig, AutoModelForCausalLM from typing import Optional, Tuple, List class AsteriskConfig(LlamaConfig): """ Configuration class for Asterisk model. Inherits from LlamaConfig with custom model_type. """ model_type = "asterisk" def __init__( self, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent) # π-flow parameters pi_flow: bool = False, pi_flow_steps: int = 1, pi_flow_scale: float = 0.2, pi_flow_use_gate: bool = True, **kwargs ): super().__init__(**kwargs) self.hybrid_layer_indices = hybrid_layer_indices self.aspp_hidden_dim = aspp_hidden_dim self.aspp_num_steps = aspp_num_steps self.aspp_dropout = aspp_dropout self.aspp_num_neighbors = aspp_num_neighbors # π-flow config self.pi_flow = pi_flow self.pi_flow_steps = pi_flow_steps self.pi_flow_scale = pi_flow_scale self.pi_flow_use_gate = pi_flow_use_gate class ASPPOperator(nn.Module): """ Asterisk Operator (ASPP) - Union-Find Graph Propagation Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections: - Each position maintains a parent pointer: parent[i] - Initial structure: parent[i] = max(0, i-1) (linear chain) - Message passing: aggregate self + parent features - Can apply path compression for optimization Advantages: - O(n) complexity with simple indexing - Dynamic grouping of related positions - Efficient parent-only propagation (no complex gather) - Nearly constant time find with path compression Complexity: O(n) with α(n) ≈ O(1) per operation Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i]) Args: hidden_size: Dimension of hidden states (input/output) aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size) num_steps: Number of evolution steps K (default: 2) dropout: Dropout rate for regularization (default: 0.1) num_neighbors: Fixed at 1 (only parent) for Union-Find structure """ def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1): super().__init__() self.hidden_size = hidden_size self.aspp_hidden_dim = aspp_hidden_dim or hidden_size self.num_steps = num_steps self.num_neighbors = 1 # Fixed: only parent # Projection to lower dimension (if specified) self.use_projection = (self.aspp_hidden_dim != hidden_size) if self.use_projection: self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim) self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size) self.proj_dropout = nn.Dropout(dropout) # Message aggregation function: combines self + parent self.message_net = nn.Sequential( nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2), nn.SiLU(), nn.Dropout(dropout), nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim), nn.Dropout(dropout), ) # Learnable K-step parameter self.k_logit = nn.Parameter(torch.tensor(1.0)) # Learnable residual scale self.residual_scale = nn.Parameter(torch.tensor(0.1)) # Layer norm for stability self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5) def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor: """ Compute parent index for each position using Union-Find structure Simple implementation: parent[i] = i-1 (linear chain) - Position 0 points to itself (root) - All others point to previous position Can be extended with dynamic union operations based on: - Semantic similarity - Positional heuristics - Learned grouping Returns: [seq_len] tensor of parent indices """ # Initialize: parent[i] = max(0, i-1) parent_indices = torch.arange(seq_len, device=device) - 1 parent_indices[0] = 0 # Root points to itself parent_indices = torch.clamp(parent_indices, 0, seq_len - 1) return parent_indices def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Args: hidden_states: [batch_size, seq_len, hidden_size] Returns: evolved_states: [batch_size, seq_len, hidden_size] """ batch_size, seq_len, _ = hidden_states.shape # Project to lower dimension if needed if self.use_projection: h_t = self.down_proj(hidden_states) h_t = self.proj_dropout(h_t) else: h_t = hidden_states # Learnable number of steps k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps)) # K-step Union-Find graph propagation for t in range(k_steps): # 1. Compute parent indices using Union-Find structure parent_indices = self.compute_parent_indices(seq_len, h_t.device) # [L] # 2. Gather parent features (super simple indexing!) # h_t: [B, L, D], parent_indices: [L] # Just gather from parent positions parent_features = h_t[:, parent_indices, :] # [B, L, D] # 3. Message passing: combine self + parent message_input = torch.cat([h_t, parent_features], dim=-1) # [B, L, 2D] h_t_next = self.message_net(message_input) # [B, L, D] # 4. Scaled residual connection for stability h_t = h_t + self.residual_scale * h_t_next h_t = self.norm(h_t) # Project back to original dimension if needed if self.use_projection: h_t = self.up_proj(h_t) h_t = self.proj_dropout(h_t) return h_t class HybridASPPAttentionLayer(LlamaDecoderLayer): """ Hybrid layer combining ASPP operator and standard attention Inherits from LlamaDecoderLayer to maintain compatibility Architecture: 1. Parallel branches: - ASPP operator for local structured reasoning - Standard LlamaAttention for global context 2. Gated fusion of both outputs 3. π-flow refinement (optional, per-layer) 4. Feed-forward network """ def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1): # Initialize parent LlamaDecoderLayer super().__init__(config, layer_idx) # Add ASPP branch self.aspp_operator = ASPPOperator( hidden_size=config.hidden_size, aspp_hidden_dim=aspp_hidden_dim, num_steps=aspp_num_steps, dropout=aspp_dropout, num_neighbors=aspp_num_neighbors ) # Gated fusion mechanism with dropout self.fusion_gate = nn.Sequential( nn.Linear(config.hidden_size * 2, config.hidden_size), nn.Dropout(aspp_dropout), nn.Sigmoid() ) # Initialize gate to be balanced (output 0.5 initially) with torch.no_grad(): self.fusion_gate[0].bias.fill_(0.0) # sigmoid(0) = 0.5 # π-flow: Per-layer refinement ASPP if getattr(config, 'pi_flow', False): self.pi_flow_aspp = ASPPOperator( hidden_size=config.hidden_size, aspp_hidden_dim=aspp_hidden_dim, num_steps=aspp_num_steps, dropout=aspp_dropout, num_neighbors=aspp_num_neighbors ) # Learnable flow scale (per-layer) self.pi_flow_scale = nn.Parameter( torch.tensor(getattr(config, 'pi_flow_scale', 0.2)) ) # Token-wise adaptive gating (optional) if getattr(config, 'pi_flow_use_gate', True): self.pi_flow_gate = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size // 4), nn.SiLU(), nn.Dropout(aspp_dropout), nn.Linear(config.hidden_size // 4, 1), nn.Sigmoid() ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: """ Override LlamaDecoderLayer.forward to add ASPP branch and π-flow Returns single tensor like LlamaDecoderLayer """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # ASPP branch aspp_output = self.aspp_operator(hidden_states) # Attention branch - use parent's self_attn (returns tuple, discard cache with _) attn_output, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, ) # Gated fusion fusion_input = torch.cat([aspp_output, attn_output], dim=-1) gate = self.fusion_gate(fusion_input) # Combine with gating: gate * ASPP + (1-gate) * Attention fused_output = gate * aspp_output + (1 - gate) * attn_output # Residual connection hidden_states = residual + fused_output # π-flow: Multi-step refinement in probability space (per-layer) if hasattr(self, 'pi_flow_aspp'): pi_flow_steps = getattr(self.config if hasattr(self, 'config') else kwargs.get('config'), 'pi_flow_steps', 1) for step in range(pi_flow_steps): # Compute velocity field v(h) using ASPP v = self.pi_flow_aspp(hidden_states) # Compute adaptive gate (per-token flow strength) if hasattr(self, 'pi_flow_gate'): gate = self.pi_flow_gate(hidden_states) # [B, L, 1] alpha = self.pi_flow_scale * gate else: alpha = self.pi_flow_scale # Euler step: h' = h + α * v(h) hidden_states = hidden_states + alpha * v # MLP block (use parent's mlp) residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states # Return only hidden_states tensor, like LlamaDecoderLayer return hidden_states class AsteriskLlamaModel(LlamaModel): """ Asterisk-Llama model with full hybrid ASPP-Attention architecture All layers use hybrid ASPP+Attention by default for maximum expressiveness. """ def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): super().__init__(config) # Determine which layers to make hybrid (default: ALL layers) if hybrid_layer_indices is None: # Use ALL layers as hybrid (full hybrid architecture) num_layers = config.num_hidden_layers hybrid_layer_indices = list(range(num_layers)) self.hybrid_layer_indices = hybrid_layer_indices # Replace specified layers with hybrid layers (with per-layer π-flow if enabled) for idx in hybrid_layer_indices: if idx < len(self.layers): self.layers[idx] = HybridASPPAttentionLayer( config, layer_idx=idx, aspp_hidden_dim=aspp_hidden_dim, aspp_num_steps=aspp_num_steps, aspp_dropout=aspp_dropout, aspp_num_neighbors=aspp_num_neighbors ) # Initialize weights self.post_init() class AsteriskForCausalLM(LlamaForCausalLM): """ Asterisk Causal LM with Hybrid ASPP-Attention architecture Registered as: AsteriskForCausalLM """ config_class = AsteriskConfig def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): # Read all ASPP parameters from config if not explicitly provided if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'): hybrid_layer_indices = config.hybrid_layer_indices if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'): aspp_hidden_dim = config.aspp_hidden_dim if hasattr(config, 'aspp_num_steps'): aspp_num_steps = config.aspp_num_steps if hasattr(config, 'aspp_dropout'): aspp_dropout = config.aspp_dropout if hasattr(config, 'aspp_num_neighbors'): aspp_num_neighbors = config.aspp_num_neighbors super().__init__(config) # Replace model with Asterisk version self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) # Store hybrid layer info in config for serialization self.config.hybrid_layer_indices = hybrid_layer_indices # Initialize weights self.post_init() @classmethod def from_pretrained_base( cls, base_model_path: str, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent) # π-flow parameters pi_flow: bool = False, pi_flow_steps: int = 1, pi_flow_scale: float = 0.2, pi_flow_use_gate: bool = True, **kwargs ): """ Load base model and convert to Asterisk architecture Args: base_model_path: Path to base SmolLM2 model hybrid_layer_indices: Which layers to make hybrid (None for all) aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size) aspp_num_steps: Number of evolution steps K for ASPP (default: 2) aspp_dropout: Dropout rate for ASPP regularization (default: 0.1) aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent) pi_flow: Enable π-flow refinement step (default: False) pi_flow_steps: Number of flow refinement steps (default: 1) pi_flow_scale: Initial flow scale parameter (default: 0.2) pi_flow_use_gate: Use token-wise adaptive gating (default: True) """ # Load base model base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs) base_config = base_model.config # Create Asterisk config from base config with ASPP + π-flow params asterisk_config = AsteriskConfig( **base_config.to_dict(), hybrid_layer_indices=hybrid_layer_indices, aspp_hidden_dim=aspp_hidden_dim, aspp_num_steps=aspp_num_steps, aspp_dropout=aspp_dropout, aspp_num_neighbors=aspp_num_neighbors, pi_flow=pi_flow, pi_flow_steps=pi_flow_steps, pi_flow_scale=pi_flow_scale, pi_flow_use_gate=pi_flow_use_gate, ) # Create Asterisk model asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) # Transfer weights from base model (non-hybrid layers and embeddings) asterisk_model.load_state_dict(base_model.state_dict(), strict=False) print(f"✓ Converted base model to Asterisk architecture with Graph Propagation") print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}") aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)" print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}") if pi_flow: print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}") return asterisk_model, base_model # Register the model for AutoModel AutoConfig.register("asterisk", AsteriskConfig) AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM) def get_model_info(model): """Print model architecture information""" total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" • Total parameters: {total_params:,}") print(f" • Trainable parameters: {trainable_params:,}") print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)") if isinstance(model, AsteriskForCausalLM): print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}") print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}")