|
|
""" |
|
|
Hybrid ASPP-Attention Architecture (Asterisk Model) |
|
|
Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms |
|
|
to enhance model expressiveness while maintaining efficiency. |
|
|
|
|
|
Architecture Design: |
|
|
- Hybrid layers: Standard attention + ASPP operator in parallel |
|
|
- Gate mechanism for dynamic fusion |
|
|
- Knowledge distillation from SmolLM2-135M base model |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel |
|
|
from transformers.models.llama.modeling_llama import ( |
|
|
LlamaAttention, |
|
|
LlamaDecoderLayer, |
|
|
LlamaRMSNorm, |
|
|
LlamaMLP, |
|
|
) |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
|
|
|
class AsteriskConfig(LlamaConfig): |
|
|
""" |
|
|
Configuration class for Asterisk model. |
|
|
Inherits from LlamaConfig with custom model_type. |
|
|
""" |
|
|
model_type = "asterisk" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hybrid_layer_indices: Optional[List[int]] = None, |
|
|
aspp_hidden_dim: Optional[int] = None, |
|
|
aspp_num_steps: int = 2, |
|
|
aspp_dropout: float = 0.1, |
|
|
aspp_num_neighbors: int = 1, |
|
|
|
|
|
pi_flow: bool = False, |
|
|
pi_flow_steps: int = 1, |
|
|
pi_flow_scale: float = 0.2, |
|
|
pi_flow_use_gate: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.hybrid_layer_indices = hybrid_layer_indices |
|
|
self.aspp_hidden_dim = aspp_hidden_dim |
|
|
self.aspp_num_steps = aspp_num_steps |
|
|
self.aspp_dropout = aspp_dropout |
|
|
self.aspp_num_neighbors = aspp_num_neighbors |
|
|
|
|
|
self.pi_flow = pi_flow |
|
|
self.pi_flow_steps = pi_flow_steps |
|
|
self.pi_flow_scale = pi_flow_scale |
|
|
self.pi_flow_use_gate = pi_flow_use_gate |
|
|
|
|
|
|
|
|
class ASPPOperator(nn.Module): |
|
|
""" |
|
|
Asterisk Operator (ASPP) - Union-Find Graph Propagation |
|
|
|
|
|
Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections: |
|
|
- Each position maintains a parent pointer: parent[i] |
|
|
- Initial structure: parent[i] = max(0, i-1) (linear chain) |
|
|
- Message passing: aggregate self + parent features |
|
|
- Can apply path compression for optimization |
|
|
|
|
|
Advantages: |
|
|
- O(n) complexity with simple indexing |
|
|
- Dynamic grouping of related positions |
|
|
- Efficient parent-only propagation (no complex gather) |
|
|
- Nearly constant time find with path compression |
|
|
|
|
|
Complexity: O(n) with α(n) ≈ O(1) per operation |
|
|
Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i]) |
|
|
|
|
|
Args: |
|
|
hidden_size: Dimension of hidden states (input/output) |
|
|
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size) |
|
|
num_steps: Number of evolution steps K (default: 2) |
|
|
dropout: Dropout rate for regularization (default: 0.1) |
|
|
num_neighbors: Fixed at 1 (only parent) for Union-Find structure |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size |
|
|
self.num_steps = num_steps |
|
|
self.num_neighbors = 1 |
|
|
|
|
|
|
|
|
self.use_projection = (self.aspp_hidden_dim != hidden_size) |
|
|
if self.use_projection: |
|
|
self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim) |
|
|
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size) |
|
|
self.proj_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.message_net = nn.Sequential( |
|
|
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2), |
|
|
nn.SiLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
|
|
|
self.k_logit = nn.Parameter(torch.tensor(1.0)) |
|
|
|
|
|
|
|
|
self.residual_scale = nn.Parameter(torch.tensor(0.1)) |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5) |
|
|
|
|
|
def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor: |
|
|
""" |
|
|
Compute parent index for each position using Union-Find structure |
|
|
|
|
|
Simple implementation: parent[i] = i-1 (linear chain) |
|
|
- Position 0 points to itself (root) |
|
|
- All others point to previous position |
|
|
|
|
|
Can be extended with dynamic union operations based on: |
|
|
- Semantic similarity |
|
|
- Positional heuristics |
|
|
- Learned grouping |
|
|
|
|
|
Returns: [seq_len] tensor of parent indices |
|
|
""" |
|
|
|
|
|
parent_indices = torch.arange(seq_len, device=device) - 1 |
|
|
parent_indices[0] = 0 |
|
|
parent_indices = torch.clamp(parent_indices, 0, seq_len - 1) |
|
|
|
|
|
return parent_indices |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: [batch_size, seq_len, hidden_size] |
|
|
Returns: |
|
|
evolved_states: [batch_size, seq_len, hidden_size] |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
|
|
|
if self.use_projection: |
|
|
h_t = self.down_proj(hidden_states) |
|
|
h_t = self.proj_dropout(h_t) |
|
|
else: |
|
|
h_t = hidden_states |
|
|
|
|
|
|
|
|
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps)) |
|
|
|
|
|
|
|
|
for t in range(k_steps): |
|
|
|
|
|
parent_indices = self.compute_parent_indices(seq_len, h_t.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parent_features = h_t[:, parent_indices, :] |
|
|
|
|
|
|
|
|
message_input = torch.cat([h_t, parent_features], dim=-1) |
|
|
h_t_next = self.message_net(message_input) |
|
|
|
|
|
|
|
|
h_t = h_t + self.residual_scale * h_t_next |
|
|
h_t = self.norm(h_t) |
|
|
|
|
|
|
|
|
if self.use_projection: |
|
|
h_t = self.up_proj(h_t) |
|
|
h_t = self.proj_dropout(h_t) |
|
|
|
|
|
return h_t |
|
|
|
|
|
|
|
|
class HybridASPPAttentionLayer(LlamaDecoderLayer): |
|
|
""" |
|
|
Hybrid layer combining ASPP operator and standard attention |
|
|
Inherits from LlamaDecoderLayer to maintain compatibility |
|
|
|
|
|
Architecture: |
|
|
1. Parallel branches: |
|
|
- ASPP operator for local structured reasoning |
|
|
- Standard LlamaAttention for global context |
|
|
2. Gated fusion of both outputs |
|
|
3. π-flow refinement (optional, per-layer) |
|
|
4. Feed-forward network |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1): |
|
|
|
|
|
super().__init__(config, layer_idx) |
|
|
|
|
|
|
|
|
self.aspp_operator = ASPPOperator( |
|
|
hidden_size=config.hidden_size, |
|
|
aspp_hidden_dim=aspp_hidden_dim, |
|
|
num_steps=aspp_num_steps, |
|
|
dropout=aspp_dropout, |
|
|
num_neighbors=aspp_num_neighbors |
|
|
) |
|
|
|
|
|
|
|
|
self.fusion_gate = nn.Sequential( |
|
|
nn.Linear(config.hidden_size * 2, config.hidden_size), |
|
|
nn.Dropout(aspp_dropout), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
self.fusion_gate[0].bias.fill_(0.0) |
|
|
|
|
|
|
|
|
if getattr(config, 'pi_flow', False): |
|
|
self.pi_flow_aspp = ASPPOperator( |
|
|
hidden_size=config.hidden_size, |
|
|
aspp_hidden_dim=aspp_hidden_dim, |
|
|
num_steps=aspp_num_steps, |
|
|
dropout=aspp_dropout, |
|
|
num_neighbors=aspp_num_neighbors |
|
|
) |
|
|
|
|
|
|
|
|
self.pi_flow_scale = nn.Parameter( |
|
|
torch.tensor(getattr(config, 'pi_flow_scale', 0.2)) |
|
|
) |
|
|
|
|
|
|
|
|
if getattr(config, 'pi_flow_use_gate', True): |
|
|
self.pi_flow_gate = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.hidden_size // 4), |
|
|
nn.SiLU(), |
|
|
nn.Dropout(aspp_dropout), |
|
|
nn.Linear(config.hidden_size // 4, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Override LlamaDecoderLayer.forward to add ASPP branch and π-flow |
|
|
Returns single tensor like LlamaDecoderLayer |
|
|
""" |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
aspp_output = self.aspp_operator(hidden_states) |
|
|
|
|
|
|
|
|
attn_output, _ = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
position_embeddings=position_embeddings, |
|
|
) |
|
|
|
|
|
|
|
|
fusion_input = torch.cat([aspp_output, attn_output], dim=-1) |
|
|
gate = self.fusion_gate(fusion_input) |
|
|
|
|
|
|
|
|
fused_output = gate * aspp_output + (1 - gate) * attn_output |
|
|
|
|
|
|
|
|
hidden_states = residual + fused_output |
|
|
|
|
|
|
|
|
if hasattr(self, 'pi_flow_aspp'): |
|
|
pi_flow_steps = getattr(self.config if hasattr(self, 'config') else kwargs.get('config'), 'pi_flow_steps', 1) |
|
|
|
|
|
for step in range(pi_flow_steps): |
|
|
|
|
|
v = self.pi_flow_aspp(hidden_states) |
|
|
|
|
|
|
|
|
if hasattr(self, 'pi_flow_gate'): |
|
|
gate = self.pi_flow_gate(hidden_states) |
|
|
alpha = self.pi_flow_scale * gate |
|
|
else: |
|
|
alpha = self.pi_flow_scale |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + alpha * v |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class AsteriskLlamaModel(LlamaModel): |
|
|
""" |
|
|
Asterisk-Llama model with full hybrid ASPP-Attention architecture |
|
|
|
|
|
All layers use hybrid ASPP+Attention by default for maximum expressiveness. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
if hybrid_layer_indices is None: |
|
|
|
|
|
num_layers = config.num_hidden_layers |
|
|
hybrid_layer_indices = list(range(num_layers)) |
|
|
|
|
|
self.hybrid_layer_indices = hybrid_layer_indices |
|
|
|
|
|
|
|
|
for idx in hybrid_layer_indices: |
|
|
if idx < len(self.layers): |
|
|
self.layers[idx] = HybridASPPAttentionLayer( |
|
|
config, |
|
|
layer_idx=idx, |
|
|
aspp_hidden_dim=aspp_hidden_dim, |
|
|
aspp_num_steps=aspp_num_steps, |
|
|
aspp_dropout=aspp_dropout, |
|
|
aspp_num_neighbors=aspp_num_neighbors |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
class AsteriskForCausalLM(LlamaForCausalLM): |
|
|
""" |
|
|
Asterisk Causal LM with Hybrid ASPP-Attention architecture |
|
|
|
|
|
Registered as: AsteriskForCausalLM |
|
|
""" |
|
|
|
|
|
config_class = AsteriskConfig |
|
|
|
|
|
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2): |
|
|
|
|
|
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'): |
|
|
hybrid_layer_indices = config.hybrid_layer_indices |
|
|
if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'): |
|
|
aspp_hidden_dim = config.aspp_hidden_dim |
|
|
if hasattr(config, 'aspp_num_steps'): |
|
|
aspp_num_steps = config.aspp_num_steps |
|
|
if hasattr(config, 'aspp_dropout'): |
|
|
aspp_dropout = config.aspp_dropout |
|
|
if hasattr(config, 'aspp_num_neighbors'): |
|
|
aspp_num_neighbors = config.aspp_num_neighbors |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) |
|
|
|
|
|
|
|
|
self.config.hybrid_layer_indices = hybrid_layer_indices |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained_base( |
|
|
cls, |
|
|
base_model_path: str, |
|
|
hybrid_layer_indices: Optional[List[int]] = None, |
|
|
aspp_hidden_dim: Optional[int] = None, |
|
|
aspp_num_steps: int = 2, |
|
|
aspp_dropout: float = 0.1, |
|
|
aspp_num_neighbors: int = 1, |
|
|
|
|
|
pi_flow: bool = False, |
|
|
pi_flow_steps: int = 1, |
|
|
pi_flow_scale: float = 0.2, |
|
|
pi_flow_use_gate: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Load base model and convert to Asterisk architecture |
|
|
|
|
|
Args: |
|
|
base_model_path: Path to base SmolLM2 model |
|
|
hybrid_layer_indices: Which layers to make hybrid (None for all) |
|
|
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size) |
|
|
aspp_num_steps: Number of evolution steps K for ASPP (default: 2) |
|
|
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1) |
|
|
aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent) |
|
|
pi_flow: Enable π-flow refinement step (default: False) |
|
|
pi_flow_steps: Number of flow refinement steps (default: 1) |
|
|
pi_flow_scale: Initial flow scale parameter (default: 0.2) |
|
|
pi_flow_use_gate: Use token-wise adaptive gating (default: True) |
|
|
""" |
|
|
|
|
|
base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs) |
|
|
base_config = base_model.config |
|
|
|
|
|
|
|
|
asterisk_config = AsteriskConfig( |
|
|
**base_config.to_dict(), |
|
|
hybrid_layer_indices=hybrid_layer_indices, |
|
|
aspp_hidden_dim=aspp_hidden_dim, |
|
|
aspp_num_steps=aspp_num_steps, |
|
|
aspp_dropout=aspp_dropout, |
|
|
aspp_num_neighbors=aspp_num_neighbors, |
|
|
pi_flow=pi_flow, |
|
|
pi_flow_steps=pi_flow_steps, |
|
|
pi_flow_scale=pi_flow_scale, |
|
|
pi_flow_use_gate=pi_flow_use_gate, |
|
|
) |
|
|
|
|
|
|
|
|
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors) |
|
|
|
|
|
|
|
|
asterisk_model.load_state_dict(base_model.state_dict(), strict=False) |
|
|
|
|
|
print(f"✓ Converted base model to Asterisk architecture with Graph Propagation") |
|
|
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}") |
|
|
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)" |
|
|
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}") |
|
|
if pi_flow: |
|
|
print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}") |
|
|
|
|
|
return asterisk_model, base_model |
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("asterisk", AsteriskConfig) |
|
|
AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM) |
|
|
|
|
|
|
|
|
def get_model_info(model): |
|
|
"""Print model architecture information""" |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
print(f" • Total parameters: {total_params:,}") |
|
|
print(f" • Trainable parameters: {trainable_params:,}") |
|
|
print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)") |
|
|
|
|
|
if isinstance(model, AsteriskForCausalLM): |
|
|
print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}") |
|
|
print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}") |
|
|
|