OzTianlu commited on
Commit
6ce05cc
Β·
verified Β·
1 Parent(s): 8376f90

Upload 14 files

Browse files
AsteriskForCausalLM.py CHANGED
@@ -36,6 +36,7 @@ class AsteriskConfig(LlamaConfig):
36
  aspp_hidden_dim: Optional[int] = None,
37
  aspp_num_steps: int = 2,
38
  aspp_dropout: float = 0.1,
 
39
  # Ο€-flow parameters
40
  pi_flow: bool = False,
41
  pi_flow_steps: int = 1,
@@ -48,6 +49,7 @@ class AsteriskConfig(LlamaConfig):
48
  self.aspp_hidden_dim = aspp_hidden_dim
49
  self.aspp_num_steps = aspp_num_steps
50
  self.aspp_dropout = aspp_dropout
 
51
  # Ο€-flow config
52
  self.pi_flow = pi_flow
53
  self.pi_flow_steps = pi_flow_steps
@@ -57,26 +59,31 @@ class AsteriskConfig(LlamaConfig):
57
 
58
  class ASPPOperator(nn.Module):
59
  """
60
- Asterisk Operator (ASPP) - Point-wise Parallel Propagation
61
 
62
- Simplified version WITHOUT neighbor gathering to reduce overfitting:
63
- - Optional dimensionality reduction for efficiency
64
- - Point-wise evolution: h_i^(t+1) = Ο†(h_i^(t)) [NO neighbors]
65
- - Multi-step evolution for depth without added complexity
66
- - Dropout for regularization
 
 
 
67
 
68
  Args:
69
  hidden_size: Dimension of hidden states (input/output)
70
  aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
71
  num_steps: Number of evolution steps K (default: 2)
72
  dropout: Dropout rate for regularization (default: 0.1)
 
73
  """
74
 
75
- def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1):
76
  super().__init__()
77
  self.hidden_size = hidden_size
78
  self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
79
  self.num_steps = num_steps
 
80
 
81
  # Projection to lower dimension (if specified)
82
  self.use_projection = (self.aspp_hidden_dim != hidden_size)
@@ -85,10 +92,22 @@ class ASPPOperator(nn.Module):
85
  self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
86
  self.proj_dropout = nn.Dropout(dropout)
87
 
88
- # Point-wise update function Ο† - NO neighbor gathering
89
- # Much smaller: only processes current position
90
- self.update_net = nn.Sequential(
91
- nn.Linear(self.aspp_hidden_dim, self.aspp_hidden_dim * 2),
 
 
 
 
 
 
 
 
 
 
 
 
92
  nn.SiLU(),
93
  nn.Dropout(dropout),
94
  nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
@@ -96,7 +115,6 @@ class ASPPOperator(nn.Module):
96
  )
97
 
98
  # Learnable K-step parameter
99
- # sigmoid(1.0) β‰ˆ 0.73, giving k_steps β‰ˆ 1.5 β†’ 2 steps initially
100
  self.k_logit = nn.Parameter(torch.tensor(1.0))
101
 
102
  # Learnable residual scale
@@ -112,6 +130,8 @@ class ASPPOperator(nn.Module):
112
  Returns:
113
  evolved_states: [batch_size, seq_len, hidden_size]
114
  """
 
 
115
  # Project to lower dimension if needed
116
  if self.use_projection:
117
  h_t = self.down_proj(hidden_states)
@@ -122,12 +142,70 @@ class ASPPOperator(nn.Module):
122
  # Learnable number of steps
123
  k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
124
 
125
- # K-step point-wise evolution (NO neighbor gathering)
126
  for t in range(k_steps):
127
- # Apply point-wise update rule Ο†
128
- h_t_next = self.update_net(h_t)
129
-
130
- # Scaled residual connection for stability
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  h_t = h_t + self.residual_scale * h_t_next
132
  h_t = self.norm(h_t)
133
 
@@ -153,7 +231,7 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
153
  4. Feed-forward network
154
  """
