Asterisk / AsteriskForCausalLM.py
OzTianlu's picture
Update AsteriskForCausalLM.py
33a36ae verified
"""
Hybrid ASPP-Attention Architecture (Asterisk Model)
Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms
to enhance model expressiveness while maintaining efficiency.
Architecture Design:
- Hybrid layers: Standard attention + ASPP operator in parallel
- Gate mechanism for dynamic fusion
- Knowledge distillation from SmolLM2-135M base model
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaRMSNorm,
LlamaMLP,
)
from transformers import AutoConfig, AutoModelForCausalLM
from typing import Optional, Tuple, List
class AsteriskConfig(LlamaConfig):
"""
Configuration class for Asterisk model.
Inherits from LlamaConfig with custom model_type.
"""
model_type = "asterisk"
def __init__(
self,
hybrid_layer_indices: Optional[List[int]] = None,
aspp_hidden_dim: Optional[int] = None,
aspp_num_steps: int = 2,
aspp_dropout: float = 0.1,
**kwargs
):
super().__init__(**kwargs)
self.hybrid_layer_indices = hybrid_layer_indices
self.aspp_hidden_dim = aspp_hidden_dim
self.aspp_num_steps = aspp_num_steps
self.aspp_dropout = aspp_dropout
class ASPPOperator(nn.Module):
"""
Asterisk Operator (ASPP) - Point-wise Parallel Propagation
Simplified version WITHOUT neighbor gathering to reduce overfitting:
- Optional dimensionality reduction for efficiency
- Point-wise evolution: h_i^(t+1) = φ(h_i^(t)) [NO neighbors]
- Multi-step evolution for depth without added complexity
- Dropout for regularization
Args:
hidden_size: Dimension of hidden states (input/output)
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
num_steps: Number of evolution steps K (default: 2)
dropout: Dropout rate for regularization (default: 0.1)
"""
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1):
super().__init__()
self.hidden_size = hidden_size
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
self.num_steps = num_steps
# Projection to lower dimension (if specified)
self.use_projection = (self.aspp_hidden_dim != hidden_size)
if self.use_projection:
self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim)
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
self.proj_dropout = nn.Dropout(dropout)
# Point-wise update function φ - NO neighbor gathering
# Much smaller: only processes current position
self.update_net = nn.Sequential(
nn.Linear(self.aspp_hidden_dim, self.aspp_hidden_dim * 2),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
nn.Dropout(dropout),
)
# Learnable K-step parameter
# sigmoid(1.0) ≈ 0.73, giving k_steps ≈ 1.5 → 2 steps initially
self.k_logit = nn.Parameter(torch.tensor(1.0))
# Learnable residual scale
self.residual_scale = nn.Parameter(torch.tensor(0.1))
# Layer norm for stability
self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: [batch_size, seq_len, hidden_size]
Returns:
evolved_states: [batch_size, seq_len, hidden_size]
"""
# Project to lower dimension if needed
if self.use_projection:
h_t = self.down_proj(hidden_states)
h_t = self.proj_dropout(h_t)
else:
h_t = hidden_states
# Learnable number of steps
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
# K-step point-wise evolution (NO neighbor gathering)
for t in range(k_steps):
# Apply point-wise update rule φ
h_t_next = self.update_net(h_t)
# Scaled residual connection for stability
h_t = h_t + self.residual_scale * h_t_next
h_t = self.norm(h_t)
# Project back to original dimension if needed
if self.use_projection:
h_t = self.up_proj(h_t)
h_t = self.proj_dropout(h_t)
return h_t
class HybridASPPAttentionLayer(LlamaDecoderLayer):
"""
Hybrid layer combining ASPP operator and standard attention
Inherits from LlamaDecoderLayer to maintain compatibility
Architecture:
1. Parallel branches:
- ASPP operator for local structured reasoning
- Standard LlamaAttention for global context
2. Gated fusion of both outputs
3. Feed-forward network
"""
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
# Initialize parent LlamaDecoderLayer
super().__init__(config, layer_idx)
# Add ASPP branch
self.aspp_operator = ASPPOperator(
hidden_size=config.hidden_size,
aspp_hidden_dim=aspp_hidden_dim,
num_steps=aspp_num_steps,
dropout=aspp_dropout
)
# Gated fusion mechanism with dropout
self.fusion_gate = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size),
nn.Dropout(aspp_dropout),
nn.Sigmoid()
)
# Initialize gate to be balanced (output 0.5 initially)
with torch.no_grad():
self.fusion_gate[0].bias.fill_(0.0) # sigmoid(0) = 0.5
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
"""
Override LlamaDecoderLayer.forward to add ASPP branch
Returns single tensor to match LlamaDecoderLayer API in transformers 4.57.6
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# ASPP branch
aspp_output = self.aspp_operator(hidden_states)
# Attention branch - use parent's self_attn
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
attn_output = attn_outputs[0]
# Gated fusion
fusion_input = torch.cat([aspp_output, attn_output], dim=-1)
gate = self.fusion_gate(fusion_input)
# Combine with gating: gate * ASPP + (1-gate) * Attention
fused_output = gate * aspp_output + (1 - gate) * attn_output
# Residual connection
hidden_states = residual + fused_output
# MLP block (use parent's mlp)
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
# Return single tensor like LlamaDecoderLayer
return hidden_states
class AsteriskLlamaModel(LlamaModel):
"""
Asterisk-Llama model with full hybrid ASPP-Attention architecture
All layers use hybrid ASPP+Attention by default for maximum expressiveness.
"""
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
super().__init__(config)
# Determine which layers to make hybrid (default: ALL layers)
if hybrid_layer_indices is None:
# Use ALL layers as hybrid (full hybrid architecture)
num_layers = config.num_hidden_layers
hybrid_layer_indices = list(range(num_layers))
self.hybrid_layer_indices = hybrid_layer_indices
# Replace specified layers with hybrid layers
for idx in hybrid_layer_indices:
if idx < len(self.layers):
self.layers[idx] = HybridASPPAttentionLayer(
config,
layer_idx=idx,
aspp_hidden_dim=aspp_hidden_dim,
aspp_num_steps=aspp_num_steps,
aspp_dropout=aspp_dropout
)
# Initialize weights
self.post_init()
class AsteriskForCausalLM(LlamaForCausalLM):
"""
Asterisk Causal LM with Hybrid ASPP-Attention architecture
Registered as: AsteriskForCausalLM
"""
config_class = AsteriskConfig
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
# Read all ASPP parameters from config if not explicitly provided
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
hybrid_layer_indices = config.hybrid_layer_indices
if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'):
aspp_hidden_dim = config.aspp_hidden_dim
if hasattr(config, 'aspp_num_steps'):
aspp_num_steps = config.aspp_num_steps
if hasattr(config, 'aspp_dropout'):
aspp_dropout = config.aspp_dropout
super().__init__(config)
# Replace model with Asterisk version
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
# Store hybrid layer info in config for serialization
self.config.hybrid_layer_indices = hybrid_layer_indices
# Initialize weights
self.post_init()
@classmethod
def from_pretrained_base(
cls,
base_model_path: str,
hybrid_layer_indices: Optional[List[int]] = None,
aspp_hidden_dim: Optional[int] = None,
aspp_num_steps: int = 2,
aspp_dropout: float = 0.1,
**kwargs
):
"""
Load base model and convert to Asterisk architecture
Args:
base_model_path: Path to base SmolLM2 model
hybrid_layer_indices: Which layers to make hybrid (None for all)
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
"""
# Load base model
base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs)
base_config = base_model.config
# Create Asterisk config from base config with ASPP params
asterisk_config = AsteriskConfig(
**base_config.to_dict(),
hybrid_layer_indices=hybrid_layer_indices,
aspp_hidden_dim=aspp_hidden_dim,
aspp_num_steps=aspp_num_steps,
aspp_dropout=aspp_dropout
)
# Create Asterisk model
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
# Transfer weights from base model (non-hybrid layers and embeddings)
asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
print(f"✓ Converted base model to Asterisk architecture")
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}")
return asterisk_model, base_model
# Register the model for AutoModel
AutoConfig.register("asterisk", AsteriskConfig)
AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM)
def get_model_info(model):
"""Print model architecture information"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" • Total parameters: {total_params:,}")
print(f" • Trainable parameters: {trainable_params:,}")
print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)")
if isinstance(model, AsteriskForCausalLM):
print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}")
print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}")
# Example usage
if __name__ == "__main__":
print("=" * 80)
print("Asterisk Architecture - ASPP + Standard Attention")
print("=" * 80)
# Configuration
base_model_path = "SmolLM2-135M-Instruct"
# Create Asterisk model
print("\n🔧 Creating Asterisk model...")
asterisk_model, base_model = AsteriskForCausalLM.from_pretrained_base(
base_model_path,
hybrid_layer_indices=None, # Auto-select ALL layers (full hybrid)
aspp_num_steps=2, # Reduced from 3
aspp_neighbor_radius=1, # Reduced from 2
aspp_dropout=0.1, # Added dropout
torch_dtype=torch.bfloat16,
device_map="auto",
)
print("\n📊 Base model info:")
get_model_info(base_model)
print("\n📊 Asterisk model info:")
get_model_info(asterisk_model)
print("\n✨ Model ready for training!")