File size: 6,057 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
Stage 1: Expert selection from routing dumps.
Two metrics (MoE-sparse-friendly):
1. Selection Frequency Diff:
Δfreq(l, e) = P(e ∈ TopK | S_plan) - P(e ∈ TopK | S_exec)
2. Log Ratio (sensitive to sparse distributions):
LR(l, e) = log((freq_plan + eps) / (freq_exec + eps))
And a cross-dimensional metric:
3. Cross-dim Contrast:
Δfreq_cross(l, e) = P(e | S_plan) - P(e | S_mon)
This identifies experts that TRULY separate the two dimensions.
Top-K experts: ranked by a combined score.
"""
import numpy as np
import torch
from typing import Dict, List, Tuple
from configs.model import MODEL_CONFIG
def compute_selection_frequency(
topk_ids_by_layer: Dict[int, torch.Tensor], # layer_id -> (S, top_k) int16
token_indices: List[int],
num_experts: int,
) -> np.ndarray:
"""
Compute P(expert e is in topK at layer l) over the given token_indices.
Returns (num_layers, num_experts) float32 array of frequencies in [0, 1].
"""
num_layers = MODEL_CONFIG["num_layers"]
freq = np.zeros((num_layers, num_experts), dtype=np.float32)
if not token_indices:
return freq
n = len(token_indices)
for li, topk_ids in topk_ids_by_layer.items():
# topk_ids: (S, top_k)
# Select rows at token_indices
sel = topk_ids[token_indices].numpy().astype(np.int64) # (n, top_k)
# Count occurrences per expert
flat = sel.flatten()
bincount = np.bincount(flat, minlength=num_experts).astype(np.float32)
freq[li] = bincount / n
return freq
def compute_gating_weight(
topk_ids_by_layer: Dict[int, torch.Tensor],
topk_gates_by_layer: Dict[int, torch.Tensor],
token_indices: List[int],
num_experts: int,
) -> np.ndarray:
"""
Compute E[gating weight of expert e at layer l] over token_indices
(conditional on e being in topK).
Returns (num_layers, num_experts) float32.
"""
num_layers = MODEL_CONFIG["num_layers"]
weight_sum = np.zeros((num_layers, num_experts), dtype=np.float32)
count = np.zeros((num_layers, num_experts), dtype=np.float32)
for li in topk_ids_by_layer:
topk_ids = topk_ids_by_layer[li][token_indices].numpy().astype(np.int64) # (n, top_k)
topk_gates = topk_gates_by_layer[li][token_indices].numpy().astype(np.float32)
for row_ids, row_gates in zip(topk_ids, topk_gates):
for e, g in zip(row_ids, row_gates):
weight_sum[li, e] += g
count[li, e] += 1
avg_weight = np.where(count > 0, weight_sum / np.maximum(count, 1), 0.0)
return avg_weight
def rank_experts_global(score_matrix: np.ndarray, top_k: int) -> List[Tuple[int, int]]:
"""
Rank experts globally across all layers by score_matrix (L, E).
Returns top_k [(layer_id, expert_id), ...] in descending order.
"""
L, E = score_matrix.shape
flat = score_matrix.flatten()
# Descending order
top_idx = np.argsort(-flat)[:top_k]
return [(int(i // E), int(i % E)) for i in top_idx]
def select_top_experts(
routing_data: Dict, # loaded from routing shards: see below
plan_tis: List[int],
mon_tis: List[int],
exec_tis: List[int],
top_k: int = 32,
eps: float = 1e-4,
) -> Dict:
"""
Compute all 3 metrics and return structured results.
routing_data format:
{
"topk_ids": {layer_id: concatenated (N_total, top_k) tensor across all CoTs},
"topk_gates": {layer_id: concatenated (N_total, top_k) tensor},
"sample_boundaries": [...] # cumulative token offsets
}
token_indices are GLOBAL indices into the concatenated tensor.
Returns:
{
"freq_plan": (L, E),
"freq_mon": (L, E),
"freq_exec": (L, E),
"delta_plan_vs_exec": (L, E),
"delta_mon_vs_exec": (L, E),
"logratio_plan_vs_exec": (L, E),
"logratio_mon_vs_exec": (L, E),
"delta_plan_vs_mon": (L, E), # cross-dim contrast
"top_experts_planning": [(l, e), ...] top_k based on combined score,
"top_experts_monitoring": [(l, e), ...],
}
"""
num_experts = MODEL_CONFIG["num_experts"]
topk_ids = routing_data["topk_ids"]
topk_gates = routing_data["topk_gates"]
# Compute frequency distributions
freq_plan = compute_selection_frequency(topk_ids, plan_tis, num_experts)
freq_mon = compute_selection_frequency(topk_ids, mon_tis, num_experts)
freq_exec = compute_selection_frequency(topk_ids, exec_tis, num_experts)
# Differentials
delta_plan_exec = freq_plan - freq_exec
delta_mon_exec = freq_mon - freq_exec
delta_plan_mon = freq_plan - freq_mon
# Log ratios
lr_plan_exec = np.log((freq_plan + eps) / (freq_exec + eps))
lr_mon_exec = np.log((freq_mon + eps) / (freq_exec + eps))
# Combined score: rank-normalize both metrics, average
def rank_norm(mat):
flat = mat.flatten()
ranks = np.argsort(np.argsort(flat)).astype(np.float32) / len(flat)
return ranks.reshape(mat.shape)
combined_plan = 0.5 * rank_norm(delta_plan_exec) + 0.5 * rank_norm(lr_plan_exec)
combined_mon = 0.5 * rank_norm(delta_mon_exec) + 0.5 * rank_norm(lr_mon_exec)
top_plan = rank_experts_global(combined_plan, top_k)
top_mon = rank_experts_global(combined_mon, top_k)
return {
"freq_plan": freq_plan,
"freq_mon": freq_mon,
"freq_exec": freq_exec,
"delta_plan_vs_exec": delta_plan_exec,
"delta_mon_vs_exec": delta_mon_exec,
"logratio_plan_vs_exec": lr_plan_exec,
"logratio_mon_vs_exec": lr_mon_exec,
"delta_plan_vs_mon": delta_plan_mon,
"top_experts_planning": top_plan,
"top_experts_monitoring": top_mon,
}
def get_target_layers(top_experts: List[Tuple[int, int]]) -> List[int]:
"""From a list of (layer, expert) pairs, return sorted unique layers."""
return sorted(set(l for l, _ in top_experts))
|