155
 
156
- def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
157
  # Initialize parent LlamaDecoderLayer
158
  super().__init__(config, layer_idx)
159
 
@@ -162,7 +240,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
162
  hidden_size=config.hidden_size,
163
  aspp_hidden_dim=aspp_hidden_dim,
164
  num_steps=aspp_num_steps,
165
- dropout=aspp_dropout
 
166
  )
167
 
168
  # Gated fusion mechanism with dropout
@@ -182,7 +261,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
182
  hidden_size=config.hidden_size,
183
  aspp_hidden_dim=aspp_hidden_dim,
184
  num_steps=aspp_num_steps,
185
- dropout=aspp_dropout
 
186
  )
187
 
188
  # Learnable flow scale (per-layer)
@@ -276,7 +356,7 @@ class AsteriskLlamaModel(LlamaModel):
276
  All layers use hybrid ASPP+Attention by default for maximum expressiveness.
277
  """
278
 
279
- def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
280
  super().__init__(config)
281
 
282
  # Determine which layers to make hybrid (default: ALL layers)
@@ -295,7 +375,8 @@ class AsteriskLlamaModel(LlamaModel):
295
  layer_idx=idx,
296
  aspp_hidden_dim=aspp_hidden_dim,
297
  aspp_num_steps=aspp_num_steps,
298
- aspp_dropout=aspp_dropout
 
299
  )
300
 
301
  # Initialize weights
@@ -311,7 +392,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
311
 
312
  config_class = AsteriskConfig
313
 
314
- def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
315
  # Read all ASPP parameters from config if not explicitly provided
316
  if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
317
  hybrid_layer_indices = config.hybrid_layer_indices
@@ -321,11 +402,13 @@ class AsteriskForCausalLM(LlamaForCausalLM):
321
  aspp_num_steps = config.aspp_num_steps
322
  if hasattr(config, 'aspp_dropout'):
323
  aspp_dropout = config.aspp_dropout
 
 
324
 
325
  super().__init__(config)
326
 
327
  # Replace model with Asterisk version
328
- self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
329
 
330
  # Store hybrid layer info in config for serialization
331
  self.config.hybrid_layer_indices = hybrid_layer_indices
@@ -341,6 +424,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
341
  aspp_hidden_dim: Optional[int] = None,
342
  aspp_num_steps: int = 2,
343
  aspp_dropout: float = 0.1,
 
344
  # Ο€-flow parameters
345
  pi_flow: bool = False,
346
  pi_flow_steps: int = 1,
@@ -357,6 +441,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
357
  aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
358
  aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
359
  aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
 
360
  pi_flow: Enable Ο€-flow refinement step (default: False)
361
  pi_flow_steps: Number of flow refinement steps (default: 1)
362
  pi_flow_scale: Initial flow scale parameter (default: 0.2)
@@ -373,6 +458,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
373
  aspp_hidden_dim=aspp_hidden_dim,
374
  aspp_num_steps=aspp_num_steps,
375
  aspp_dropout=aspp_dropout,
 
376
  pi_flow=pi_flow,
377
  pi_flow_steps=pi_flow_steps,
378
  pi_flow_scale=pi_flow_scale,
@@ -380,15 +466,15 @@ class AsteriskForCausalLM(LlamaForCausalLM):
380
  )
381
 
382
  # Create Asterisk model
383
- asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
384
 
385
  # Transfer weights from base model (non-hybrid layers and embeddings)
386
  asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
387
 
388
- print(f"βœ“ Converted base model to Asterisk architecture")
389
  print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
390
  aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
391
- print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}")
392
  if pi_flow:
393
  print(f" Ο€-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
394
 
 
36
  aspp_hidden_dim: Optional[int] = None,
37
  aspp_num_steps: int = 2,
38
  aspp_dropout: float = 0.1,
39
+ aspp_num_neighbors: int = 8, # NEW: number of semantic neighbors
40
  # Ο€-flow parameters
41
  pi_flow: bool = False,
42
  pi_flow_steps: int = 1,
 
49
  self.aspp_hidden_dim = aspp_hidden_dim
50
  self.aspp_num_steps = aspp_num_steps
51
  self.aspp_dropout = aspp_dropout
52
+ self.aspp_num_neighbors = aspp_num_neighbors
53
  # Ο€-flow config
54
  self.pi_flow = pi_flow
55
  self.pi_flow_steps = pi_flow_steps
 
59
 
60
  class ASPPOperator(nn.Module):
61
  """
