| | """ |
| | Hybrid ASPP-Attention Architecture (Asterisk Model) |
| | Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms |
| | to enhance model expressiveness while maintaining efficiency. |
| | |
| | Architecture Design: |
| | - Hybrid layers: Standard attention + ASPP operator in parallel |
| | - Gate mechanism for dynamic fusion |
| | - Knowledge distillation from SmolLM2-135M base model |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel |
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaAttention, |
| | LlamaDecoderLayer, |
| | LlamaRMSNorm, |
| | LlamaMLP, |
| | ) |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| | from typing import Optional, Tuple, List |
| |
|
| |
|
| | class AsteriskConfig(LlamaConfig): |
| | """ |
| | Configuration class for Asterisk model. |
| | Inherits from LlamaConfig with custom model_type. |
| | """ |
| | model_type = "asterisk" |
| |
|
| | def __init__( |
| | self, |
| | hybrid_layer_indices: Optional[List[int]] = None, |
| | aspp_hidden_dim: Optional[int] = None, |
| | aspp_num_steps: int = 2, |
| | aspp_dropout: float = 0.1, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.hybrid_layer_indices = hybrid_layer_indices |
| | self.aspp_hidden_dim = aspp_hidden_dim |
| | self.aspp_num_steps = aspp_num_steps |
| | self.aspp_dropout = aspp_dropout |
| |
|
| |
|
| | class ASPPOperator(nn.Module): |
| | """ |
| | Asterisk Operator (ASPP) - Point-wise Parallel Propagation |
| | |
| | Simplified version WITHOUT neighbor gathering to reduce overfitting: |
| | - Optional dimensionality reduction for efficiency |
| | - Point-wise evolution: h_i^(t+1) = φ(h_i^(t)) [NO neighbors] |
| | - Multi-step evolution for depth without added complexity |
| | - Dropout for regularization |
| | |
| | Args: |
| | hidden_size: Dimension of hidden states (input/output) |
| | aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size) |
| | num_steps: Number of evolution steps K (default: 2) |
| | dropout: Dropout rate for regularization (default: 0.1) |
| | """ |
| |
|
| | def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1): |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.aspp_hidden_dim = aspp_hidden_dim or hidden_size |
| | self.num_steps = num_steps |
| |
|
| | |
| | self.use_projection = (self.aspp_hidden_dim != hidden_size) |
| | if self.use_projection: |
| | self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim) |
| | self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size) |
| | self.proj_dropout = nn.Dropout(dropout) |
| |
|
| | |
| | |
| | self.update_net = nn.Sequential( |
| | nn.Linear(self.aspp_hidden_dim, self.aspp_hidden_dim * 2), |
| | nn.SiLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim), |
| | nn.Dropout(dropout), |
| | ) |
| |
|
| | |
| | |
| | self.k_logit = nn.Parameter(torch.tensor(1.0)) |
| |
|
| | |
| | self.residual_scale = nn.Parameter(torch.tensor(0.1)) |
| |
|
| | |
| | self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states: [batch_size, seq_len, hidden_size] |
| | Returns: |
| | evolved_states: [batch_size, seq_len, hidden_size] |
| | """ |
| | |
| | if self.use_projection: |
| | h_t = self.down_proj(hidden_states) |
| | h_t = self.proj_dropout(h_t) |
| | else: |
| | h_t = hidden_states |
| |
|
| | |
| | k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps)) |
| |
|
| | |
| | for t in range(k_steps): |
| | |
| | h_t_next = self.update_net(h_t) |
| |
|
| | |
| | h_t = h_t + self.residual_scale * h_t_next |
| | h_t = self.norm(h_t) |
| |
|
| | |
| | if self.use_projection: |
| | h_t = self.up_proj(h_t) |
| | h_t = self.proj_dropout(h_t) |
| |
|
| | return h_t |
| |
|
| |
|
| | class HybridASPPAttentionLayer(LlamaDecoderLayer): |
| | """ |
| | Hybrid layer combining ASPP operator and standard attention |
| | Inherits from LlamaDecoderLayer to maintain compatibility |
| | |
| | Architecture: |
| | 1. Parallel branches: |
| | - ASPP operator for local structured reasoning |
| | - Standard LlamaAttention for global context |
| | 2. Gated fusion of both outputs |
| | 3. Feed-forward network |
| | """ |
| |
|
| | def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): |
| | |
| | super().__init__(config, layer_idx) |
| |
|
| | |
| | self.aspp_operator = ASPPOperator( |
| | hidden_size=config.hidden_size, |
| | aspp_hidden_dim=aspp_hidden_dim, |
| | num_steps=aspp_num_steps, |
| | dropout=aspp_dropout |
| | ) |
| |
|
| | |
| | self.fusion_gate = nn.Sequential( |
| | nn.Linear(config.hidden_size * 2, config.hidden_size), |
| | nn.Dropout(aspp_dropout), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | self.fusion_gate[0].bias.fill_(0.0) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values = None, |
| | use_cache: Optional[bool] = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | """ |
| | Override LlamaDecoderLayer.forward to add ASPP branch |
| | Returns single tensor to match LlamaDecoderLayer API in transformers 4.57.6 |
| | """ |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | |
| | aspp_output = self.aspp_operator(hidden_states) |
| |
|
| | |
| | attn_outputs = self.self_attn( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | cache_position=cache_position, |
| | position_embeddings=position_embeddings, |
| | ) |
| |
|
| | attn_output = attn_outputs[0] |
| |
|
| | |
| | fusion_input = torch.cat([aspp_output, attn_output], dim=-1) |
| | gate = self.fusion_gate(fusion_input) |
| |
|
| | |
| | fused_output = gate * aspp_output + (1 - gate) * attn_output |
| |
|
| | |
| | hidden_states = residual + fused_output |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | return hidden_states |
| |
|
| |
|
| | class AsteriskLlamaModel(LlamaModel): |
| | """ |
| | Asterisk-Llama model with full hybrid ASPP-Attention architecture |
| | |
| | All layers use hybrid ASPP+Attention by default for maximum expressiveness. |
| | """ |
| |
|
| | def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): |
| | super().__init__(config) |
| |
|
| | |
| | if hybrid_layer_indices is None: |
| | |
| | num_layers = config.num_hidden_layers |
| | hybrid_layer_indices = list(range(num_layers)) |
| |
|
| | self.hybrid_layer_indices = hybrid_layer_indices |
| |
|
| | |
| | for idx in hybrid_layer_indices: |
| | if idx < len(self.layers): |
| | self.layers[idx] = HybridASPPAttentionLayer( |
| | config, |
| | layer_idx=idx, |
| | aspp_hidden_dim=aspp_hidden_dim, |
| | aspp_num_steps=aspp_num_steps, |
| | aspp_dropout=aspp_dropout |
| | ) |
| |
|
| | |
| | self.post_init() |
| |
|
| |
|
| | class AsteriskForCausalLM(LlamaForCausalLM): |
| | """ |
| | Asterisk Causal LM with Hybrid ASPP-Attention architecture |
| | |
| | Registered as: AsteriskForCausalLM |
| | """ |
| |
|
| | config_class = AsteriskConfig |
| |
|
| | def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): |
| | |
| | if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'): |
| | hybrid_layer_indices = config.hybrid_layer_indices |
| | if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'): |
| | aspp_hidden_dim = config.aspp_hidden_dim |
| | if hasattr(config, 'aspp_num_steps'): |
| | aspp_num_steps = config.aspp_num_steps |
| | if hasattr(config, 'aspp_dropout'): |
| | aspp_dropout = config.aspp_dropout |
| |
|
| | super().__init__(config) |
| |
|
| | |
| | self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout) |
| |
|
| | |
| | self.config.hybrid_layer_indices = hybrid_layer_indices |
| |
|
| | |
| | self.post_init() |
| |
|
| | @classmethod |
| | def from_pretrained_base( |
| | cls, |
| | base_model_path: str, |
| | hybrid_layer_indices: Optional[List[int]] = None, |
| | aspp_hidden_dim: Optional[int] = None, |
| | aspp_num_steps: int = 2, |
| | aspp_dropout: float = 0.1, |
| | **kwargs |
| | ): |
| | """ |
| | Load base model and convert to Asterisk architecture |
| | |
| | Args: |
| | base_model_path: Path to base SmolLM2 model |
| | hybrid_layer_indices: Which layers to make hybrid (None for all) |
| | aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size) |
| | aspp_num_steps: Number of evolution steps K for ASPP (default: 2) |
| | aspp_dropout: Dropout rate for ASPP regularization (default: 0.1) |
| | """ |
| | |
| | base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs) |
| | base_config = base_model.config |
| |
|
| | |
| | asterisk_config = AsteriskConfig( |
| | **base_config.to_dict(), |
| | hybrid_layer_indices=hybrid_layer_indices, |
| | aspp_hidden_dim=aspp_hidden_dim, |
| | aspp_num_steps=aspp_num_steps, |
| | aspp_dropout=aspp_dropout |
| | ) |
| |
|
| | |
| | asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout) |
| |
|
| | |
| | asterisk_model.load_state_dict(base_model.state_dict(), strict=False) |
| |
|
| | print(f"✓ Converted base model to Asterisk architecture") |
| | print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}") |
| | aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)" |
| | print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}") |
| |
|
| | return asterisk_model, base_model |
| |
|
| |
|
| | |
| | AutoConfig.register("asterisk", AsteriskConfig) |
| | AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM) |
| |
|
| |
|
| | def get_model_info(model): |
| | """Print model architecture information""" |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| |
|
| | print(f" • Total parameters: {total_params:,}") |
| | print(f" • Trainable parameters: {trainable_params:,}") |
| | print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)") |
| |
|
| | if isinstance(model, AsteriskForCausalLM): |
| | print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}") |
| | print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}") |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | print("=" * 80) |
| | print("Asterisk Architecture - ASPP + Standard Attention") |
| | print("=" * 80) |
| |
|
| | |
| | base_model_path = "SmolLM2-135M-Instruct" |
| |
|
| | |
| | print("\n🔧 Creating Asterisk model...") |
| | asterisk_model, base_model = AsteriskForCausalLM.from_pretrained_base( |
| | base_model_path, |
| | hybrid_layer_indices=None, |
| | aspp_num_steps=2, |
| | aspp_neighbor_radius=1, |
| | aspp_dropout=0.1, |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto", |
| | ) |
| |
|
| | print("\n📊 Base model info:") |
| | get_model_info(base_model) |
| |
|
| | print("\n📊 Asterisk model info:") |
| | get_model_info(asterisk_model) |
| |
|
| | print("\n✨ Model ready for training!") |
| |
|