Upload 14 files
Browse files- AsteriskForCausalLM.py +113 -27
- README.md +101 -8
- chat_template.jinja +1 -1
- config.json +2 -1
- model.safetensors +2 -2
- training_args.bin +1 -1
AsteriskForCausalLM.py
CHANGED
|
@@ -36,6 +36,7 @@ class AsteriskConfig(LlamaConfig):
|
|
| 36 |
aspp_hidden_dim: Optional[int] = None,
|
| 37 |
aspp_num_steps: int = 2,
|
| 38 |
aspp_dropout: float = 0.1,
|
|
|
|
| 39 |
# Ο-flow parameters
|
| 40 |
pi_flow: bool = False,
|
| 41 |
pi_flow_steps: int = 1,
|
|
@@ -48,6 +49,7 @@ class AsteriskConfig(LlamaConfig):
|
|
| 48 |
self.aspp_hidden_dim = aspp_hidden_dim
|
| 49 |
self.aspp_num_steps = aspp_num_steps
|
| 50 |
self.aspp_dropout = aspp_dropout
|
|
|
|
| 51 |
# Ο-flow config
|
| 52 |
self.pi_flow = pi_flow
|
| 53 |
self.pi_flow_steps = pi_flow_steps
|
|
@@ -57,26 +59,31 @@ class AsteriskConfig(LlamaConfig):
|
|
| 57 |
|
| 58 |
class ASPPOperator(nn.Module):
|
| 59 |
"""
|
| 60 |
-
Asterisk Operator (ASPP) -
|
| 61 |
|
| 62 |
-
|
| 63 |
-
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
Args:
|
| 69 |
hidden_size: Dimension of hidden states (input/output)
|
| 70 |
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
|
| 71 |
num_steps: Number of evolution steps K (default: 2)
|
| 72 |
dropout: Dropout rate for regularization (default: 0.1)
|
|
|
|
| 73 |
"""
|
| 74 |
|
| 75 |
-
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1):
|
| 76 |
super().__init__()
|
| 77 |
self.hidden_size = hidden_size
|
| 78 |
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
|
| 79 |
self.num_steps = num_steps
|
|
|
|
| 80 |
|
| 81 |
# Projection to lower dimension (if specified)
|
| 82 |
self.use_projection = (self.aspp_hidden_dim != hidden_size)
|
|
@@ -85,10 +92,22 @@ class ASPPOperator(nn.Module):
|
|
| 85 |
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
|
| 86 |
self.proj_dropout = nn.Dropout(dropout)
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
nn.SiLU(),
|
| 93 |
nn.Dropout(dropout),
|
| 94 |
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
|
|
@@ -96,7 +115,6 @@ class ASPPOperator(nn.Module):
|
|
| 96 |
)
|
| 97 |
|
| 98 |
# Learnable K-step parameter
|
| 99 |
-
# sigmoid(1.0) β 0.73, giving k_steps β 1.5 β 2 steps initially
|
| 100 |
self.k_logit = nn.Parameter(torch.tensor(1.0))
|
| 101 |
|
| 102 |
# Learnable residual scale
|
|
@@ -112,6 +130,8 @@ class ASPPOperator(nn.Module):
|
|
| 112 |
Returns:
|
| 113 |
evolved_states: [batch_size, seq_len, hidden_size]
|
| 114 |
"""
|
|
|
|
|
|
|
| 115 |
# Project to lower dimension if needed
|
| 116 |
if self.use_projection:
|
| 117 |
h_t = self.down_proj(hidden_states)
|
|
@@ -122,12 +142,70 @@ class ASPPOperator(nn.Module):
|
|
| 122 |
# Learnable number of steps
|
| 123 |
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
|
| 124 |
|
| 125 |
-
# K-step
|
| 126 |
for t in range(k_steps):
|
| 127 |
-
#
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
h_t = h_t + self.residual_scale * h_t_next
|
| 132 |
h_t = self.norm(h_t)
|
| 133 |
|
|
@@ -153,7 +231,7 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
|
|
| 153 |
4. Feed-forward network
|
| 154 |
"""
|
| 155 |
|
| 156 |
-
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
|
| 157 |
# Initialize parent LlamaDecoderLayer
|
| 158 |
super().__init__(config, layer_idx)
|
| 159 |
|
|
@@ -162,7 +240,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
|
|
| 162 |
hidden_size=config.hidden_size,
|
| 163 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 164 |
num_steps=aspp_num_steps,
|
| 165 |
-
dropout=aspp_dropout
|
|
|
|
| 166 |
)
|
| 167 |
|
| 168 |
# Gated fusion mechanism with dropout
|
|
@@ -182,7 +261,8 @@ class HybridASPPAttentionLayer(LlamaDecoderLayer):
|
|
| 182 |
hidden_size=config.hidden_size,
|
| 183 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 184 |
num_steps=aspp_num_steps,
|
| 185 |
-
dropout=aspp_dropout
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
# Learnable flow scale (per-layer)
|
|
@@ -276,7 +356,7 @@ class AsteriskLlamaModel(LlamaModel):
|
|
| 276 |
All layers use hybrid ASPP+Attention by default for maximum expressiveness.
|
| 277 |
"""
|
| 278 |
|
| 279 |
-
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
|
| 280 |
super().__init__(config)
|
| 281 |
|
| 282 |
# Determine which layers to make hybrid (default: ALL layers)
|
|
@@ -295,7 +375,8 @@ class AsteriskLlamaModel(LlamaModel):
|
|
| 295 |
layer_idx=idx,
|
| 296 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 297 |
aspp_num_steps=aspp_num_steps,
|
| 298 |
-
aspp_dropout=aspp_dropout
|
|
|
|
| 299 |
)
|
| 300 |
|
| 301 |
# Initialize weights
|
|
@@ -311,7 +392,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 311 |
|
| 312 |
config_class = AsteriskConfig
|
| 313 |
|
| 314 |
-
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1):
|
| 315 |
# Read all ASPP parameters from config if not explicitly provided
|
| 316 |
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
|
| 317 |
hybrid_layer_indices = config.hybrid_layer_indices
|
|
@@ -321,11 +402,13 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 321 |
aspp_num_steps = config.aspp_num_steps
|
| 322 |
if hasattr(config, 'aspp_dropout'):
|
| 323 |
aspp_dropout = config.aspp_dropout
|
|
|
|
|
|
|
| 324 |
|
| 325 |
super().__init__(config)
|
| 326 |
|
| 327 |
# Replace model with Asterisk version
|
| 328 |
-
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
|
| 329 |
|
| 330 |
# Store hybrid layer info in config for serialization
|
| 331 |
self.config.hybrid_layer_indices = hybrid_layer_indices
|
|
@@ -341,6 +424,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 341 |
aspp_hidden_dim: Optional[int] = None,
|
| 342 |
aspp_num_steps: int = 2,
|
| 343 |
aspp_dropout: float = 0.1,
|
|
|
|
| 344 |
# Ο-flow parameters
|
| 345 |
pi_flow: bool = False,
|
| 346 |
pi_flow_steps: int = 1,
|
|
@@ -357,6 +441,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 357 |
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
|
| 358 |
aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
|
| 359 |
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
|
|
|
|
| 360 |
pi_flow: Enable Ο-flow refinement step (default: False)
|
| 361 |
pi_flow_steps: Number of flow refinement steps (default: 1)
|
| 362 |
pi_flow_scale: Initial flow scale parameter (default: 0.2)
|
|
@@ -373,6 +458,7 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 373 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 374 |
aspp_num_steps=aspp_num_steps,
|
| 375 |
aspp_dropout=aspp_dropout,
|
|
|
|
| 376 |
pi_flow=pi_flow,
|
| 377 |
pi_flow_steps=pi_flow_steps,
|
| 378 |
pi_flow_scale=pi_flow_scale,
|
|
@@ -380,15 +466,15 @@ class AsteriskForCausalLM(LlamaForCausalLM):
|
|
| 380 |
)
|
| 381 |
|
| 382 |
# Create Asterisk model
|
| 383 |
-
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout)
|
| 384 |
|
| 385 |
# Transfer weights from base model (non-hybrid layers and embeddings)
|
| 386 |
asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
|
| 387 |
|
| 388 |
-
print(f"β Converted base model to Asterisk architecture")
|
| 389 |
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
|
| 390 |
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
|
| 391 |
-
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}")
|
| 392 |
if pi_flow:
|
| 393 |
print(f" Ο-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
|
| 394 |
|
|
|
|
| 36 |
aspp_hidden_dim: Optional[int] = None,
|
| 37 |
aspp_num_steps: int = 2,
|
| 38 |
aspp_dropout: float = 0.1,
|
| 39 |
+
aspp_num_neighbors: int = 8, # NEW: number of semantic neighbors
|
| 40 |
# Ο-flow parameters
|
| 41 |
pi_flow: bool = False,
|
| 42 |
pi_flow_steps: int = 1,
|
|
|
|
| 49 |
self.aspp_hidden_dim = aspp_hidden_dim
|
| 50 |
self.aspp_num_steps = aspp_num_steps
|
| 51 |
self.aspp_dropout = aspp_dropout
|
| 52 |
+
self.aspp_num_neighbors = aspp_num_neighbors
|
| 53 |
# Ο-flow config
|
| 54 |
self.pi_flow = pi_flow
|
| 55 |
self.pi_flow_steps = pi_flow_steps
|
|
|
|
| 59 |
|
| 60 |
class ASPPOperator(nn.Module):
|
| 61 |
"""
|
| 62 |
+
Asterisk Operator (ASPP) - Graph Propagation with Adjacency List (2D Grid)
|
| 63 |
|
| 64 |
+
Converts 1D sequence to 2D grid and performs 8-neighbor graph propagation:
|
| 65 |
+
- Adjacency list: 8 neighbors per position (ζ£η8ι»ε±
)
|
| 66 |
+
β β β
|
| 67 |
+
β β β
|
| 68 |
+
β β β
|
| 69 |
+
- Learnable adjacency weights (randomly initialized)
|
| 70 |
+
- Padding for dynamic sequence lengths
|
| 71 |
+
- Message passing: h_i^(t+1) = Ο(h_i^(t), Ξ£_jβN(i) A_j * h_j^(t))
|
| 72 |
|
| 73 |
Args:
|
| 74 |
hidden_size: Dimension of hidden states (input/output)
|
| 75 |
aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
|
| 76 |
num_steps: Number of evolution steps K (default: 2)
|
| 77 |
dropout: Dropout rate for regularization (default: 0.1)
|
| 78 |
+
num_neighbors: Number of neighbors (default: 8, for 8-directional grid)
|
| 79 |
"""
|
| 80 |
|
| 81 |
+
def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 8):
|
| 82 |
super().__init__()
|
| 83 |
self.hidden_size = hidden_size
|
| 84 |
self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
|
| 85 |
self.num_steps = num_steps
|
| 86 |
+
self.num_neighbors = num_neighbors
|
| 87 |
|
| 88 |
# Projection to lower dimension (if specified)
|
| 89 |
self.use_projection = (self.aspp_hidden_dim != hidden_size)
|
|
|
|
| 92 |
self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
|
| 93 |
self.proj_dropout = nn.Dropout(dropout)
|
| 94 |
|
| 95 |
+
# 8-directional offsets for 2D grid neighbors (row, col)
|
| 96 |
+
# β(-1,-1) β(-1,0) β(-1,1)
|
| 97 |
+
# β(0,-1) β(0,0) β(0,1)
|
| 98 |
+
# β(1,-1) β(1,0) β(1,1)
|
| 99 |
+
self.register_buffer('neighbor_offsets', torch.tensor([
|
| 100 |
+
[-1, -1], [-1, 0], [-1, 1], # top-left, top, top-right
|
| 101 |
+
[0, -1], [0, 1], # left, right
|
| 102 |
+
[1, -1], [1, 0], [1, 1] # bottom-left, bottom, bottom-right
|
| 103 |
+
], dtype=torch.long)) # [8, 2]
|
| 104 |
+
|
| 105 |
+
# Learnable adjacency weights for 8 directions (randomly initialized)
|
| 106 |
+
self.adjacency_weights = nn.Parameter(torch.randn(num_neighbors) * 0.1)
|
| 107 |
+
|
| 108 |
+
# Message aggregation function: combines self + neighbors
|
| 109 |
+
self.message_net = nn.Sequential(
|
| 110 |
+
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2),
|
| 111 |
nn.SiLU(),
|
| 112 |
nn.Dropout(dropout),
|
| 113 |
nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
# Learnable K-step parameter
|
|
|
|
| 118 |
self.k_logit = nn.Parameter(torch.tensor(1.0))
|
| 119 |
|
| 120 |
# Learnable residual scale
|
|
|
|
| 130 |
Returns:
|
| 131 |
evolved_states: [batch_size, seq_len, hidden_size]
|
| 132 |
"""
|
| 133 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 134 |
+
|
| 135 |
# Project to lower dimension if needed
|
| 136 |
if self.use_projection:
|
| 137 |
h_t = self.down_proj(hidden_states)
|
|
|
|
| 142 |
# Learnable number of steps
|
| 143 |
k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
|
| 144 |
|
| 145 |
+
# K-step graph propagation with 2D grid adjacency list
|
| 146 |
for t in range(k_steps):
|
| 147 |
+
# 1. Reshape 1D sequence to 2D grid
|
| 148 |
+
# Dynamic grid size: H β W β sqrt(seq_len)
|
| 149 |
+
H = int(torch.ceil(torch.sqrt(torch.tensor(seq_len, dtype=torch.float32))).item())
|
| 150 |
+
W = int(torch.ceil(torch.tensor(seq_len, dtype=torch.float32) / H).item())
|
| 151 |
+
grid_size = H * W
|
| 152 |
+
|
| 153 |
+
# Pad sequence to grid_size
|
| 154 |
+
if seq_len < grid_size:
|
| 155 |
+
padding = grid_size - seq_len
|
| 156 |
+
h_t_padded = F.pad(h_t, (0, 0, 0, padding), mode='constant', value=0) # [B, H*W, D]
|
| 157 |
+
else:
|
| 158 |
+
h_t_padded = h_t
|
| 159 |
+
|
| 160 |
+
# Reshape to 2D grid: [B, H*W, D] -> [B, H, W, D]
|
| 161 |
+
h_grid = h_t_padded.view(batch_size, H, W, self.aspp_hidden_dim)
|
| 162 |
+
|
| 163 |
+
# 2. Add boundary padding for neighbor gathering (pad with zeros)
|
| 164 |
+
# Pad 1 row/col on each side: [B, H, W, D] -> [B, H+2, W+2, D]
|
| 165 |
+
h_grid_padded = F.pad(h_grid, (0, 0, 1, 1, 1, 1), mode='constant', value=0)
|
| 166 |
+
|
| 167 |
+
# 3. Gather neighbors using adjacency list (offsets)
|
| 168 |
+
# neighbor_offsets: [8, 2] with (row_offset, col_offset)
|
| 169 |
+
# Only use first num_neighbors offsets
|
| 170 |
+
neighbors_list = []
|
| 171 |
+
for offset in self.neighbor_offsets[:self.num_neighbors]:
|
| 172 |
+
# Offset is relative to center, but we need absolute indices in padded grid
|
| 173 |
+
# Center at (i+1, j+1) in padded grid, neighbor at (i+1+di, j+1+dj)
|
| 174 |
+
# Use roll to shift the grid
|
| 175 |
+
di, dj = offset[0].item(), offset[1].item()
|
| 176 |
+
|
| 177 |
+
# Create index tensors for gathering
|
| 178 |
+
# For each position (i,j) in original grid, get neighbor at (i+di, j+dj) in padded grid
|
| 179 |
+
row_indices = torch.arange(H, device=h_t.device).view(-1, 1).expand(H, W) + 1 + di
|
| 180 |
+
col_indices = torch.arange(W, device=h_t.device).view(1, -1).expand(H, W) + 1 + dj
|
| 181 |
+
|
| 182 |
+
# Gather neighbor features: [B, H, W, D]
|
| 183 |
+
neighbor = h_grid_padded[:, row_indices, col_indices, :] # [B, H, W, D]
|
| 184 |
+
neighbors_list.append(neighbor)
|
| 185 |
+
|
| 186 |
+
# Stack neighbors: [B, H, W, num_neighbors, D]
|
| 187 |
+
neighbors = torch.stack(neighbors_list, dim=3) # [B, H, W, num_neighbors, D]
|
| 188 |
+
|
| 189 |
+
# 4. Apply learnable adjacency weights
|
| 190 |
+
# adjacency_weights: [num_neighbors] -> normalize with softmax
|
| 191 |
+
adj_weights = F.softmax(self.adjacency_weights, dim=0) # [num_neighbors]
|
| 192 |
+
|
| 193 |
+
# Weighted aggregation: [B, H, W, D]
|
| 194 |
+
aggregated_neighbors = torch.sum(neighbors * adj_weights.view(1, 1, 1, self.num_neighbors, 1), dim=3)
|
| 195 |
+
|
| 196 |
+
# 5. Message passing: combine self + neighbors
|
| 197 |
+
# Flatten back to sequence: [B, H, W, D] -> [B, H*W, D]
|
| 198 |
+
h_grid_flat = h_grid.view(batch_size, grid_size, self.aspp_hidden_dim)
|
| 199 |
+
aggregated_flat = aggregated_neighbors.view(batch_size, grid_size, self.aspp_hidden_dim)
|
| 200 |
+
|
| 201 |
+
# Concat and pass through message net
|
| 202 |
+
message_input = torch.cat([h_grid_flat, aggregated_flat], dim=-1) # [B, H*W, 2D]
|
| 203 |
+
h_t_next = self.message_net(message_input) # [B, H*W, D]
|
| 204 |
+
|
| 205 |
+
# Remove padding to restore original seq_len
|
| 206 |
+
h_t_next = h_t_next[:, :seq_len, :] # [B, L, D]
|
| 207 |
+
|
| 208 |
+
# 6. Scaled residual connection for stability
|
| 209 |
h_t = h_t + self.residual_scale * h_t_next
|
| 210 |
h_t = self.norm(h_t)
|
| 211 |
|
|
|
|
| 231 |
4. Feed-forward network
|
| 232 |
"""
|
| 233 |
|
| 234 |
+
def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 8):
|
| 235 |
# Initialize parent LlamaDecoderLayer
|
| 236 |
super().__init__(config, layer_idx)
|
| 237 |
|
|
|
|
| 240 |
hidden_size=config.hidden_size,
|
| 241 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 242 |
num_steps=aspp_num_steps,
|
| 243 |
+
dropout=aspp_dropout,
|
| 244 |
+
num_neighbors=aspp_num_neighbors
|
| 245 |
)
|
| 246 |
|
| 247 |
# Gated fusion mechanism with dropout
|
|
|
|
| 261 |
hidden_size=config.hidden_size,
|
| 262 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 263 |
num_steps=aspp_num_steps,
|
| 264 |
+
dropout=aspp_dropout,
|
| 265 |
+
num_neighbors=aspp_num_neighbors
|
| 266 |
)
|
| 267 |
|
| 268 |
# Learnable flow scale (per-layer)
|
|
|
|
| 356 |
All layers use hybrid ASPP+Attention by default for maximum expressiveness.
|
| 357 |
"""
|
| 358 |
|
| 359 |
+
def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 8):
|
| 360 |
super().__init__(config)
|
| 361 |
|
| 362 |
# Determine which layers to make hybrid (default: ALL layers)
|
|
|
|
| 375 |
layer_idx=idx,
|
| 376 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 377 |
aspp_num_steps=aspp_num_steps,
|
| 378 |
+
aspp_dropout=aspp_dropout,
|
| 379 |
+
aspp_num_neighbors=aspp_num_neighbors
|
| 380 |
)
|
| 381 |
|
| 382 |
# Initialize weights
|
|
|
|
| 392 |
|
| 393 |
config_class = AsteriskConfig
|
| 394 |
|
| 395 |
+
def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 8):
|
| 396 |
# Read all ASPP parameters from config if not explicitly provided
|
| 397 |
if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
|
| 398 |
hybrid_layer_indices = config.hybrid_layer_indices
|
|
|
|
| 402 |
aspp_num_steps = config.aspp_num_steps
|
| 403 |
if hasattr(config, 'aspp_dropout'):
|
| 404 |
aspp_dropout = config.aspp_dropout
|
| 405 |
+
if hasattr(config, 'aspp_num_neighbors'):
|
| 406 |
+
aspp_num_neighbors = config.aspp_num_neighbors
|
| 407 |
|
| 408 |
super().__init__(config)
|
| 409 |
|
| 410 |
# Replace model with Asterisk version
|
| 411 |
+
self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
|
| 412 |
|
| 413 |
# Store hybrid layer info in config for serialization
|
| 414 |
self.config.hybrid_layer_indices = hybrid_layer_indices
|
|
|
|
| 424 |
aspp_hidden_dim: Optional[int] = None,
|
| 425 |
aspp_num_steps: int = 2,
|
| 426 |
aspp_dropout: float = 0.1,
|
| 427 |
+
aspp_num_neighbors: int = 8, # NEW: number of semantic neighbors
|
| 428 |
# Ο-flow parameters
|
| 429 |
pi_flow: bool = False,
|
| 430 |
pi_flow_steps: int = 1,
|
|
|
|
| 441 |
aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
|
| 442 |
aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
|
| 443 |
aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
|
| 444 |
+
aspp_num_neighbors: Number of semantic neighbors for graph propagation (default: 8)
|
| 445 |
pi_flow: Enable Ο-flow refinement step (default: False)
|
| 446 |
pi_flow_steps: Number of flow refinement steps (default: 1)
|
| 447 |
pi_flow_scale: Initial flow scale parameter (default: 0.2)
|
|
|
|
| 458 |
aspp_hidden_dim=aspp_hidden_dim,
|
| 459 |
aspp_num_steps=aspp_num_steps,
|
| 460 |
aspp_dropout=aspp_dropout,
|
| 461 |
+
aspp_num_neighbors=aspp_num_neighbors,
|
| 462 |
pi_flow=pi_flow,
|
| 463 |
pi_flow_steps=pi_flow_steps,
|
| 464 |
pi_flow_scale=pi_flow_scale,
|
|
|
|
| 466 |
)
|
| 467 |
|
| 468 |
# Create Asterisk model
|
| 469 |
+
asterisk_model = cls(asterisk_config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
|
| 470 |
|
| 471 |
# Transfer weights from base model (non-hybrid layers and embeddings)
|
| 472 |
asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
|
| 473 |
|
| 474 |
+
print(f"β Converted base model to Asterisk architecture with Graph Propagation")
|
| 475 |
print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
|
| 476 |
aspp_dim_str = f"{aspp_hidden_dim}" if aspp_hidden_dim else f"{base_config.hidden_size} (full)"
|
| 477 |
+
print(f" ASPP config: dim={aspp_dim_str}, steps={aspp_num_steps}, dropout={aspp_dropout}, neighbors={aspp_num_neighbors}")
|
| 478 |
if pi_flow:
|
| 479 |
print(f" Ο-flow enabled: steps={pi_flow_steps}, scale={pi_flow_scale}, gate={pi_flow_use_gate}")
|
| 480 |
|
README.md
CHANGED
|
@@ -100,12 +100,41 @@ class HybridASPPAttentionLayer:
|
|
| 100 |
Combines ASPP operator with standard attention
|
| 101 |
|
| 102 |
Components:
|
| 103 |
-
- ASPP operator: Local structured reasoning
|
| 104 |
- Standard attention: Global context
|
| 105 |
- Gated fusion: Dynamic balancing
|
| 106 |
"""
|
| 107 |
```
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
**Fusion mechanism:**
|
| 110 |
```
|
| 111 |
aspp_out = ASPP(hidden_states)
|
|
@@ -240,9 +269,9 @@ Total: ~10,148 training samples
|
|
| 240 |
### Training Configuration
|
| 241 |
|
| 242 |
- **Starting Point**: Asterisk checkpoint (base ASPP-Attention model)
|
| 243 |
-
- **Optimizer**: AdamW (lr=
|
| 244 |
-
- **Batch Size**:
|
| 245 |
-
- **Epochs**: 2
|
| 246 |
- **Scheduler**: Linear warmup (10% of steps)
|
| 247 |
- **Mixed Precision**: bfloat16
|
| 248 |
- **Gradient Checkpointing**: Enabled
|
|
@@ -263,9 +292,16 @@ pi_flow_use_gate = True # Token-wise adaptive gating
|
|
| 263 |
aspp_hidden_dim = 256 # Internal dimension (vs 576 model hidden_size)
|
| 264 |
aspp_num_steps = 4 # Evolution steps for ASPP
|
| 265 |
aspp_dropout = 0.2 # Regularization
|
|
|
|
| 266 |
hybrid_layer_indices = None # All 30 layers
|
| 267 |
```
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
## Model Creation from Base Asterisk
|
| 270 |
|
| 271 |
```python
|
|
@@ -283,6 +319,10 @@ config.pi_flow_steps = 2
|
|
| 283 |
config.pi_flow_scale = 1.0
|
| 284 |
config.pi_flow_use_gate = True
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
# Create model with Ο-flow
|
| 287 |
model = AsteriskForCausalLM(config)
|
| 288 |
|
|
@@ -334,6 +374,40 @@ This creates a **hierarchical refinement cascade** enabling gradual convergence
|
|
| 334 |
|
| 335 |
## Implementation Details
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
### Return Type Handling
|
| 338 |
|
| 339 |
Critical for Transformers compatibility:
|
|
@@ -386,22 +460,41 @@ Asterisk-Pi/
|
|
| 386 |
|
| 387 |
## Known Issues & Solutions
|
| 388 |
|
| 389 |
-
### 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
**Issue**: `AttributeError: 'tuple' object has no attribute 'dtype'`
|
| 392 |
|
| 393 |
**Solution**: `HybridASPPAttentionLayer.forward()` must return `torch.Tensor` only, not tuple. This matches the `LlamaDecoderLayer` API in transformers 4.57.6.
|
| 394 |
|
| 395 |
-
###
|
| 396 |
|
| 397 |
**Initial approach**: Ο-flow only in final layer (limited expressiveness)
|
| 398 |
|
| 399 |
**Current approach**: Ο-flow in all 30 hybrid layers for maximum refinement capability.
|
| 400 |
|
| 401 |
-
###
|
| 402 |
|
| 403 |
Ο-Flow can cause instability with high learning rates. Use:
|
| 404 |
-
- Lower learning rate (
|
| 405 |
- Gradient clipping (max_norm=1.0)
|
| 406 |
- Conservative initial flow scale (0.2-1.0)
|
| 407 |
|
|
|
|
| 100 |
Combines ASPP operator with standard attention
|
| 101 |
|
| 102 |
Components:
|
| 103 |
+
- ASPP operator: Local structured reasoning with graph propagation
|
| 104 |
- Standard attention: Global context
|
| 105 |
- Gated fusion: Dynamic balancing
|
| 106 |
"""
|
| 107 |
```
|
| 108 |
|
| 109 |
+
**ASPP Operator - Graph Propagation:**
|
| 110 |
+
|
| 111 |
+
The ASPP operator converts the 1D sequence into a 2D grid and performs graph-based message passing:
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
Sequence [1, 2, 3, 4, ...] β 2D Grid:
|
| 115 |
+
βββββ¬ββββ¬ββββ
|
| 116 |
+
β 1 β 2 β 3 β
|
| 117 |
+
βββββΌββββΌββββ€
|
| 118 |
+
β 4 β 5 β 6 β
|
| 119 |
+
βββββ΄ββββ΄ββββ
|
| 120 |
+
|
| 121 |
+
8-directional neighbors (default):
|
| 122 |
+
β β β
|
| 123 |
+
β β β
|
| 124 |
+
β β β
|
| 125 |
+
|
| 126 |
+
4-directional neighbors (optional):
|
| 127 |
+
β
|
| 128 |
+
β β β
|
| 129 |
+
β
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
**Key features:**
|
| 133 |
+
- **Configurable neighbors**: `aspp_num_neighbors` (default: 8)
|
| 134 |
+
- **Learnable adjacency weights**: Each direction has a learnable weight
|
| 135 |
+
- **K-step evolution**: Iterative message passing for `aspp_num_steps` (default: 4)
|
| 136 |
+
- **Dynamic grid**: Grid dimensions adapt to sequence length (H β W β βseq_len)
|
| 137 |
+
|
| 138 |
**Fusion mechanism:**
|
| 139 |
```
|
| 140 |
aspp_out = ASPP(hidden_states)
|
|
|
|
| 269 |
### Training Configuration
|
| 270 |
|
| 271 |
- **Starting Point**: Asterisk checkpoint (base ASPP-Attention model)
|
| 272 |
+
- **Optimizer**: AdamW (lr=1e-4, weight_decay=0.1)
|
| 273 |
+
- **Batch Size**: 4 per device, gradient accumulation=4 (effective batch=16)
|
| 274 |
+
- **Epochs**: 2.5
|
| 275 |
- **Scheduler**: Linear warmup (10% of steps)
|
| 276 |
- **Mixed Precision**: bfloat16
|
| 277 |
- **Gradient Checkpointing**: Enabled
|
|
|
|
| 292 |
aspp_hidden_dim = 256 # Internal dimension (vs 576 model hidden_size)
|
| 293 |
aspp_num_steps = 4 # Evolution steps for ASPP
|
| 294 |
aspp_dropout = 0.2 # Regularization
|
| 295 |
+
aspp_num_neighbors = 8 # Number of semantic neighbors for graph propagation (default: 8)
|
| 296 |
hybrid_layer_indices = None # All 30 layers
|
| 297 |
```
|
| 298 |
|
| 299 |
+
**Graph Propagation Neighbors:**
|
| 300 |
+
- Default: 8-directional grid (ββββββββ)
|
| 301 |
+
- Configurable: Can use fewer neighbors (e.g., 4-directional: ββββ)
|
| 302 |
+
- The neighbor offsets are defined in a buffer, and only the first `num_neighbors` are used
|
| 303 |
+
- Learnable adjacency weights adapt importance of each direction during training
|
| 304 |
+
|
| 305 |
## Model Creation from Base Asterisk
|
| 306 |
|
| 307 |
```python
|
|
|
|
| 319 |
config.pi_flow_scale = 1.0
|
| 320 |
config.pi_flow_use_gate = True
|
| 321 |
|
| 322 |
+
# Optional: Configure ASPP graph propagation neighbors
|
| 323 |
+
# config.aspp_num_neighbors = 8 # Default: 8 (full 8-directional grid)
|
| 324 |
+
# config.aspp_num_neighbors = 4 # Alternative: 4 (cardinal directions only)
|
| 325 |
+
|
| 326 |
# Create model with Ο-flow
|
| 327 |
model = AsteriskForCausalLM(config)
|
| 328 |
|
|
|
|
| 374 |
|
| 375 |
## Implementation Details
|
| 376 |
|
| 377 |
+
### ASPP Graph Propagation Configuration
|
| 378 |
+
|
| 379 |
+
The ASPP operator supports configurable neighbor connectivity:
|
| 380 |
+
|
| 381 |
+
```python
|
| 382 |
+
class ASPPOperator(nn.Module):
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
hidden_size: int,
|
| 386 |
+
aspp_hidden_dim: Optional[int] = None,
|
| 387 |
+
num_steps: int = 2,
|
| 388 |
+
dropout: float = 0.1,
|
| 389 |
+
num_neighbors: int = 8 # Configurable: 4, 8, etc.
|
| 390 |
+
):
|
| 391 |
+
# Neighbor offsets buffer (8 directions hardcoded)
|
| 392 |
+
self.register_buffer('neighbor_offsets', torch.tensor([
|
| 393 |
+
[-1, -1], [-1, 0], [-1, 1], # β β β
|
| 394 |
+
[0, -1], [0, 1], # β β
|
| 395 |
+
[1, -1], [1, 0], [1, 1] # β β β
|
| 396 |
+
]))
|
| 397 |
+
|
| 398 |
+
# Only first num_neighbors are used
|
| 399 |
+
self.num_neighbors = num_neighbors
|
| 400 |
+
|
| 401 |
+
# Learnable adjacency weights
|
| 402 |
+
self.adjacency_weights = nn.Parameter(torch.randn(num_neighbors) * 0.1)
|
| 403 |
+
```
|
| 404 |
+
|
| 405 |
+
**Usage:**
|
| 406 |
+
- `num_neighbors=8`: Full 8-directional connectivity (default)
|
| 407 |
+
- `num_neighbors=4`: Cardinal directions only (ββββ)
|
| 408 |
+
- The buffer always contains 8 offsets, but only the first `num_neighbors` are used
|
| 409 |
+
- Adjacency weights are softmax-normalized for stability
|
| 410 |
+
|
| 411 |
### Return Type Handling
|
| 412 |
|
| 413 |
Critical for Transformers compatibility:
|
|
|
|
| 460 |
|
| 461 |
## Known Issues & Solutions
|
| 462 |
|
| 463 |
+
### 1. Neighbor Count Configuration (Fixed in Latest Version)
|
| 464 |
+
|
| 465 |
+
**Previous issue**: ASPP operator hardcoded 8 neighbors in reshape operations, causing errors when using different neighbor counts.
|
| 466 |
+
|
| 467 |
+
**Solution**: Updated implementation to use `self.num_neighbors` dynamically:
|
| 468 |
+
```python
|
| 469 |
+
# Fixed: Dynamic neighbor count
|
| 470 |
+
for offset in self.neighbor_offsets[:self.num_neighbors]: # Only use first N
|
| 471 |
+
# ...gather neighbors...
|
| 472 |
+
|
| 473 |
+
# Fixed: Dynamic reshape
|
| 474 |
+
aggregated_neighbors = torch.sum(
|
| 475 |
+
neighbors * adj_weights.view(1, 1, 1, self.num_neighbors, 1),
|
| 476 |
+
dim=3
|
| 477 |
+
)
|
| 478 |
+
```
|
| 479 |
+
|
| 480 |
+
Now supports any neighbor count (4, 8, etc.) without modification.
|
| 481 |
+
|
| 482 |
+
### 2. Return Type Errors
|
| 483 |
|
| 484 |
**Issue**: `AttributeError: 'tuple' object has no attribute 'dtype'`
|
| 485 |
|
| 486 |
**Solution**: `HybridASPPAttentionLayer.forward()` must return `torch.Tensor` only, not tuple. This matches the `LlamaDecoderLayer` API in transformers 4.57.6.
|
| 487 |
|
| 488 |
+
### 3. Ο-Flow in All Layers vs Final Layer
|
| 489 |
|
| 490 |
**Initial approach**: Ο-flow only in final layer (limited expressiveness)
|
| 491 |
|
| 492 |
**Current approach**: Ο-flow in all 30 hybrid layers for maximum refinement capability.
|
| 493 |
|
| 494 |
+
### 4. Training Stability
|
| 495 |
|
| 496 |
Ο-Flow can cause instability with high learning rates. Use:
|
| 497 |
+
- Lower learning rate (1e-4 recommended for stability)
|
| 498 |
- Gradient clipping (max_norm=1.0)
|
| 499 |
- Conservative initial flow scale (0.2-1.0)
|
| 500 |
|
chat_template.jinja
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
|
| 2 |
-
You are a helpful AI assistant named
|
| 3 |
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
|
| 4 |
' + message['content'] + '<|im_end|>' + '
|
| 5 |
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
|
|
|
|
| 1 |
{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
|
| 2 |
+
You are a helpful AI assistant named Asterisk, trained by NoesisLab<|im_end|>
|
| 3 |
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
|
| 4 |
' + message['content'] + '<|im_end|>' + '
|
| 5 |
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
|
config.json
CHANGED
|
@@ -2,8 +2,9 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"AsteriskForCausalLM"
|
| 4 |
],
|
| 5 |
-
"aspp_dropout": 0.
|
| 6 |
"aspp_hidden_dim": 256,
|
|
|
|
| 7 |
"aspp_num_steps": 4,
|
| 8 |
"attention_bias": false,
|
| 9 |
"attention_dropout": 0.0,
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"AsteriskForCausalLM"
|
| 4 |
],
|
| 5 |
+
"aspp_dropout": 0.1,
|
| 6 |
"aspp_hidden_dim": 256,
|
| 7 |
+
"aspp_num_neighbors": 8,
|
| 8 |
"aspp_num_steps": 4,
|
| 9 |
"attention_bias": false,
|
| 10 |
"attention_dropout": 0.0,
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:41550e1413295a2b5e02127758543c905650e9c834c6b53e382238f4021f3668
|
| 3 |
+
size 396858360
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 6353
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:79668c78f13b1c865f88ffab2a80bc10c893e3262c29261ee8d447ec47b717d3
|
| 3 |
size 6353
|