62
+ Asterisk Operator (ASPP) - Graph Propagation with Adjacency List (2D Grid)
63
 
64
+ Converts 1D sequence to 2D grid and performs 8-neighbor graph propagation:
65
+ - Adjacency list: 8 neighbors per position (ζ£‹η›˜8ι‚»ε±…)
66
+ β†– ↑ β†—
67
+ ← ● β†’
68
+ ↙ ↓ β†˜
69
+ - Learnable adjacency weights (randomly initialized)
70
+ - Padding for dynamic sequence lengths
71
+ - Message passing: h_i^(t+1) = Ο†(h_i^(t), Ξ£_j∈N(i) A_j * h_j^(t))
72
 
73
  Args:
74
  hidden_size: Dimension of hidden states (input/output)
75
  aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
76
  num_steps: Number of evolution steps K (default: 2)
77
  dropout: Dropout rate for regularization (default: 0.1)
78
+ num_neighbors: Number of neighbors (default: 8, for 8-directional grid)
79
  """
80
 
81
+ def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 8):
82
  super().__init__()
83
  self.hidden_size = hidden_size
84
  self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
85
  self.num_steps = num_steps
86
+ self.num_neighbors = num_neighbors
87
 
88
  # Projection to lower dimension (if specified)
89
  self.use_projection = (self.aspp_hidden_dim != hidden_size)
 
92
  self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
93
  self.proj_dropout = nn.Dropout(dropout)
94
 
95
+ # 8-directional offsets for 2D grid neighbors (row, col)
96
+ # β†–(-1,-1) ↑(-1,0) β†—(-1,1)
97
+ # ←(0,-1) ●(0,0) β†’(0,1)
98
+ # ↙(1,-1) ↓(1,0) β†˜(1,1)
99
+ self.register_buffer('neighbor_offsets', torch.tensor([
100
+ [-1, -1], [-1, 0], [-1, 1], # top-left, top, top-right
101
+ [0, -1], [0, 1], # left, right
102
+ [1, -1], [1, 0], [1, 1] # bottom-left, bottom, bottom-right
103
+ ], dtype=torch.long)) # [8, 2]
104
+
105
+ # Learnable adjacency weights for 8 directions (randomly initialized)
106
+ self.adjacency_weights = nn.Parameter(torch.randn(num_neighbors) * 0.1)
107
+
108
+ # Message aggregation function: combines self + neighbors
109
+ self.message_net = nn.Sequential(
110
+ nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2),
111
  nn.SiLU(),
112
  nn.Dropout(dropout),
113
  nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
 
115
  )
116
 
117
  # Learnable K-step parameter
 
118
  self.k_logit = nn.Parameter(torch.tensor(1.0))
119
 
120
  # Learnable residual scale
 
130
  Returns:
131
  evolved_states: [batch_size, seq_len, hidden_size]
132
  """
133
+ batch_size, seq_len, _ = hidden_states.shape
134
+
135
  # Project to lower dimension if needed
136
  if self.use_projection:
137
  h_t = self.down_proj(hidden_states)
 
142
  # Learnable number of steps
143
  k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
144
 
145
+ # K-step graph propagation with 2D grid adjacency list
146
  for t in range(k_steps):
