|
|
""" |
|
|
Hybrid ASPP-Attention Architecture (Asterisk Model) |
|
|
Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms |
|
|
to enhance model expressiveness while maintaining efficiency. |
|
|
|
|
|
Architecture Design: |
|
|
- Hybrid layers: Standard attention + ASPP operator in parallel |
|
|
- Gate mechanism for dynamic fusion |
|
|
- Knowledge distillation from SmolLM2-135M base model |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel |
|
|
from transformers.models.llama.modeling_llama import ( |
|
|
LlamaAttention, |
|
|
LlamaDecoderLayer, |
|
|
LlamaRMSNorm, |
|
|
LlamaMLP, |
|
|
) |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
|
|
|
class AsteriskConfig(LlamaConfig): |
|
|
""" |
|
|
Configuration class for Asterisk model. |
|
|
Inherits from LlamaConfig with custom model_type. |
|
|
""" |
|
|
model_type = "asterisk" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hybrid_layer_indices: Optional[List[int]] = None, |
|
|
aspp_hidden_dim: Optional[int] = None, |
|
|
aspp_num_steps: int = 2, |
|
|
aspp_dropout: float = 0.1, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.hybrid_layer_indices = hybrid_layer_indices |
|
|
self.aspp_hidden_dim = aspp_hidden_dim |
|
|
self.aspp_num_steps = aspp_num_steps |
|
|
self.aspp_dropout = aspp_dropout |
|
|
|
|
|
|
|
|
class ASPPOperator(nn.Module): |
|
|
""" |
|
|
Asterisk Operator (ASPP) - Point-wise Parallel Propagation |
|
|
|
|
|
Simplified version WITHOUT neighbor gathering to reduce overfitting: |
|
|
- Optional dimensionality reduction for efficiency |
|
|
- Point-wise evolution: h_i^(t+1) = φ(h_i^(t)) [NO neighbors] |
|
|
- Multi-step evolution for depth without added complexity |
|
|
- Dropout for regularization |
|
|
|
|
|
Args: |
|
|
hidden_size: Dimension of hidden states (input/output) |
|
|
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size) |
|
|
num_steps: Number of evolution steps K (default: 2) |
|
|
dropout: Dropout rate for regularization (default: 0.1) |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size |
|
|
self.num_steps = num_steps |
|
|
|
|
|
|
|
|
self.use_projection = (self.aspp_hidden_dim != hidden_size) |
|
|
if self.use_projection: |
|
|
self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim) |
|
|
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size) |
|
|
self.proj_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
|
|
|
self.update_net = nn.Sequential( |
|
|
nn.Linear(self.aspp_hidden_dim, self.aspp_hidden_dim * 2), |
|
|
nn.SiLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.k_logit = nn.Parameter(torch.tensor(1.0)) |
|
|
|
|
|
|
|
|
self.residual_scale = nn.Parameter(torch.tensor(0.1)) |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: [batch_size, seq_len, hidden_size] |
|
|
Returns: |
|
|
evolved_states: [batch_size, seq_len, hidden_size] |
|
|
""" |
|
|
|
|
|
if self.use_projection: |
|
|
h_t = self.down_proj(hidden_states) |
|
|
h_t = self.proj_dropout(h_t) |
|
|
else: |
|
|
h_t = hidden_states |
|
|
|
|
|
|
|
|
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps)) |
|
|
|
|
|
|
|
|
for t in range(k_steps): |
|
|
|
|
|
h_t_next = self.update_net(h_t) |
|
|
|
|
|
|
|
|
h_t = h_t + self.residual_scale * h_t_next |
|
|
h_t = self.norm(h_t) |
|
|
|
|
|
|
|
|
if self.use_projection: |
|
|
h_t = self.up_proj(h_t) |
|
|
h_t = self.proj_dropout(h_t) |
|
|
|
|
|
return h_t |
|
|
|
|
|
|
|
|
class HybridASPPAttentionLayer(LlamaDecoderLayer): |
|
|
""" |
|
|
Hybrid layer combining ASPP operator and standard attention |
|
|
Inherits from LlamaDecoderLayer to maintain compatibility |
|
|
|
|
|
Architecture: |
|
|
1. Parallel branches: |
|
|
- ASPP operator for local structured reasoning |
|
|
- Standard LlamaAttention for global context |
|
|
2. Gated fusion of both outputs |
|
|
3. Feed-forward network |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): |
|
|
|
|
|
super().__init__(config, layer_idx) |
|
|
|
|
|
|
|
|
self.aspp_operator = ASPPOperator( |
|
|
hidden_size=config.hidden_size, |
|
|
aspp_hidden_dim=aspp_hidden_dim, |
|
|
num_steps=aspp_num_steps, |
|
|
dropout=aspp_dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.fusion_gate = nn.Sequential( |
|
|
nn.Linear(config.hidden_size * 2, config.hidden_size), |
|
|
nn.Dropout(aspp_dropout), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
self.fusion_gate[0].bias.fill_(0.0) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Override LlamaDecoderLayer.forward to add ASPP branch |
|
|
Returns single tensor to match LlamaDecoderLayer API in transformers 4.57.6 |
|
|
""" |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
aspp_output = self.aspp_operator(hidden_states) |
|
|
|
|
|
|
|
|
attn_outputs = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
) |
|
|
|
|
|
attn_output = attn_outputs[0] |
|
|
|
|
|
|
|
|
fusion_input = torch.cat([aspp_output, attn_output], dim=-1) |
|
|
gate = self.fusion_gate(fusion_input) |
|
|
|
|
|
|
|
|
fused_output = gate * aspp_output + (1 - gate) * attn_output |
|
|
|
|
|
|
|
|
hidden_states = residual + fused_output |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class AsteriskLlamaModel(LlamaModel): |
|
|
""" |
|
|
Asterisk-Llama model with full hybrid ASPP-Attention architecture |
|
|
|
|
|
All layers use hybrid ASPP+Attention by default for maximum expressiveness. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
if hybrid_layer_indices is None: |
|
|
|
|
|
num_layers = config.num_hidden_layers |
|
|
hybrid_layer_indices = list(range(num_layers)) |
|
|
|
|
|
self.hybrid_layer_indices = hybrid_layer_indices |
|
|
|
|
|
|
|
|
for idx in hybrid_layer_indices: |
|
|
if idx < len(self.layers): |
|
|
self.layers[idx] = HybridASPPAttentionLayer( |
|
|
config, |
|
|
layer_idx=idx, |
|
|
aspp_hidden_dim=aspp_hidden_dim, |
|
|
aspp_num_steps=aspp_num_steps, |
|
|
aspp_dropout=aspp_dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
class AsteriskForCausalLM(LlamaForCausalLM): |
|
|
""" |
|
|
Asterisk Causal LM with Hybrid ASPP-Attention architecture |
|
|
|
|
|
Registered as: AsteriskForCausalLM |
|
|
""" |
|
|
|
|
|
config_class = AsteriskConfig |
|
|
|
|
|
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1): |
|
|
|
|
|
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'): |
|
|
hybrid_layer_indices = config.hybrid_layer_indices |
|
|
if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'): |
|
|
aspp_hidden_dim = config.aspp_hidden_dim |
|
|
if hasattr(config, 'aspp_num_steps'): |
|
|
aspp_num_steps = config.aspp_num_steps |
|
|
if hasattr(config, 'aspp_dropout'): |
|
|
aspp_dropout = config.aspp_dropout |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout) |
|
|
|
|
|
|
|
|
self.config.hybrid_layer_indices = hybrid_layer_indices |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained_base( |
|
|
cls, |
|
|
base_model_path: str, |
|
|
hybrid_layer_indices: Optional[List[int]] = None, |
|
|
aspp_hidden_dim: Optional[int] = None, |
|
|
aspp_num_steps: int = 2, |
|
|
aspp_dropout: float = 0.1, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Load base model and convert to Asterisk architecture |
|
|
|
|
|
Args: |
|
|
base_model_path: Path to base SmolLM2 model |
|
|
hybrid_layer_indices: Which layers to make hybrid (None for all) |
|
|
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size) |
|
|
aspp_num_steps: Number of evolution steps K for ASPP (default: 2) |
|
|
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1) |
|
|
""" |
|
|
|
|
|
base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs) |
|
|
base_config = base_model.config |
|
|
|
|
|
|
|
|
asterisk_config = AsteriskConfig( |
|
|
**base_config.to_dict(), |
|
|
hybrid_layer_indices=hybrid_layer_indices, |
|
|
aspp_hidden_dim=aspp_hidden_dim, |
|
|
aspp_num_steps=aspp_num_steps, |
|
|
aspp_dropout=aspp_dropout |
|
|
) |
|
|
|
|
|
|
|
|
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout) |
|
|
|
|
|
|
|
|
asterisk_model.load_state_dict(base_model.state_dict(), strict=False) |
|
|
|
|
|
print(f"✓ Converted base model to Asterisk architecture") |
|
|
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}") |
|
|
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)" |
|
|
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}") |
|
|
|
|
|
return asterisk_model, base_model |
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("asterisk", AsteriskConfig) |
|
|
AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM) |
|
|
|
|
|
|
|
|
def get_model_info(model): |
|
|
"""Print model architecture information""" |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
print(f" • Total parameters: {total_params:,}") |
|
|
print(f" • Trainable parameters: {trainable_params:,}") |
|
|
print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)") |
|
|
|
|
|
if isinstance(model, AsteriskForCausalLM): |
|
|
print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}") |
|
|
print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=" * 80) |
|
|
print("Asterisk Architecture - ASPP + Standard Attention") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
base_model_path = "SmolLM2-135M-Instruct" |
|
|
|
|
|
|
|
|
print("\n🔧 Creating Asterisk model...") |
|
|
asterisk_model, base_model = AsteriskForCausalLM.from_pretrained_base( |
|
|
base_model_path, |
|
|
hybrid_layer_indices=None, |
|
|
aspp_num_steps=2, |
|
|
aspp_neighbor_radius=1, |
|
|
aspp_dropout=0.1, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
print("\n📊 Base model info:") |
|
|
get_model_info(base_model) |
|
|
|
|
|
print("\n📊 Asterisk model info:") |
|
|
get_model_info(asterisk_model) |
|
|
|
|
|
print("\n✨ Model ready for training!") |
|
|
|