OzTianlu commited on
Commit
3356706
·
verified ·
1 Parent(s): 74de4bb

Upload 14 files

Browse files
AsteriskForCausalLM.py CHANGED
@@ -36,6 +36,7 @@ class AsteriskConfig(LlamaConfig):
36
  aspp_hidden_dim: Optional[int] = None,
37
  aspp_num_steps: int = 2,
38
  aspp_dropout: float = 0.1,
 
39
  # π-flow parameters
40
  pi_flow: bool = False,
41
  pi_flow_steps: int = 1,
@@ -48,6 +49,7 @@ class AsteriskConfig(LlamaConfig):
48
  self.aspp_hidden_dim = aspp_hidden_dim
49
  self.aspp_num_steps = aspp_num_steps
50
  self.aspp_dropout = aspp_dropout
 
51
  # π-flow config
52
  self.pi_flow = pi_flow
53
  self.pi_flow_steps = pi_flow_steps
@@ -57,26 +59,37 @@ class AsteriskConfig(LlamaConfig):
57
 
58
  class ASPPOperator(nn.Module):
59
  """
60
- Asterisk Operator (ASPP) - Point-wise Parallel Propagation
61
 
62
- Simplified version WITHOUT neighbor gathering to reduce overfitting:
63
- - Optional dimensionality reduction for efficiency
64
- - Point-wise evolution: h_i^(t+1) = φ(h_i^(t)) [NO neighbors]
65
- - Multi-step evolution for depth without added complexity
66
- - Dropout for regularization
 
 
 
 
 
 
 
 
 
67
 
68
  Args:
69
  hidden_size: Dimension of hidden states (input/output)
70
  aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
71
  num_steps: Number of evolution steps K (default: 2)
72
  dropout: Dropout rate for regularization (default: 0.1)
 
73
  """
74
 
75
- def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1):
76
  super().__init__()
77
  self.hidden_size = hidden_size
78
  self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
79
  self.num_steps = num_steps
 
80
 
81
  # Projection to lower dimension (if specified)
82
  self.use_projection = (self.aspp_hidden_dim != hidden_size)
@@ -85,10 +98,9 @@ class ASPPOperator(nn.Module):
85
  self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
86
  self.proj_dropout = nn.Dropout(dropout)
87
 
88
- # Point-wise update function φ - NO neighbor gathering
89
- # Much smaller: only processes current position
90
- self.update_net = nn.Sequential(
91
- nn.Linear(self.aspp_hidden_dim, self.aspp_hidden_dim * 2),
92
  nn.SiLU(),
93
  nn.Dropout(dropout),
94
  nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
@@ -96,7 +108,6 @@ class ASPPOperator(nn.Module):
96
  )
97
 
98
  # Learnable K-step parameter
99
- # sigmoid(1.0) ≈ 0.73, giving k_steps ≈ 1.5 → 2 steps initially
100
  self.k_logit = nn.Parameter(torch.tensor(1.0))
101
 
102
  # Learnable residual scale
@@ -105,6 +116,28 @@ class ASPPOperator(nn.Module):
105
  # Layer norm for stability
106
  self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
109
  """
110
  Args:
@@ -112,6 +145,8 @@ class ASPPOperator(nn.Module):
112
  Returns:
113
  evolved_states: [batch_size, seq_len, hidden_size]
114
  """
 
 
115
  # Project to lower dimension if needed
116
  if self.use_projection:
117
  h_t = self.down_proj(hidden_states)
@@ -122,12 +157,21 @@ class ASPPOperator(nn.Module):
122
  # Learnable number of steps
123
  k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
124
 
125
- # K-step point-wise evolution (NO neighbor gathering)
126
  for t in range(k_steps):
127
- # Apply point-wise update rule φ
128
- h_t_next = self.update_net(h_t)
 
 
 
 
 
 
 
 
 
129
 
130
- # Scaled residual connection for stability
131
  h_t = h_t + self.residual_scale * h_t_next
132
  h_t = self.norm(h_t)
133
 
@@ -153,7 +197,7 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
153
  4. Feed-forward network