147
+ # 1. Reshape 1D sequence to 2D grid
148
+ # Dynamic grid size: H β‰ˆ W β‰ˆ sqrt(seq_len)
149
+ H = int(torch.ceil(torch.sqrt(torch.tensor(seq_len, dtype=torch.float32))).item())
150
+ W = int(torch.ceil(torch.tensor(seq_len, dtype=torch.float32) / H).item())
151
+ grid_size = H * W
152
+
153
+ # Pad sequence to grid_size
154
+ if seq_len < grid_size:
155
+ padding = grid_size - seq_len
156
+ h_t_padded = F.pad(h_t, (0, 0, 0, padding), mode='constant', value=0) # [B, H*W, D]
157
+ else:
158
+ h_t_padded = h_t
159
+
160
+ # Reshape to 2D grid: [B, H*W, D] -> [B, H, W, D]
161
+ h_grid = h_t_padded.view(batch_size, H, W, self.aspp_hidden_dim)
162
+
163
+ # 2. Add boundary padding for neighbor gathering (pad with zeros)
164
+ # Pad 1 row/col on each side: [B, H, W, D] -> [B, H+2, W+2, D]
165
+ h_grid_padded = F.pad(h_grid, (0, 0, 1, 1, 1, 1), mode='constant', value=0)
166
+
167
+ # 3. Gather neighbors using adjacency list (offsets)
168
+ # neighbor_offsets: [8, 2] with (row_offset, col_offset)
169
+ # Only use first num_neighbors offsets
170
+ neighbors_list = []
171
+ for offset in self.neighbor_offsets[:self.num_neighbors]:
172
+ # Offset is relative to center, but we need absolute indices in padded grid
173
+ # Center at (i+1, j+1) in padded grid, neighbor at (i+1+di, j+1+dj)
174
+ # Use roll to shift the grid
175
+ di, dj = offset[0].item(), offset[1].item()
176
+
177
+ # Create index tensors for gathering
178
+ # For each position (i,j) in original grid, get neighbor at (i+di, j+dj) in padded grid
179
+ row_indices = torch.arange(H, device=h_t.device).view(-1, 1).expand(H, W) + 1 + di
180
+ col_indices = torch.arange(W, device=h_t.device).view(1, -1).expand(H, W) + 1 + dj
181
+
182
+ # Gather neighbor features: [B, H, W, D]
183
+ neighbor = h_grid_padded[:, row_indices, col_indices, :] # [B, H, W, D]
184
+ neighbors_list.append(neighbor)
185
+
186
+ # Stack neighbors: [B, H, W, num_neighbors, D]
187
+ neighbors = torch.stack(neighbors_list, dim=3) # [B, H, W, num_neighbors, D]
188
+
189
+ # 4. Apply learnable adjacency weights
190
+ # adjacency_weights: [num_neighbors] -> normalize with softmax
191
+ adj_weights = F.softmax(self.adjacency_weights, dim=0) # [num_neighbors]
192
+
193
+ # Weighted aggregation: [B, H, W, D]
194
+ aggregated_neighbors = torch.sum(neighbors * adj_weights.view(1, 1, 1, self.num_neighbors, 1), dim=3)
195
+
196
+ # 5. Message passing: combine self + neighbors
197
+ # Flatten back to sequence: [B, H, W, D] -> [B, H*W, D]
198
+ h_grid_flat = h_grid.view(batch_size, grid_size, self.aspp_hidden_dim)
199
+ aggregated_flat = aggregated_neighbors.view(batch_size, grid_size, self.aspp_hidden_dim)
200
+
201
+ # Concat and pass through message net
202
+ message_input = torch.cat([h_grid_flat, aggregated_flat], dim=-1) # [B, H*W, 2D]
203
+ h_t_next = self.message_net(message_input) # [B, H*W, D]
204
+
205
+ # Remove padding to restore original seq_len
206
+ h_t_next = h_t_next[:, :seq_len, :] # [B, L, D]
207
+
208
+ # 6. Scaled residual connection for stability
209
  h_t = h_t + self.residual_scale * h_t_next
210
  h_t = self.norm(h_t)
211
 
 
231
  4. Feed-forward network
