Asterisk-Pi / AsteriskForCausalLM.py
OzTianlu's picture
Upload 14 files
3356706 verified
"""
Hybrid ASPP-Attention Architecture (Asterisk Model)
Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms
to enhance model expressiveness while maintaining efficiency.
Architecture Design:
- Hybrid layers: Standard attention + ASPP operator in parallel
- Gate mechanism for dynamic fusion
- Knowledge distillation from SmolLM2-135M base model
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaRMSNorm,
LlamaMLP,
)
from transformers import AutoConfig, AutoModelForCausalLM
from typing import Optional, Tuple, List
class AsteriskConfig(LlamaConfig):
"""
Configuration class for Asterisk model.
Inherits from LlamaConfig with custom model_type.
"""
model_type = "asterisk"
def __init__(
self,
hybrid_layer_indices: Optional[List[int]] = None,
aspp_hidden_dim: Optional[int] = None,
aspp_num_steps: int = 2,
aspp_dropout: float = 0.1,
aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
# π-flow parameters
pi_flow: bool = False,
pi_flow_steps: int = 1,
pi_flow_scale: float = 0.2,
pi_flow_use_gate: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.hybrid_layer_indices = hybrid_layer_indices
self.aspp_hidden_dim = aspp_hidden_dim
self.aspp_num_steps = aspp_num_steps
self.aspp_dropout = aspp_dropout
self.aspp_num_neighbors = aspp_num_neighbors
# π-flow config
self.pi_flow = pi_flow
self.pi_flow_steps = pi_flow_steps
self.pi_flow_scale = pi_flow_scale
self.pi_flow_use_gate = pi_flow_use_gate
class ASPPOperator(nn.Module):
"""
Asterisk Operator (ASPP) - Union-Find Graph Propagation
Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections:
- Each position maintains a parent pointer: parent[i]
- Initial structure: parent[i] = max(0, i-1) (linear chain)
- Message passing: aggregate self + parent features
- Can apply path compression for optimization
Advantages:
- O(n) complexity with simple indexing
- Dynamic grouping of related positions
- Efficient parent-only propagation (no complex gather)
- Nearly constant time find with path compression
Complexity: O(n) with α(n) ≈ O(1) per operation
Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i])
Args:
hidden_size: Dimension of hidden states (input/output)
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
num_steps: Number of evolution steps K (default: 2)
dropout: Dropout rate for regularization (default: 0.1)
num_neighbors: Fixed at 1 (only parent) for Union-Find structure
"""
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1):
super().__init__()
self.hidden_size = hidden_size
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
self.num_steps = num_steps
self.num_neighbors = 1 # Fixed: only parent
# Projection to lower dimension (if specified)
self.use_projection = (self.aspp_hidden_dim != hidden_size)
if self.use_projection:
self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim)
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
self.proj_dropout = nn.Dropout(dropout)
# Message aggregation function: combines self + parent
self.message_net = nn.Sequential(
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
nn.Dropout(dropout),
)
# Learnable K-step parameter
self.k_logit = nn.Parameter(torch.tensor(1.0))
# Learnable residual scale
self.residual_scale = nn.Parameter(torch.tensor(0.1))
# Layer norm for stability
self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5)
def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor:
"""
Compute parent index for each position using Union-Find structure
Simple implementation: parent[i] = i-1 (linear chain)
- Position 0 points to itself (root)
- All others point to previous position
Can be extended with dynamic union operations based on:
- Semantic similarity
- Positional heuristics
- Learned grouping
Returns: [seq_len] tensor of parent indices
"""
# Initialize: parent[i] = max(0, i-1)
parent_indices = torch.arange(seq_len, device=device) - 1
parent_indices[0] = 0 # Root points to itself
parent_indices = torch.clamp(parent_indices, 0, seq_len - 1)
return parent_indices
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: [batch_size, seq_len, hidden_size]
Returns:
evolved_states: [batch_size, seq_len, hidden_size]
"""
batch_size, seq_len, _ = hidden_states.shape
# Project to lower dimension if needed
if self.use_projection:
h_t = self.down_proj(hidden_states)
h_t = self.proj_dropout(h_t)
else:
h_t = hidden_states
# Learnable number of steps
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
# K-step Union-Find graph propagation
for t in range(k_steps):
# 1. Compute parent indices using Union-Find structure
parent_indices = self.compute_parent_indices(seq_len, h_t.device) # [L]
# 2. Gather parent features (super simple indexing!)
# h_t: [B, L, D], parent_indices: [L]
# Just gather from parent positions
parent_features = h_t[:, parent_indices, :] # [B, L, D]
# 3. Message passing: combine self + parent
message_input = torch.cat([h_t, parent_features], dim=-1) # [B, L, 2D]
h_t_next = self.message_net(message_input) # [B, L, D]
# 4. Scaled residual connection for stability
h_t = h_t + self.residual_scale * h_t_next
h_t = self.norm(h_t)
# Project back to original dimension if needed
if self.use_projection:
h_t = self.up_proj(h_t)
h_t = self.proj_dropout(h_t)
return h_t
class HybridASPPAttentionLayer(LlamaDecoderLayer):
"""
Hybrid layer combining ASPP operator and standard attention
Inherits from LlamaDecoderLayer to maintain compatibility
Architecture:
1. Parallel branches:
- ASPP operator for local structured reasoning
- Standard LlamaAttention for global context
2. Gated fusion of both outputs
3. π-flow refinement (optional, per-layer)
4. Feed-forward network
"""
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1):
# Initialize parent LlamaDecoderLayer
super().__init__(config, layer_idx)
# Add ASPP branch
self.aspp_operator = ASPPOperator(
hidden_size=config.hidden_size,
aspp_hidden_dim=aspp_hidden_dim,
num_steps=aspp_num_steps,
dropout=aspp_dropout,
num_neighbors=aspp_num_neighbors
)
# Gated fusion mechanism with dropout
self.fusion_gate = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size),
nn.Dropout(aspp_dropout),
nn.Sigmoid()
)
# Initialize gate to be balanced (output 0.5 initially)
with torch.no_grad():
self.fusion_gate[0].bias.fill_(0.0) # sigmoid(0) = 0.5
# π-flow: Per-layer refinement ASPP
if getattr(config, 'pi_flow', False):
self.pi_flow_aspp = ASPPOperator(
hidden_size=config.hidden_size,
aspp_hidden_dim=aspp_hidden_dim,
num_steps=aspp_num_steps,
dropout=aspp_dropout,
num_neighbors=aspp_num_neighbors
)
# Learnable flow scale (per-layer)
self.pi_flow_scale = nn.Parameter(
torch.tensor(getattr(config, 'pi_flow_scale', 0.2))
)
# Token-wise adaptive gating (optional)
if getattr(config, 'pi_flow_use_gate', True):
self.pi_flow_gate = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size // 4),
nn.SiLU(),
nn.Dropout(aspp_dropout),
nn.Linear(config.hidden_size // 4, 1),
nn.Sigmoid()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
"""
Override LlamaDecoderLayer.forward to add ASPP branch and π-flow
Returns single tensor like LlamaDecoderLayer
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# ASPP branch
aspp_output = self.aspp_operator(hidden_states)
# Attention branch - use parent's self_attn (returns tuple, discard cache with _)
attn_output, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
# Gated fusion
fusion_input = torch.cat([aspp_output, attn_output], dim=-1)
gate = self.fusion_gate(fusion_input)
# Combine with gating: gate * ASPP + (1-gate) * Attention
fused_output = gate * aspp_output + (1 - gate) * attn_output
# Residual connection
hidden_states = residual + fused_output
# π-flow: Multi-step refinement in probability space (per-layer)
if hasattr(self, 'pi_flow_aspp'):
pi_flow_steps = getattr(self.config if hasattr(self, 'config') else kwargs.get('config'), 'pi_flow_steps', 1)
for step in range(pi_flow_steps):
# Compute velocity field v(h) using ASPP
v = self.pi_flow_aspp(hidden_states)
# Compute adaptive gate (per-token flow strength)
if hasattr(self, 'pi_flow_gate'):
gate = self.pi_flow_gate(hidden_states) # [B, L, 1]
alpha = self.pi_flow_scale * gate
else:
alpha = self.pi_flow_scale
# Euler step: h' = h + α * v(h)
hidden_states = hidden_states + alpha * v
# MLP block (use parent's mlp)
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
# Return only hidden_states tensor, like LlamaDecoderLayer
return hidden_states
class AsteriskLlamaModel(LlamaModel):
"""
Asterisk-Llama model with full hybrid ASPP-Attention architecture
All layers use hybrid ASPP+Attention by default for maximum expressiveness.
"""
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
super().__init__(config)
# Determine which layers to make hybrid (default: ALL layers)
if hybrid_layer_indices is None:
# Use ALL layers as hybrid (full hybrid architecture)
num_layers = config.num_hidden_layers
hybrid_layer_indices = list(range(num_layers))
self.hybrid_layer_indices = hybrid_layer_indices
# Replace specified layers with hybrid layers (with per-layer π-flow if enabled)
for idx in hybrid_layer_indices:
if idx < len(self.layers):
self.layers[idx] = HybridASPPAttentionLayer(
config,
layer_idx=idx,
aspp_hidden_dim=aspp_hidden_dim,
aspp_num_steps=aspp_num_steps,
aspp_dropout=aspp_dropout,
aspp_num_neighbors=aspp_num_neighbors
)
# Initialize weights
self.post_init()
class AsteriskForCausalLM(LlamaForCausalLM):
"""
Asterisk Causal LM with Hybrid ASPP-Attention architecture
Registered as: AsteriskForCausalLM
"""
config_class = AsteriskConfig
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
# Read all ASPP parameters from config if not explicitly provided
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
hybrid_layer_indices = config.hybrid_layer_indices
if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'):
aspp_hidden_dim = config.aspp_hidden_dim
if hasattr(config, 'aspp_num_steps'):
aspp_num_steps = config.aspp_num_steps
if hasattr(config, 'aspp_dropout'):
aspp_dropout = config.aspp_dropout
if hasattr(config, 'aspp_num_neighbors'):
aspp_num_neighbors = config.aspp_num_neighbors
super().__init__(config)
# Replace model with Asterisk version
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
# Store hybrid layer info in config for serialization
self.config.hybrid_layer_indices = hybrid_layer_indices
# Initialize weights
self.post_init()
@classmethod
def from_pretrained_base(
cls,
base_model_path: str,
hybrid_layer_indices: Optional[List[int]] = None,
aspp_hidden_dim: Optional[int] = None,
aspp_num_steps: int = 2,
aspp_dropout: float = 0.1,
aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
# π-flow parameters
pi_flow: bool = False,
pi_flow_steps: int = 1,
pi_flow_scale: float = 0.2,
pi_flow_use_gate: bool = True,
**kwargs
):
"""
Load base model and convert to Asterisk architecture
Args:
base_model_path: Path to base SmolLM2 model
hybrid_layer_indices: Which layers to make hybrid (None for all)
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent)
pi_flow: Enable π-flow refinement step (default: False)
pi_flow_steps: Number of flow refinement steps (default: 1)
pi_flow_scale: Initial flow scale parameter (default: 0.2)
pi_flow_use_gate: Use token-wise adaptive gating (default: True)
"""
# Load base model
base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs)
base_config = base_model.config
# Create Asterisk config from base config with ASPP + π-flow params
asterisk_config = AsteriskConfig(
**base_config.to_dict(),
hybrid_layer_indices=hybrid_layer_indices,
aspp_hidden_dim=aspp_hidden_dim,
aspp_num_steps=aspp_num_steps,
aspp_dropout=aspp_dropout,
aspp_num_neighbors=aspp_num_neighbors,
pi_flow=pi_flow,
pi_flow_steps=pi_flow_steps,
pi_flow_scale=pi_flow_scale,
pi_flow_use_gate=pi_flow_use_gate,
)
# Create Asterisk model
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
# Transfer weights from base model (non-hybrid layers and embeddings)
asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
print(f"✓ Converted base model to Asterisk architecture with Graph Propagation")
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}")
if pi_flow:
print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
return asterisk_model, base_model
# Register the model for AutoModel
AutoConfig.register("asterisk", AsteriskConfig)
AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM)
def get_model_info(model):
"""Print model architecture information"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" • Total parameters: {total_params:,}")
print(f" • Trainable parameters: {trainable_params:,}")
print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)")
if isinstance(model, AsteriskForCausalLM):
print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}")
print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}")