154
  """
155
 
156
- def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
157
  # Initialize parent LlamaDecoderLayer
158
  super().__init__(config, layer_idx)
159
 
@@ -162,7 +206,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
162
  hidden_size=config.hidden_size,
163
  aspp_hidden_dim=aspp_hidden_dim,
164
  num_steps=aspp_num_steps,
165
- dropout=aspp_dropout
 
166
  )
167
 
168
  # Gated fusion mechanism with dropout
@@ -182,7 +227,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
182
  hidden_size=config.hidden_size,
183
  aspp_hidden_dim=aspp_hidden_dim,
184
  num_steps=aspp_num_steps,
185
- dropout=aspp_dropout
 
186
  )
187
 
188
  # Learnable flow scale (per-layer)
@@ -276,7 +322,7 @@ class AsteriskLlamaModel(LlamaModel):
276
  All layers use hybrid ASPP+Attention by default for maximum expressiveness.
277
  """
278
 
279
- def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
280
  super().__init__(config)
281
 
282
  # Determine which layers to make hybrid (default: ALL layers)
@@ -295,7 +341,8 @@ class AsteriskLlamaModel(LlamaModel):
295
  layer_idx=idx,
296
  aspp_hidden_dim=aspp_hidden_dim,
297
  aspp_num_steps=aspp_num_steps,
298
- aspp_dropout=aspp_dropout
 
299
  )
300
 
301
  # Initialize weights
@@ -311,7 +358,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
311
 
312
  config_class = AsteriskConfig
313
 
314
- def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
315
  # Read all ASPP parameters from config if not explicitly provided
316
  if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
317
  hybrid_layer_indices = config.hybrid_layer_indices
@@ -321,11 +368,13 @@ class AsteriskForCausalLM(LlamaForCausalLM):
321
  aspp_num_steps = config.aspp_num_steps
322
  if hasattr(config, 'aspp_dropout'):
323
  aspp_dropout = config.aspp_dropout
 
 
324
 
325
  super().__init__(config)
326
 
327
  # Replace model with Asterisk version
328
- self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
329
 
330
  # Store hybrid layer info in config for serialization
331
  self.config.hybrid_layer_indices = hybrid_layer_indices
@@ -341,6 +390,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
341
  aspp_hidden_dim: Optional[int] = None,
342
  aspp_num_steps: int = 2,
343
  aspp_dropout: float = 0.1,
 
344
  # π-flow parameters
345
  pi_flow: bool = False,
346
  pi_flow_steps: int = 1,
@@ -357,6 +407,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
357
  aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
358
  aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
359
  aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
 
360
  pi_flow: Enable π-flow refinement step (default: False)
361
  pi_flow_steps: Number of flow refinement steps (default: 1)
362
  pi_flow_scale: Initial flow scale parameter (default: 0.2)
@@ -373,6 +424,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
373
  aspp_hidden_dim=aspp_hidden_dim,
374
  aspp_num_steps=aspp_num_steps,
375
  aspp_dropout=aspp_dropout,
 
376
  pi_flow=pi_flow,
377
  pi_flow_steps=pi_flow_steps,
378
  pi_flow_scale=pi_flow_scale,
@@ -380,15 +432,15 @@ class AsteriskForCausalLM(LlamaForCausalLM):
380
  )
381
 
382
  # Create Asterisk model
383
- asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
384
 
385
  # Transfer weights from base model (non-hybrid layers and embeddings)
386
  asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
387
 
388
- print(f"✓ Converted base model to Asterisk architecture")
389
  print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
390
  aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
391
- print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}")
392
  if pi_flow:
393
  print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
394
 
 
36
  aspp_hidden_dim: Optional[int] = None,
37
  aspp_num_steps: int = 2,
38
  aspp_dropout: float = 0.1,
39
+ aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
40
  # π-flow parameters
41
  pi_flow: bool = False,
42
  pi_flow_steps: int = 1,
 
49
  self.aspp_hidden_dim = aspp_hidden_dim
50
  self.aspp_num_steps = aspp_num_steps
51
  self.aspp_dropout = aspp_dropout
52
+ self.aspp_num_neighbors = aspp_num_neighbors
53
  # π-flow config
54
  self.pi_flow = pi_flow