232
  """
233
 
234
+ def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 8):
235
  # Initialize parent LlamaDecoderLayer
236
  super().__init__(config, layer_idx)
237
 
 
240
  hidden_size=config.hidden_size,
241
  aspp_hidden_dim=aspp_hidden_dim,
242
  num_steps=aspp_num_steps,
243
+ dropout=aspp_dropout,
244
+ num_neighbors=aspp_num_neighbors
245
  )
246
 
247
  # Gated fusion mechanism with dropout
 
261
  hidden_size=config.hidden_size,
262
  aspp_hidden_dim=aspp_hidden_dim,
263
  num_steps=aspp_num_steps,
264
+ dropout=aspp_dropout,
265
+ num_neighbors=aspp_num_neighbors
266
  )
267
 
268
  # Learnable flow scale (per-layer)
 
356
  All layers use hybrid ASPP+Attention by default for maximum expressiveness.
357
  """
358
 
359
+ def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 8):
360
  super().__init__(config)
361
 
362
  # Determine which layers to make hybrid (default: ALL layers)
 
375
  layer_idx=idx,
376
  aspp_hidden_dim=aspp_hidden_dim,
377
  aspp_num_steps=aspp_num_steps,
378
+ aspp_dropout=aspp_dropout,
379
+ aspp_num_neighbors=aspp_num_neighbors
380
  )
381
 
382
  # Initialize weights
 
392
 
393
  config_class = AsteriskConfig
394
 
395
+ def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 8):
396
  # Read all ASPP parameters from config if not explicitly provided
397
  if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
398
  hybrid_layer_indices = config.hybrid_layer_indices
 
402
  aspp_num_steps = config.aspp_num_steps
403
  if hasattr(config, 'aspp_dropout'):
404
  aspp_dropout = config.aspp_dropout
405
+ if hasattr(config, 'aspp_num_neighbors'):
406
+ aspp_num_neighbors = config.aspp_num_neighbors
407
 
408
  super().__init__(config)
409
 
410
  # Replace model with Asterisk version
411
+ self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
412
 
413
  # Store hybrid layer info in config for serialization
414
  self.config.hybrid_layer_indices = hybrid_layer_indices
 
424
  aspp_hidden_dim: Optional[int] = None,
425
  aspp_num_steps: int = 2,
426
  aspp_dropout: float = 0.1,
427
+ aspp_num_neighbors: int = 8, # NEW: number of semantic neighbors
428
  # Ο€-flow parameters
429
  pi_flow: bool = False,
430
  pi_flow_steps: int = 1,
 
441
  aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
442
  aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
443
  aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
444
+ aspp_num_neighbors: Number of semantic neighbors for graph propagation (default: 8)
445
  pi_flow: Enable Ο€-flow refinement step (default: False)
446
  pi_flow_steps: Number of flow refinement steps (default: 1)
447
  pi_flow_scale: Initial flow scale parameter (default: 0.2)
 
458
  aspp_hidden_dim=aspp_hidden_dim,
459
  aspp_num_steps=aspp_num_steps,
460
  aspp_dropout=aspp_dropout,
461
+ aspp_num_neighbors=aspp_num_neighbors,
462
  pi_flow=pi_flow,
463
  pi_flow_steps=pi_flow_steps,
464
  pi_flow_scale=pi_flow_scale,
 
466
  )
467
 
468
  # Create Asterisk model
469
+ asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
470
 
471
  # Transfer weights from base model (non-hybrid layers and embeddings)
472
  asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
473
 
474
+ print(f"βœ“ Converted base model to Asterisk architecture with Graph Propagation")
475
  print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
476
  aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
477
+ print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}")
478
  if pi_flow:
