Upload 14 files
Browse files- AsteriskForCausalLM.py +78 -26
- README.md +114 -3
- chat_template.jinja +1 -1
- config.json +2 -1
- model.safetensors +2 -2
- training_args.bin +1 -1
AsteriskForCausalLM.py
CHANGED
|
@@ -36,6 +36,7 @@ class AsteriskConfig(LlamaConfig):
|
|
| 36 |
aspp_hidden_dim: Optional[int] = None,
|
| 37 |
aspp_num_steps: int = 2,
|
| 38 |
aspp_dropout: float = 0.1,
|
|
|
|
| 39 |
# π-flow parameters
|
| 40 |
pi_flow: bool = False,
|
| 41 |
pi_flow_steps: int = 1,
|
|
@@ -48,6 +49,7 @@ class AsteriskConfig(LlamaConfig):
|
|
| 48 |
self.aspp_hidden_dim = aspp_hidden_dim
|
| 49 |
self.aspp_num_steps = aspp_num_steps
|
| 50 |
self.aspp_dropout = aspp_dropout
|
|
|
|
| 51 |
# π-flow config
|
| 52 |
self.pi_flow = pi_flow
|
| 53 |
self.pi_flow_steps = pi_flow_steps
|
|
@@ -57,26 +59,37 @@ class AsteriskConfig(LlamaConfig):
|
|
| 57 |
|
| 58 |
class ASPPOperator(nn.Module):
|
| 59 |
"""
|
| 60 |
-
Asterisk Operator (ASPP) -
|
| 61 |
|
| 62 |
-
|
| 63 |
-
-
|
| 64 |
-
-
|
| 65 |
-
-
|
| 66 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
Args:
|
| 69 |
hidden_size: Dimension of hidden states (input/output)
|
| 70 |
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
|
| 71 |
num_steps: Number of evolution steps K (default: 2)
|
| 72 |
dropout: Dropout rate for regularization (default: 0.1)
|
|
|
|
| 73 |
"""
|
| 74 |
|
| 75 |
-
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1):
|
| 76 |
super().__init__()
|
| 77 |
self.hidden_size = hidden_size
|
| 78 |
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
|
| 79 |
self.num_steps = num_steps
|
|
|
|
| 80 |
|
| 81 |
# Projection to lower dimension (if specified)
|
| 82 |
self.use_projection = (self.aspp_hidden_dim != hidden_size)
|
|
@@ -85,10 +98,9 @@ class ASPPOperator(nn.Module):
|
|
| 85 |
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
|
| 86 |
self.proj_dropout = nn.Dropout(dropout)
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
nn.Linear(self.aspp_hidden_dim, self.aspp_hidden_dim * 2),
|
| 92 |
nn.SiLU(),
|
| 93 |
nn.Dropout(dropout),
|
| 94 |
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
|
|
@@ -96,7 +108,6 @@ class ASPPOperator(nn.Module):
|
|
| 96 |
)
|
| 97 |
|
| 98 |
# Learnable K-step parameter
|
| 99 |
-
# sigmoid(1.0) ≈ 0.73, giving k_steps ≈ 1.5 → 2 steps initially
|
| 100 |
self.k_logit = nn.Parameter(torch.tensor(1.0))
|
| 101 |
|
| 102 |
# Learnable residual scale
|
|
@@ -105,6 +116,28 @@ class ASPPOperator(nn.Module):
|
|
| 105 |
# Layer norm for stability
|
| 106 |
self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 109 |
"""
|
| 110 |
Args:
|
|
@@ -112,6 +145,8 @@ class ASPPOperator(nn.Module):
|
|
| 112 |
Returns:
|
| 113 |
evolved_states: [batch_size, seq_len, hidden_size]
|
| 114 |
"""
|
|
|
|
|
|
|
| 115 |
# Project to lower dimension if needed
|
| 116 |
if self.use_projection:
|
| 117 |
h_t = self.down_proj(hidden_states)
|
|
@@ -122,12 +157,21 @@ class ASPPOperator(nn.Module):
|
|
| 122 |
# Learnable number of steps
|
| 123 |
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
|
| 124 |
|
| 125 |
-
# K-step
|
| 126 |
for t in range(k_steps):
|
| 127 |
-
#
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
# Scaled residual connection for stability
|
| 131 |
h_t = h_t + self.residual_scale * h_t_next
|
| 132 |
h_t = self.norm(h_t)
|
| 133 |
|
|
@@ -153,7 +197,7 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
|
|
| 153 |
4. Feed-forward network
|
| 154 |
"""
|
| 155 |
|
| 156 |
-
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
|
| 157 |
# Initialize parent LlamaDecoderLayer
|
| 158 |
super().__init__(config, layer_idx)
|
| 159 |
|
|
@@ -162,7 +206,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
|
|
| 162 |
hidden_size=config.hidden_size,
|
| 163 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 164 |
num_steps=aspp_num_steps,
|
| 165 |
-
dropout=aspp_dropout
|
|
|
|
| 166 |
)
|
| 167 |
|
| 168 |
# Gated fusion mechanism with dropout
|
|
@@ -182,7 +227,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
|
|
| 182 |
hidden_size=config.hidden_size,
|
| 183 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 184 |
num_steps=aspp_num_steps,
|
| 185 |
-
dropout=aspp_dropout
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
# Learnable flow scale (per-layer)
|
|
@@ -276,7 +322,7 @@ class AsteriskLlamaModel(LlamaModel):
|
|
| 276 |
All layers use hybrid ASPP+Attention by default for maximum expressiveness.
|
| 277 |
"""
|
| 278 |
|
| 279 |
-
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
|
| 280 |
super().__init__(config)
|
| 281 |
|
| 282 |
# Determine which layers to make hybrid (default: ALL layers)
|
|
@@ -295,7 +341,8 @@ class AsteriskLlamaModel(LlamaModel):
|
|
| 295 |
layer_idx=idx,
|
| 296 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 297 |
aspp_num_steps=aspp_num_steps,
|
| 298 |
-
aspp_dropout=aspp_dropout
|
|
|
|
| 299 |
)
|
| 300 |
|
| 301 |
# Initialize weights
|
|
@@ -311,7 +358,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 311 |
|
| 312 |
config_class = AsteriskConfig
|
| 313 |
|
| 314 |
-
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
|
| 315 |
# Read all ASPP parameters from config if not explicitly provided
|
| 316 |
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
|
| 317 |
hybrid_layer_indices = config.hybrid_layer_indices
|
|
@@ -321,11 +368,13 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 321 |
aspp_num_steps = config.aspp_num_steps
|
| 322 |
if hasattr(config, 'aspp_dropout'):
|
| 323 |
aspp_dropout = config.aspp_dropout
|
|
|
|
|
|
|
| 324 |
|
| 325 |
super().__init__(config)
|
| 326 |
|
| 327 |
# Replace model with Asterisk version
|
| 328 |
-
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
|
| 329 |
|
| 330 |
# Store hybrid layer info in config for serialization
|
| 331 |
self.config.hybrid_layer_indices = hybrid_layer_indices
|
|
@@ -341,6 +390,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 341 |
aspp_hidden_dim: Optional[int] = None,
|
| 342 |
aspp_num_steps: int = 2,
|
| 343 |
aspp_dropout: float = 0.1,
|
|
|
|
| 344 |
# π-flow parameters
|
| 345 |
pi_flow: bool = False,
|
| 346 |
pi_flow_steps: int = 1,
|
|
@@ -357,6 +407,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 357 |
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
|
| 358 |
aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
|
| 359 |
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
|
|
|
|
| 360 |
pi_flow: Enable π-flow refinement step (default: False)
|
| 361 |
pi_flow_steps: Number of flow refinement steps (default: 1)
|
| 362 |
pi_flow_scale: Initial flow scale parameter (default: 0.2)
|
|
@@ -373,6 +424,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 373 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 374 |
aspp_num_steps=aspp_num_steps,
|
| 375 |
aspp_dropout=aspp_dropout,
|
|
|
|
| 376 |
pi_flow=pi_flow,
|
| 377 |
pi_flow_steps=pi_flow_steps,
|
| 378 |
pi_flow_scale=pi_flow_scale,
|
|
@@ -380,15 +432,15 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 380 |
)
|
| 381 |
|
| 382 |
# Create Asterisk model
|
| 383 |
-
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
|
| 384 |
|
| 385 |
# Transfer weights from base model (non-hybrid layers and embeddings)
|
| 386 |
asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
|
| 387 |
|
| 388 |
-
print(f"✓ Converted base model to Asterisk architecture")
|
| 389 |
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
|
| 390 |
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
|
| 391 |
-
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}")
|
| 392 |
if pi_flow:
|
| 393 |
print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
|
| 394 |
|
|
|
|
| 36 |
aspp_hidden_dim: Optional[int] = None,
|
| 37 |
aspp_num_steps: int = 2,
|
| 38 |
aspp_dropout: float = 0.1,
|
| 39 |
+
aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
|
| 40 |
# π-flow parameters
|
| 41 |
pi_flow: bool = False,
|
| 42 |
pi_flow_steps: int = 1,
|
|
|
|
| 49 |
self.aspp_hidden_dim = aspp_hidden_dim
|
| 50 |
self.aspp_num_steps = aspp_num_steps
|
| 51 |
self.aspp_dropout = aspp_dropout
|
| 52 |
+
self.aspp_num_neighbors = aspp_num_neighbors
|
| 53 |
# π-flow config
|
| 54 |
self.pi_flow = pi_flow
|
| 55 |
self.pi_flow_steps = pi_flow_steps
|
|
|
|
| 59 |
|
| 60 |
class ASPPOperator(nn.Module):
|
| 61 |
"""
|
| 62 |
+
Asterisk Operator (ASPP) - Union-Find Graph Propagation
|
| 63 |
|
| 64 |
+
Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections:
|
| 65 |
+
- Each position maintains a parent pointer: parent[i]
|
| 66 |
+
- Initial structure: parent[i] = max(0, i-1) (linear chain)
|
| 67 |
+
- Message passing: aggregate self + parent features
|
| 68 |
+
- Can apply path compression for optimization
|
| 69 |
+
|
| 70 |
+
Advantages:
|
| 71 |
+
- O(n) complexity with simple indexing
|
| 72 |
+
- Dynamic grouping of related positions
|
| 73 |
+
- Efficient parent-only propagation (no complex gather)
|
| 74 |
+
- Nearly constant time find with path compression
|
| 75 |
+
|
| 76 |
+
Complexity: O(n) with α(n) ≈ O(1) per operation
|
| 77 |
+
Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i])
|
| 78 |
|
| 79 |
Args:
|
| 80 |
hidden_size: Dimension of hidden states (input/output)
|
| 81 |
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
|
| 82 |
num_steps: Number of evolution steps K (default: 2)
|
| 83 |
dropout: Dropout rate for regularization (default: 0.1)
|
| 84 |
+
num_neighbors: Fixed at 1 (only parent) for Union-Find structure
|
| 85 |
"""
|
| 86 |
|
| 87 |
+
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1):
|
| 88 |
super().__init__()
|
| 89 |
self.hidden_size = hidden_size
|
| 90 |
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
|
| 91 |
self.num_steps = num_steps
|
| 92 |
+
self.num_neighbors = 1 # Fixed: only parent
|
| 93 |
|
| 94 |
# Projection to lower dimension (if specified)
|
| 95 |
self.use_projection = (self.aspp_hidden_dim != hidden_size)
|
|
|
|
| 98 |
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
|
| 99 |
self.proj_dropout = nn.Dropout(dropout)
|
| 100 |
|
| 101 |
+
# Message aggregation function: combines self + parent
|
| 102 |
+
self.message_net = nn.Sequential(
|
| 103 |
+
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2),
|
|
|
|
| 104 |
nn.SiLU(),
|
| 105 |
nn.Dropout(dropout),
|
| 106 |
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
# Learnable K-step parameter
|
|
|
|
| 111 |
self.k_logit = nn.Parameter(torch.tensor(1.0))
|
| 112 |
|
| 113 |
# Learnable residual scale
|
|
|
|
| 116 |
# Layer norm for stability
|
| 117 |
self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5)
|
| 118 |
|
| 119 |
+
def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor:
|
| 120 |
+
"""
|
| 121 |
+
Compute parent index for each position using Union-Find structure
|
| 122 |
+
|
| 123 |
+
Simple implementation: parent[i] = i-1 (linear chain)
|
| 124 |
+
- Position 0 points to itself (root)
|
| 125 |
+
- All others point to previous position
|
| 126 |
+
|
| 127 |
+
Can be extended with dynamic union operations based on:
|
| 128 |
+
- Semantic similarity
|
| 129 |
+
- Positional heuristics
|
| 130 |
+
- Learned grouping
|
| 131 |
+
|
| 132 |
+
Returns: [seq_len] tensor of parent indices
|
| 133 |
+
"""
|
| 134 |
+
# Initialize: parent[i] = max(0, i-1)
|
| 135 |
+
parent_indices = torch.arange(seq_len, device=device) - 1
|
| 136 |
+
parent_indices[0] = 0 # Root points to itself
|
| 137 |
+
parent_indices = torch.clamp(parent_indices, 0, seq_len - 1)
|
| 138 |
+
|
| 139 |
+
return parent_indices
|
| 140 |
+
|
| 141 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 142 |
"""
|
| 143 |
Args:
|
|
|
|
| 145 |
Returns:
|
| 146 |
evolved_states: [batch_size, seq_len, hidden_size]
|
| 147 |
"""
|
| 148 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 149 |
+
|
| 150 |
# Project to lower dimension if needed
|
| 151 |
if self.use_projection:
|
| 152 |
h_t = self.down_proj(hidden_states)
|
|
|
|
| 157 |
# Learnable number of steps
|
| 158 |
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
|
| 159 |
|
| 160 |
+
# K-step Union-Find graph propagation
|
| 161 |
for t in range(k_steps):
|
| 162 |
+
# 1. Compute parent indices using Union-Find structure
|
| 163 |
+
parent_indices = self.compute_parent_indices(seq_len, h_t.device) # [L]
|
| 164 |
+
|
| 165 |
+
# 2. Gather parent features (super simple indexing!)
|
| 166 |
+
# h_t: [B, L, D], parent_indices: [L]
|
| 167 |
+
# Just gather from parent positions
|
| 168 |
+
parent_features = h_t[:, parent_indices, :] # [B, L, D]
|
| 169 |
+
|
| 170 |
+
# 3. Message passing: combine self + parent
|
| 171 |
+
message_input = torch.cat([h_t, parent_features], dim=-1) # [B, L, 2D]
|
| 172 |
+
h_t_next = self.message_net(message_input) # [B, L, D]
|
| 173 |
|
| 174 |
+
# 4. Scaled residual connection for stability
|
| 175 |
h_t = h_t + self.residual_scale * h_t_next
|
| 176 |
h_t = self.norm(h_t)
|
| 177 |
|
|
|
|
| 197 |
4. Feed-forward network
|
| 198 |
"""
|
| 199 |
|
| 200 |
+
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1):
|
| 201 |
# Initialize parent LlamaDecoderLayer
|
| 202 |
super().__init__(config, layer_idx)
|
| 203 |
|
|
|
|
| 206 |
hidden_size=config.hidden_size,
|
| 207 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 208 |
num_steps=aspp_num_steps,
|
| 209 |
+
dropout=aspp_dropout,
|
| 210 |
+
num_neighbors=aspp_num_neighbors
|
| 211 |
)
|
| 212 |
|
| 213 |
# Gated fusion mechanism with dropout
|
|
|
|
| 227 |
hidden_size=config.hidden_size,
|
| 228 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 229 |
num_steps=aspp_num_steps,
|
| 230 |
+
dropout=aspp_dropout,
|
| 231 |
+
num_neighbors=aspp_num_neighbors
|
| 232 |
)
|
| 233 |
|
| 234 |
# Learnable flow scale (per-layer)
|
|
|
|
| 322 |
All layers use hybrid ASPP+Attention by default for maximum expressiveness.
|
| 323 |
"""
|
| 324 |
|
| 325 |
+
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
|
| 326 |
super().__init__(config)
|
| 327 |
|
| 328 |
# Determine which layers to make hybrid (default: ALL layers)
|
|
|
|
| 341 |
layer_idx=idx,
|
| 342 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 343 |
aspp_num_steps=aspp_num_steps,
|
| 344 |
+
aspp_dropout=aspp_dropout,
|
| 345 |
+
aspp_num_neighbors=aspp_num_neighbors
|
| 346 |
)
|
| 347 |
|
| 348 |
# Initialize weights
|
|
|
|
| 358 |
|
| 359 |
config_class = AsteriskConfig
|
| 360 |
|
| 361 |
+
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
|
| 362 |
# Read all ASPP parameters from config if not explicitly provided
|
| 363 |
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
|
| 364 |
hybrid_layer_indices = config.hybrid_layer_indices
|
|
|
|
| 368 |
aspp_num_steps = config.aspp_num_steps
|
| 369 |
if hasattr(config, 'aspp_dropout'):
|
| 370 |
aspp_dropout = config.aspp_dropout
|
| 371 |
+
if hasattr(config, 'aspp_num_neighbors'):
|
| 372 |
+
aspp_num_neighbors = config.aspp_num_neighbors
|
| 373 |
|
| 374 |
super().__init__(config)
|
| 375 |
|
| 376 |
# Replace model with Asterisk version
|
| 377 |
+
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
|
| 378 |
|
| 379 |
# Store hybrid layer info in config for serialization
|
| 380 |
self.config.hybrid_layer_indices = hybrid_layer_indices
|
|
|
|
| 390 |
aspp_hidden_dim: Optional[int] = None,
|
| 391 |
aspp_num_steps: int = 2,
|
| 392 |
aspp_dropout: float = 0.1,
|
| 393 |
+
aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
|
| 394 |
# π-flow parameters
|
| 395 |
pi_flow: bool = False,
|
| 396 |
pi_flow_steps: int = 1,
|
|
|
|
| 407 |
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
|
| 408 |
aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
|
| 409 |
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
|
| 410 |
+
aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent)
|
| 411 |
pi_flow: Enable π-flow refinement step (default: False)
|
| 412 |
pi_flow_steps: Number of flow refinement steps (default: 1)
|
| 413 |
pi_flow_scale: Initial flow scale parameter (default: 0.2)
|
|
|
|
| 424 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 425 |
aspp_num_steps=aspp_num_steps,
|
| 426 |
aspp_dropout=aspp_dropout,
|
| 427 |
+
aspp_num_neighbors=aspp_num_neighbors,
|
| 428 |
pi_flow=pi_flow,
|
| 429 |
pi_flow_steps=pi_flow_steps,
|
| 430 |
pi_flow_scale=pi_flow_scale,
|
|
|
|
| 432 |
)
|
| 433 |
|
| 434 |
# Create Asterisk model
|
| 435 |
+
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
|
| 436 |
|
| 437 |
# Transfer weights from base model (non-hybrid layers and embeddings)
|
| 438 |
asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
|
| 439 |
|
| 440 |
+
print(f"✓ Converted base model to Asterisk architecture with Graph Propagation")
|
| 441 |
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
|
| 442 |
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
|
| 443 |
+
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}")
|
| 444 |
if pi_flow:
|
| 445 |
print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
|
| 446 |
|
README.md
CHANGED
|
@@ -100,18 +100,129 @@ class HybridASPPAttentionLayer:
|
|
| 100 |
Combines ASPP operator with standard attention
|
| 101 |
|
| 102 |
Components:
|
| 103 |
-
- ASPP operator: Local structured reasoning
|
| 104 |
- Standard attention: Global context
|
| 105 |
- Gated fusion: Dynamic balancing
|
| 106 |
"""
|
| 107 |
```
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
**Fusion mechanism:**
|
| 110 |
```
|
| 111 |
-
aspp_out = ASPP(hidden_states)
|
| 112 |
-
attn_out = Attention(hidden_states, mask, ...)
|
| 113 |
gate = sigmoid(linear([aspp_out || attn_out]))
|
| 114 |
fused = gate * aspp_out + (1 - gate) * attn_out
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
```
|
| 116 |
|
| 117 |
### 2. π-Flow Refinement (Per-Layer)
|
|
|
|
| 100 |
Combines ASPP operator with standard attention
|
| 101 |
|
| 102 |
Components:
|
| 103 |
+
- ASPP operator: Local structured reasoning with Union-Find graph propagation
|
| 104 |
- Standard attention: Global context
|
| 105 |
- Gated fusion: Dynamic balancing
|
| 106 |
"""
|
| 107 |
```
|
| 108 |
|
| 109 |
+
#### ASPP Operator: Union-Find Graph Propagation
|
| 110 |
+
|
| 111 |
+
The ASPP operator uses a **Union-Find (Disjoint Set Union)** structure for efficient graph-based message passing. Unlike traditional attention's O(n²) complexity or skip-list's O(n log n), Union-Find achieves **O(n) complexity with nearly constant-time operations**.
|
| 112 |
+
|
| 113 |
+
**Graph Structure - Union-Find Parent Chain:**
|
| 114 |
+
|
| 115 |
+
```
|
| 116 |
+
Position: [0] [1] [2] [3] [4] [5] ... [n-1]
|
| 117 |
+
Parent: [0] ← 0 ← 1 ← 2 ← 3 ← 4 ... ← n-2
|
| 118 |
+
(root)
|
| 119 |
+
|
| 120 |
+
- Position 0: points to itself (root of the tree)
|
| 121 |
+
- Position i (i>0): points to position i-1 (parent)
|
| 122 |
+
- Forms a linear chain structure for sequential token relationships
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
This creates a **directed acyclic graph (DAG)** where information flows from children to parents, naturally capturing left-to-right sequential dependencies in language modeling.
|
| 126 |
+
|
| 127 |
+
**Graph Propagation Aggregation:**
|
| 128 |
+
|
| 129 |
+
Each ASPP evolution step performs parent-based message passing:
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
# Pseudocode for one ASPP propagation step
|
| 133 |
+
for position i in sequence:
|
| 134 |
+
# 1. Find parent using Union-Find structure
|
| 135 |
+
parent_idx = compute_parent_indices()[i] # O(1) with path compression
|
| 136 |
+
|
| 137 |
+
# 2. Gather parent features
|
| 138 |
+
parent_features = hidden_states[parent_idx]
|
| 139 |
+
|
| 140 |
+
# 3. Message aggregation: combine self + parent
|
| 141 |
+
message_input = concat([hidden_states[i], parent_features])
|
| 142 |
+
|
| 143 |
+
# 4. Update via learned transformation
|
| 144 |
+
new_state = message_net(message_input) # 2-layer MLP
|
| 145 |
+
|
| 146 |
+
# 5. Scaled residual connection
|
| 147 |
+
hidden_states[i] = hidden_states[i] + residual_scale * new_state
|
| 148 |
+
hidden_states[i] = layer_norm(hidden_states[i])
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
**Key properties of Union-Find propagation:**
|
| 152 |
+
|
| 153 |
+
1. **O(n) Complexity**: Each position performs exactly one parent lookup and one aggregation
|
| 154 |
+
- No expensive attention computation (O(n²))
|
| 155 |
+
- No multi-level skip connections (O(n log n))
|
| 156 |
+
- Simple indexing operation: `parent_features = h[parent_indices]`
|
| 157 |
+
|
| 158 |
+
2. **Hierarchical Information Flow**: After K steps, position i can access information from positions [i-K, i]
|
| 159 |
+
- K=1: immediate parent only
|
| 160 |
+
- K=2: grandparent (2 positions back)
|
| 161 |
+
- K=4 (default): great-great-grandparent (4 positions back)
|
| 162 |
+
- Information propagates through the chain structure
|
| 163 |
+
|
| 164 |
+
3. **Learnable Aggregation**: The `message_net` MLP learns how to combine self and parent features
|
| 165 |
+
- Input: `[self_features || parent_features]` (2D dimensions)
|
| 166 |
+
- Output: `D` dimensional update vector
|
| 167 |
+
- Dropout regularization for robustness
|
| 168 |
+
|
| 169 |
+
4. **Path Compression Potential**: Can extend to dynamic parent reassignment
|
| 170 |
+
- Current implementation: static `parent[i] = i-1` chain
|
| 171 |
+
- Future extension: learn parent assignments based on semantic similarity
|
| 172 |
+
- Enables adaptive graph structure during forward pass
|
| 173 |
+
|
| 174 |
+
**Union-Find vs. Other Graph Structures:**
|
| 175 |
+
|
| 176 |
+
| Structure | Complexity | Receptive Field | Connections per Node |
|
| 177 |
+
|-----------|------------|-----------------|----------------------|
|
| 178 |
+
| **Full Attention** | O(n²) | Global | n-1 (all positions) |
|
| 179 |
+
| **Skip-List** | O(n log n) | Multi-scale | O(log n) (multiple levels) |
|
| 180 |
+
| **Union-Find** | O(n) | Local chain | 1 (parent only) |
|
| 181 |
+
| **Dilated Conv** | O(n·k) | Sparse | k (fixed window) |
|
| 182 |
+
|
| 183 |
+
Union-Find achieves the **lowest complexity** while maintaining effective information propagation through iterative K-step evolution.
|
| 184 |
+
|
| 185 |
+
**Theoretical Foundation - Union-Find in Graph Algorithms:**
|
| 186 |
+
|
| 187 |
+
Union-Find is a classic data structure for disjoint set operations:
|
| 188 |
+
- **Find**: Determine which set an element belongs to (with path compression: O(α(n)) ≈ O(1))
|
| 189 |
+
- **Union**: Merge two sets into one
|
| 190 |
+
- **Applications**: Kruskal's MST algorithm, connected components, cycle detection
|
| 191 |
+
|
| 192 |
+
In Asterisk-Pi:
|
| 193 |
+
- Each token position is a node in the graph
|
| 194 |
+
- Parent pointers define the tree structure
|
| 195 |
+
- Message passing simulates "Find" operations (traversing to ancestors)
|
| 196 |
+
- Can extend to dynamic "Union" operations (merging related tokens)
|
| 197 |
+
|
| 198 |
+
**Multi-Step Propagation:**
|
| 199 |
+
|
| 200 |
+
With K=4 evolution steps, information flow becomes:
|
| 201 |
+
```
|
| 202 |
+
Step 1: Position i accesses parent i-1
|
| 203 |
+
Step 2: Position i now has information from i-2 (via i-1)
|
| 204 |
+
Step 3: Position i now has information from i-3 (propagated through chain)
|
| 205 |
+
Step 4: Position i now has information from i-4 (fully propagated)
|
| 206 |
+
|
| 207 |
+
Result: Each position has aggregated context from 4 previous positions
|
| 208 |
+
through efficient O(n) operations
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
This multi-step propagation is crucial for:
|
| 212 |
+
- **Local context**: Recent tokens for coherence
|
| 213 |
+
- **Gradient flow**: Direct paths for backpropagation
|
| 214 |
+
- **Efficiency**: Linear cost instead of quadratic attention
|
| 215 |
+
|
| 216 |
**Fusion mechanism:**
|
| 217 |
```
|
| 218 |
+
aspp_out = ASPP(hidden_states) # Union-Find graph propagation (O(n))
|
| 219 |
+
attn_out = Attention(hidden_states, mask, ...) # Global attention (O(n²))
|
| 220 |
gate = sigmoid(linear([aspp_out || attn_out]))
|
| 221 |
fused = gate * aspp_out + (1 - gate) * attn_out
|
| 222 |
+
|
| 223 |
+
# Combines:
|
| 224 |
+
# - Local structured reasoning (ASPP via Union-Find)
|
| 225 |
+
# - Global contextual awareness (Attention)
|
| 226 |
```
|
| 227 |
|
| 228 |
### 2. π-Flow Refinement (Per-Layer)
|
chat_template.jinja
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
|
| 2 |
-
You are a helpful AI assistant named
|
| 3 |
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
|
| 4 |
' + message['content'] + '<|im_end|>' + '
|
| 5 |
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
|
|
|
|
| 1 |
{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
|
| 2 |
+
You are a helpful AI assistant named Asterisk, trained by NoesisLab<|im_end|>
|
| 3 |
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
|
| 4 |
' + message['content'] + '<|im_end|>' + '
|
| 5 |
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
|
config.json
CHANGED
|
@@ -2,8 +2,9 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"AsteriskForCausalLM"
|
| 4 |
],
|
| 5 |
-
"aspp_dropout": 0.
|
| 6 |
"aspp_hidden_dim": 256,
|
|
|
|
| 7 |
"aspp_num_steps": 4,
|
| 8 |
"attention_bias": false,
|
| 9 |
"attention_dropout": 0.0,
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"AsteriskForCausalLM"
|
| 4 |
],
|
| 5 |
+
"aspp_dropout": 0.1,
|
| 6 |
"aspp_hidden_dim": 256,
|
| 7 |
+
"aspp_num_neighbors": 1,
|
| 8 |
"aspp_num_steps": 4,
|
| 9 |
"attention_bias": false,
|
| 10 |
"attention_dropout": 0.0,
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7c7f75c4ede6e9a2f8ef54b5c5c5b0d29c773eda3c8467426fb957edf075bb5
|
| 3 |
+
size 396836528
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 6353
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4bf59bb69a5946540bcc9c4c08fc6cf0c903b2923e07239b606e38965dcb26a5
|
| 3 |
size 6353
|