55
  self.pi_flow_steps = pi_flow_steps
 
59
 
60
  class ASPPOperator(nn.Module):
61
  """
62
+ Asterisk Operator (ASPP) - Union-Find Graph Propagation
63
 
64
+ Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections:
65
+ - Each position maintains a parent pointer: parent[i]
66
+ - Initial structure: parent[i] = max(0, i-1) (linear chain)
67
+ - Message passing: aggregate self + parent features
68
+ - Can apply path compression for optimization
69
+
70
+ Advantages:
71
+ - O(n) complexity with simple indexing
72
+ - Dynamic grouping of related positions
73
+ - Efficient parent-only propagation (no complex gather)
74
+ - Nearly constant time find with path compression
75
+
76
+ Complexity: O(n) with α(n) ≈ O(1) per operation
77
+ Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i])
78
 
79
  Args:
80
  hidden_size: Dimension of hidden states (input/output)
81
  aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
82
  num_steps: Number of evolution steps K (default: 2)
83
  dropout: Dropout rate for regularization (default: 0.1)
84
+ num_neighbors: Fixed at 1 (only parent) for Union-Find structure
85
  """
86
 
87
+ def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1):
88
  super().__init__()
89
  self.hidden_size = hidden_size
90
  self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
91
  self.num_steps = num_steps
92
+ self.num_neighbors = 1 # Fixed: only parent
93
 
94
  # Projection to lower dimension (if specified)
95
  self.use_projection = (self.aspp_hidden_dim != hidden_size)
 
98
  self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
99
  self.proj_dropout = nn.Dropout(dropout)
100
 
101
+ # Message aggregation function: combines self + parent
102
+ self.message_net = nn.Sequential(
103
+ nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2),
 
104
  nn.SiLU(),
105
  nn.Dropout(dropout),
106
  nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
 
108
  )
109
 
110
  # Learnable K-step parameter
 
111
  self.k_logit = nn.Parameter(torch.tensor(1.0))
112
 
113
  # Learnable residual scale
 
116
  # Layer norm for stability
117
  self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5)
118
 
119
+ def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor:
120
+ """
121
+ Compute parent index for each position using Union-Find structure
122
+
123
+ Simple implementation: parent[i] = i-1 (linear chain)
124
+ - Position 0 points to itself (root)
125
+ - All others point to previous position
126
+
127
+ Can be extended with dynamic union operations based on:
128
+ - Semantic similarity
129
+ - Positional heuristics
130
+ - Learned grouping
131
+
132
+ Returns: [seq_len] tensor of parent indices
133
+ """
134
+ # Initialize: parent[i] = max(0, i-1)
135
+ parent_indices = torch.arange(seq_len, device=device) - 1
136
+ parent_indices[0] = 0 # Root points to itself
137
+ parent_indices = torch.clamp(parent_indices, 0, seq_len - 1)
138
+
139
+ return parent_indices
140
+
141
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
142
  """
143
  Args:
 
145
  Returns:
146
  evolved_states: [batch_size, seq_len, hidden_size]
147
  """
148
+ batch_size, seq_len, _ = hidden_states.shape
149
+
150
  # Project to lower dimension if needed
151
  if self.use_projection:
152
  h_t = self.down_proj(hidden_states)
 
157
  # Learnable number of steps
158
  k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
159
 
160
+ # K-step Union-Find graph propagation
161
  for t in range(k_steps):
162
+ # 1. Compute parent indices using Union-Find structure
163
+ parent_indices = self.compute_parent_indices(seq_len, h_t.device) # [L]
164
+
165
+ # 2. Gather parent features (super simple indexing!)
166
+ # h_t: [B, L, D], parent_indices: [L]
167
+ # Just gather from parent positions
168
+ parent_features = h_t[:, parent_indices, :] # [B, L, D]
169
+
170
+ # 3. Message passing: combine self + parent
171
+ message_input = torch.cat([h_t, parent_features], dim=-1) # [B, L, 2D]
172
+ h_t_next = self.message_net(message_input) # [B, L, D]
173
 
174
+ # 4. Scaled residual connection for stability
175
  h_t = h_t + self.residual_scale * h_t_next