479
  print(f" Ο€-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
480
 
README.md CHANGED
@@ -100,12 +100,41 @@ class HybridASPPAttentionLayer:
100
  Combines ASPP operator with standard attention
101
 
102
  Components:
103
- - ASPP operator: Local structured reasoning
104
  - Standard attention: Global context
105
  - Gated fusion: Dynamic balancing
106
  """
107
  ```
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  **Fusion mechanism:**
110
  ```
111
  aspp_out = ASPP(hidden_states)
@@ -240,9 +269,9 @@ Total: ~10,148 training samples
240
  ### Training Configuration
241
 
242
  - **Starting Point**: Asterisk checkpoint (base ASPP-Attention model)
243
- - **Optimizer**: AdamW (lr=5e-4, weight_decay=0.1)
244
- - **Batch Size**: 2 per device, gradient accumulation=4 (effective batch=8)
245
- - **Epochs**: 2
246
  - **Scheduler**: Linear warmup (10% of steps)
247
  - **Mixed Precision**: bfloat16
248
  - **Gradient Checkpointing**: Enabled
@@ -263,9 +292,16 @@ pi_flow_use_gate = True # Token-wise adaptive gating
263
  aspp_hidden_dim = 256 # Internal dimension (vs 576 model hidden_size)
264
  aspp_num_steps = 4 # Evolution steps for ASPP
265
  aspp_dropout = 0.2 # Regularization
 
266
  hybrid_layer_indices = None # All 30 layers
267
  ```
268
 
 
 
 
 
 
 
269
  ## Model Creation from Base Asterisk
270
 
271
  ```python
@@ -283,6 +319,10 @@ config.pi_flow_steps = 2
283
  config.pi_flow_scale = 1.0
284
  config.pi_flow_use_gate = True
285
 
 
 
 
 
286
  # Create model with Ο€-flow
287
  model = AsteriskForCausalLM(config)
288
 
@@ -334,6 +374,40 @@ This creates a **hierarchical refinement cascade** enabling gradual convergence
334
 
335
  ## Implementation Details
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  ### Return Type Handling
338
 
339
  Critical for Transformers compatibility:
@@ -386,22 +460,41 @@ Asterisk-Pi/
386
 
387
  ## Known Issues & Solutions
388
 
389
- ### 1. Return Type Errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
  **Issue**: `AttributeError: 'tuple' object has no attribute 'dtype'`
392
 
393
  **Solution**: `HybridASPPAttentionLayer.forward()` must return `torch.Tensor` only, not tuple. This matches the `LlamaDecoderLayer` API in transformers 4.57.6.
394
 
395
- ### 2. Ο€-Flow in All Layers vs Final Layer
396
 
397
  **Initial approach**: Ο€-flow only in final layer (limited expressiveness)
398
 
399
  **Current approach**: Ο€-flow in all 30 hybrid layers for maximum refinement capability.
400
 
401
- ### 3. Training Stability
402
 
403
  Ο€-Flow can cause instability with high learning rates. Use:
404
- - Lower learning rate (5e-4 vs 2e-5 for base)
405
  - Gradient clipping (max_norm=1.0)
406
  - Conservative initial flow scale (0.2-1.0)
407
 
 
100
  Combines ASPP operator with standard attention
101
 
102
  Components:
103
+ - ASPP operator: Local structured reasoning with graph propagation
104
  - Standard attention: Global context
105
  - Gated fusion: Dynamic balancing
106
  """
107
  ```
108
 
109
+ **ASPP Operator - Graph Propagation:**
110
+
111
+ The ASPP operator converts the 1D sequence into a 2D grid and performs graph-based message passing:
112
+
113
+ ```
114
+ Sequence [1, 2, 3, 4, ...] β†’ 2D Grid:
115
+ β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”
116
+ β”‚ 1 β”‚ 2 β”‚ 3 β”‚
117
+ β”œβ”€β”€β”€β”Όβ”€β”€β”€β”Όβ”€β”€β”€β”€
118
+ β”‚ 4 β”‚ 5 β”‚ 6 β”‚
119
+ β””β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”˜
120
+
121
+ 8-directional neighbors (default):
122
+ β†– ↑ β†—
123
+ ← ● β†’
124
+ ↙ ↓ β†˜
125
+
126
+ 4-directional neighbors (optional):
127
+ ↑
128
+ ← ● β†’
129
+ ↓
130
+ ```
131
+
132
+ **Key features:**
133
+ - **Configurable neighbors**: `aspp_num_neighbors` (default: 8)
134
+ - **Learnable adjacency weights**: Each direction has a learnable weight
135
+ - **K-step evolution**: Iterative message passing for `aspp_num_steps` (default: 4)
136
+ - **Dynamic grid**: Grid dimensions adapt to sequence length (H β‰ˆ W β‰ˆ √seq_len)
137
+
138
  **Fusion mechanism:**
139
  ```
140
  aspp_out = ASPP(hidden_states)
 
269
  ### Training Configuration
270
 
271
  - **Starting Point**: Asterisk checkpoint (base ASPP-Attention model)
272
+ - **Optimizer**: AdamW (lr=1e-4, weight_decay=0.1)
273
+ - **Batch Size**: 4 per device, gradient accumulation=4 (effective batch=16)
274
+ - **Epochs**: 2.5
275
  - **Scheduler**: Linear warmup (10% of steps)
276
  - **Mixed Precision**: bfloat16
277
  - **Gradient Checkpointing**: Enabled
 
292
  aspp_hidden_dim = 256 # Internal dimension (vs 576 model hidden_size)
293
  aspp_num_steps = 4 # Evolution steps for ASPP
294
  aspp_dropout = 0.2 # Regularization
295
+ aspp_num_neighbors = 8 # Number of semantic neighbors for graph propagation (default: 8)
296
  hybrid_layer_indices = None # All 30 layers
297
  ```
298
 
299
+ **Graph Propagation Neighbors:**
300
+ - Default: 8-directional grid (β†–β†‘β†—β†β†’β†™β†“β†˜)
301
+ - Configurable: Can use fewer neighbors (e.g., 4-directional: ↑←→↓)
302
+ - The neighbor offsets are defined in a buffer, and only the first `num_neighbors` are used
303
+ - Learnable adjacency weights adapt importance of each direction during training
304
+
305
  ## Model Creation from Base Asterisk
306
 
307
  ```python
 
319
  config.pi_flow_scale = 1.0
320
  config.pi_flow_use_gate = True
321
 
322
+ # Optional: Configure ASPP graph propagation neighbors
323
+ # config.aspp_num_neighbors = 8 # Default: 8 (full 8-directional grid)
324
+ # config.aspp_num_neighbors = 4 # Alternative: 4 (cardinal directions only)
325
+
326
  # Create model with Ο€-flow
327
  model = AsteriskForCausalLM(config)
328
 
 
374
 
375
  ## Implementation Details
376
 
377
+ ### ASPP Graph Propagation Configuration
378
+
379
+ The ASPP operator supports configurable neighbor connectivity:
380
+
381
+ ```python
382
+ class ASPPOperator(nn.Module):
383
+ def __init__(
384
+ self,
385
+ hidden_size: int,
386
+ aspp_hidden_dim: Optional[int] = None,
387
+ num_steps: int = 2,
388
+ dropout: float = 0.1,
389
+ num_neighbors: int = 8 # Configurable: 4, 8, etc.
390
+ ):
391
+ # Neighbor offsets buffer (8 directions hardcoded)
392
+ self.register_buffer('neighbor_offsets', torch.tensor([
393
+ [-1, -1], [-1, 0], [-1, 1], # β†– ↑ β†—
394
+ [0, -1], [0, 1], # ← β†’
395
+ [1, -1], [1, 0], [1, 1] # ↙ ↓ β†˜
396
+ ]))
397
+
398
+ # Only first num_neighbors are used
399
+ self.num_neighbors = num_neighbors
400
+
401
+ # Learnable adjacency weights
402
+ self.adjacency_weights = nn.Parameter(torch.randn(num_neighbors) * 0.1)
403
+ ```
404
+
405
+ **Usage:**
406
+ - `num_neighbors=8`: Full 8-directional connectivity (default)
407
+ - `num_neighbors=4`: Cardinal directions only (↑←→↓)
408
+ - The buffer always contains 8 offsets, but only the first `num_neighbors` are used
409
+ - Adjacency weights are softmax-normalized for stability
410
+
411
  ### Return Type Handling
412
 
413
  Critical for Transformers compatibility:
 
460
 
461
  ## Known Issues & Solutions
462
 
463
+ ### 1. Neighbor Count Configuration (Fixed in Latest Version)
464
+
465
+ **Previous issue**: ASPP operator hardcoded 8 neighbors in reshape operations, causing errors when using different neighbor counts.
466
+
467
+ **Solution**: Updated implementation to use `self.num_neighbors` dynamically:
468
+ ```python
469
+ # Fixed: Dynamic neighbor count
470
+ for offset in self.neighbor_offsets[:self.num_neighbors]: # Only use first N
471
+ # ...gather neighbors...
472
+
473
+ # Fixed: Dynamic reshape
474
+ aggregated_neighbors = torch.sum(
475
+ neighbors * adj_weights.view(1, 1, 1, self.num_neighbors, 1),
476
+ dim=3
477
+ )
478
+ ```
479
+
480
+ Now supports any neighbor count (4, 8, etc.) without modification.
481
+
482
+ ### 2. Return Type Errors
483
 
484
  **Issue**: `AttributeError: 'tuple' object has no attribute 'dtype'`
485
 
486
  **Solution**: `HybridASPPAttentionLayer.forward()` must return `torch.Tensor` only, not tuple. This matches the `LlamaDecoderLayer` API in transformers 4.57.6.
487
 
488
+ ### 3. Ο€-Flow in All Layers vs Final Layer
489
 
490
  **Initial approach**: Ο€-flow only in final layer (limited expressiveness)
491
 
492
  **Current approach**: Ο€-flow in all 30 hybrid layers for maximum refinement capability.
493
 
494
+ ### 4. Training Stability
495
 
496
  Ο€-Flow can cause instability with high learning rates. Use:
497
+ - Lower learning rate (1e-4 recommended for stability)
498
  - Gradient clipping (max_norm=1.0)
499
  - Conservative initial flow scale (0.2-1.0)
500
 
chat_template.jinja CHANGED
@@ -1,5 +1,5 @@
1
  {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
2
- You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
3
  ' }}{% endif %}{{'<|im_start|>' + message['role'] + '
4
  ' + message['content'] + '<|im_end|>' + '
5
  '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
 
1
  {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
2
+ You are a helpful AI assistant named Asterisk, trained by NoesisLab<|im_end|>
3
  ' }}{% endif %}{{'<|im_start|>' + message['role'] + '
4
  ' + message['content'] + '<|im_end|>' + '
5
  '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
config.json CHANGED
@@ -2,8 +2,9 @@
2
  "architectures": [
3
  "AsteriskForCausalLM"
4
  ],
5
- "aspp_dropout": 0.2,
6
  "aspp_hidden_dim": 256,
 
7
  "aspp_num_steps": 4,
8
  "attention_bias": false,
9
  "attention_dropout": 0.0,
 
2
  "architectures": [
3
  "AsteriskForCausalLM"
4
  ],
5
+ "aspp_dropout": 0.1,
6
  "aspp_hidden_dim": 256,
7
+ "aspp_num_neighbors": 8,
8
  "aspp_num_steps": 4,
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd3411332c19c27ac340b99a92d91e0b93f224b62fa3e0cccf7777b4e126b802
3
- size 381107624
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41550e1413295a2b5e02127758543c905650e9c834c6b53e382238f4021f3668
3
+ size 396858360
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:357a1e8bcbd247f80b9437f6d4dd9e81a29edbafaa6fea075a7380b6927773f4
3
  size 6353
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79668c78f13b1c865f88ffab2a80bc10c893e3262c29261ee8d447ec47b717d3
3
  size 6353