176
  h_t = self.norm(h_t)
177
 
 
197
  4. Feed-forward network
198
  """
199
 
200
+ def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1):
201
  # Initialize parent LlamaDecoderLayer
202
  super().__init__(config, layer_idx)
203
 
 
206
  hidden_size=config.hidden_size,
207
  aspp_hidden_dim=aspp_hidden_dim,
208
  num_steps=aspp_num_steps,
209
+ dropout=aspp_dropout,
210
+ num_neighbors=aspp_num_neighbors
211
  )
212
 
213
  # Gated fusion mechanism with dropout
 
227
  hidden_size=config.hidden_size,
228
  aspp_hidden_dim=aspp_hidden_dim,
229
  num_steps=aspp_num_steps,
230
+ dropout=aspp_dropout,
231
+ num_neighbors=aspp_num_neighbors
232
  )
233
 
234
  # Learnable flow scale (per-layer)
 
322
  All layers use hybrid ASPP+Attention by default for maximum expressiveness.
323
  """
324
 
325
+ def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
326
  super().__init__(config)
327
 
328
  # Determine which layers to make hybrid (default: ALL layers)
 
341
  layer_idx=idx,
342
  aspp_hidden_dim=aspp_hidden_dim,
343
  aspp_num_steps=aspp_num_steps,
344
+ aspp_dropout=aspp_dropout,
345
+ aspp_num_neighbors=aspp_num_neighbors
346
  )
347
 
348
  # Initialize weights
 
358
 
359
  config_class = AsteriskConfig
360
 
361
+ def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
362
  # Read all ASPP parameters from config if not explicitly provided
363
  if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
364
  hybrid_layer_indices = config.hybrid_layer_indices
 
368
  aspp_num_steps = config.aspp_num_steps
369
  if hasattr(config, 'aspp_dropout'):
370
  aspp_dropout = config.aspp_dropout
371
+ if hasattr(config, 'aspp_num_neighbors'):
372
+ aspp_num_neighbors = config.aspp_num_neighbors
373
 
374
  super().__init__(config)
375
 
376
  # Replace model with Asterisk version
377
+ self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
378
 
379
  # Store hybrid layer info in config for serialization
380
  self.config.hybrid_layer_indices = hybrid_layer_indices
 
390
  aspp_hidden_dim: Optional[int] = None,
391
  aspp_num_steps: int = 2,
392
  aspp_dropout: float = 0.1,
393
+ aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
394
  # π-flow parameters
395
  pi_flow: bool = False,
396
  pi_flow_steps: int = 1,
 
407
  aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
408
  aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
409
  aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
410
+ aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent)
411
  pi_flow: Enable π-flow refinement step (default: False)
412
  pi_flow_steps: Number of flow refinement steps (default: 1)
413
  pi_flow_scale: Initial flow scale parameter (default: 0.2)
 
424
  aspp_hidden_dim=aspp_hidden_dim,
425
  aspp_num_steps=aspp_num_steps,
426
  aspp_dropout=aspp_dropout,
427
+ aspp_num_neighbors=aspp_num_neighbors,
428
  pi_flow=pi_flow,
429
  pi_flow_steps=pi_flow_steps,
430
  pi_flow_scale=pi_flow_scale,
 
432
  )
433
 
434
  # Create Asterisk model
435
+ asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
436
 
437
  # Transfer weights from base model (non-hybrid layers and embeddings)
438
  asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
439
 
440
+ print(f"✓ Converted base model to Asterisk architecture with Graph Propagation")
441
  print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
442
  aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
443
+ print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}")
444
  if pi_flow:
445
  print(f" π-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
446
 
README.md CHANGED
@@ -100,18 +100,129 @@ class HybridASPPAttentionLayer:
100
  Combines ASPP operator with standard attention
101
 
102
  Components:
103
- - ASPP operator: Local structured reasoning
104
  - Standard attention: Global context
105
  - Gated fusion: Dynamic balancing
106
  """
107
  ```
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  **Fusion mechanism:**
110
  ```
111
- aspp_out = ASPP(hidden_states)
112
- attn_out = Attention(hidden_states, mask, ...)
113
  gate = sigmoid(linear([aspp_out || attn_out]))
114
  fused = gate * aspp_out + (1 - gate) * attn_out
 
 
 
 
115
  ```
116
 
117
  ### 2. π-Flow Refinement (Per-Layer)
 
100
  Combines ASPP operator with standard attention
101
 
102
  Components:
103
+ - ASPP operator: Local structured reasoning with Union-Find graph propagation
104
  - Standard attention: Global context
105
  - Gated fusion: Dynamic balancing
106
  """
107
  ```
108
 
109
+ #### ASPP Operator: Union-Find Graph Propagation
110
+
111
+ The ASPP operator uses a **Union-Find (Disjoint Set Union)** structure for efficient graph-based message passing. Unlike traditional attention's O(n²) complexity or skip-list's O(n log n), Union-Find achieves **O(n) complexity with nearly constant-time operations**.
112
+
113
+ **Graph Structure - Union-Find Parent Chain:**
114
+
115
+ ```
116
+ Position: [0] [1] [2] [3] [4] [5] ... [n-1]
117
+ Parent: [0] ← 0 ← 1 ← 2 ← 3 ← 4 ... ← n-2
118
+ (root)
119
+
120
+ - Position 0: points to itself (root of the tree)
121
+ - Position i (i>0): points to position i-1 (parent)
122
+ - Forms a linear chain structure for sequential token relationships
123
+ ```
124
+
125
+ This creates a **directed acyclic graph (DAG)** where information flows from children to parents, naturally capturing left-to-right sequential dependencies in language modeling.
126
+
127
+ **Graph Propagation Aggregation:**
128
+
129
+ Each ASPP evolution step performs parent-based message passing:
130
+
131
+ ```python
132
+ # Pseudocode for one ASPP propagation step
133
+ for position i in sequence:
134
+ # 1. Find parent using Union-Find structure
135
+ parent_idx = compute_parent_indices()[i] # O(1) with path compression
136
+
137
+ # 2. Gather parent features
138
+ parent_features = hidden_states[parent_idx]
139
+
140
+ # 3. Message aggregation: combine self + parent
141
+ message_input = concat([hidden_states[i], parent_features])
142
+
143
+ # 4. Update via learned transformation
144
+ new_state = message_net(message_input) # 2-layer MLP
145
+
146
+ # 5. Scaled residual connection
147
+ hidden_states[i] = hidden_states[i] + residual_scale * new_state
148
+ hidden_states[i] = layer_norm(hidden_states[i])
149
+ ```
150
+
151
+ **Key properties of Union-Find propagation:**
152
+
153
+ 1. **O(n) Complexity**: Each position performs exactly one parent lookup and one aggregation
154
+ - No expensive attention computation (O(n²))
155
+ - No multi-level skip connections (O(n log n))
156
+ - Simple indexing operation: `parent_features = h[parent_indices]`
157
+
158
+ 2. **Hierarchical Information Flow**: After K steps, position i can access information from positions [i-K, i]
159
+ - K=1: immediate parent only
160
+ - K=2: grandparent (2 positions back)
161
+ - K=4 (default): great-great-grandparent (4 positions back)
162
+ - Information propagates through the chain structure
163
+
164
+ 3. **Learnable Aggregation**: The `message_net` MLP learns how to combine self and parent features
165
+ - Input: `[self_features || parent_features]` (2D dimensions)
166
+ - Output: `D` dimensional update vector
167
+ - Dropout regularization for robustness
168
+
169
+ 4. **Path Compression Potential**: Can extend to dynamic parent reassignment
170
+ - Current implementation: static `parent[i] = i-1` chain
171
+ - Future extension: learn parent assignments based on semantic similarity
172
+ - Enables adaptive graph structure during forward pass
173
+
174
+ **Union-Find vs. Other Graph Structures:**
175
+
176
+ | Structure | Complexity | Receptive Field | Connections per Node |
177
+ |-----------|------------|-----------------|----------------------|
178
+ | **Full Attention** | O(n²) | Global | n-1 (all positions) |
179
+ | **Skip-List** | O(n log n) | Multi-scale | O(log n) (multiple levels) |
180
+ | **Union-Find** | O(n) | Local chain | 1 (parent only) |
181
+ | **Dilated Conv** | O(n·k) | Sparse | k (fixed window) |
182
+
183
+ Union-Find achieves the **lowest complexity** while maintaining effective information propagation through iterative K-step evolution.
184
+
185
+ **Theoretical Foundation - Union-Find in Graph Algorithms:**
186
+
187
+ Union-Find is a classic data structure for disjoint set operations:
188
+ - **Find**: Determine which set an element belongs to (with path compression: O(α(n)) ≈ O(1))
189
+ - **Union**: Merge two sets into one
190
+ - **Applications**: Kruskal's MST algorithm, connected components, cycle detection
191
+
192
+ In Asterisk-Pi:
193
+ - Each token position is a node in the graph
194
+ - Parent pointers define the tree structure
195
+ - Message passing simulates "Find" operations (traversing to ancestors)
196
+ - Can extend to dynamic "Union" operations (merging related tokens)
197
+
198
+ **Multi-Step Propagation:**
199
+
200
+ With K=4 evolution steps, information flow becomes:
201
+ ```
202
+ Step 1: Position i accesses parent i-1
203
+ Step 2: Position i now has information from i-2 (via i-1)
204
+ Step 3: Position i now has information from i-3 (propagated through chain)
205
+ Step 4: Position i now has information from i-4 (fully propagated)
206
+
207
+ Result: Each position has aggregated context from 4 previous positions
208
+ through efficient O(n) operations
209
+ ```
210
+
211
+ This multi-step propagation is crucial for:
212
+ - **Local context**: Recent tokens for coherence
213
+ - **Gradient flow**: Direct paths for backpropagation
214
+ - **Efficiency**: Linear cost instead of quadratic attention
215
+
216
  **Fusion mechanism:**
217
  ```
218
+ aspp_out = ASPP(hidden_states) # Union-Find graph propagation (O(n))
219
+ attn_out = Attention(hidden_states, mask, ...) # Global attention (O(n²))
220
  gate = sigmoid(linear([aspp_out || attn_out]))
221
  fused = gate * aspp_out + (1 - gate) * attn_out
222
+
223
+ # Combines:
224
+ # - Local structured reasoning (ASPP via Union-Find)
225
+ # - Global contextual awareness (Attention)
226
  ```
227
 
228
  ### 2. π-Flow Refinement (Per-Layer)
chat_template.jinja CHANGED
@@ -1,5 +1,5 @@
1
  {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
2
- You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
3
  ' }}{% endif %}{{'<|im_start|>' + message['role'] + '
4
  ' + message['content'] + '<|im_end|>' + '
5
  '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
 
1
  {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
2
+ You are a helpful AI assistant named Asterisk, trained by NoesisLab<|im_end|>
3
  ' }}{% endif %}{{'<|im_start|>' + message['role'] + '
4
  ' + message['content'] + '<|im_end|>' + '
5
  '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
config.json CHANGED
@@ -2,8 +2,9 @@
2
  "architectures": [
3
  "AsteriskForCausalLM"
4
  ],
5
- "aspp_dropout": 0.2,
6
  "aspp_hidden_dim": 256,
 
7
  "aspp_num_steps": 4,
8
  "attention_bias": false,
9
  "attention_dropout": 0.0,
 
2
  "architectures": [
3
  "AsteriskForCausalLM"
4
  ],
5
+ "aspp_dropout": 0.1,
6
  "aspp_hidden_dim": 256,
7
+ "aspp_num_neighbors": 1,
8
  "aspp_num_steps": 4,
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd3411332c19c27ac340b99a92d91e0b93f224b62fa3e0cccf7777b4e126b802
3
- size 381107624
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7c7f75c4ede6e9a2f8ef54b5c5c5b0d29c773eda3c8467426fb957edf075bb5
3
+ size 396836528
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:357a1e8bcbd247f80b9437f6d4dd9e81a29edbafaa6fea075a7380b6927773f4
3
  size 6353
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bf59bb69a5946540bcc9c4c08fc6cf0c903b2923e07239b606e38965dcb26a5
3
  size 6353