v2 / src /expert_select.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 1: Expert selection from routing dumps.
Two metrics (MoE-sparse-friendly):
1. Selection Frequency Diff:
Δfreq(l, e) = P(e ∈ TopK | S_plan) - P(e ∈ TopK | S_exec)
2. Log Ratio (sensitive to sparse distributions):
LR(l, e) = log((freq_plan + eps) / (freq_exec + eps))
And a cross-dimensional metric:
3. Cross-dim Contrast:
Δfreq_cross(l, e) = P(e | S_plan) - P(e | S_mon)
This identifies experts that TRULY separate the two dimensions.
Top-K experts: ranked by a combined score.
"""
import numpy as np
import torch
from typing import Dict, List, Tuple
from configs.model import MODEL_CONFIG
def compute_selection_frequency(
topk_ids_by_layer: Dict[int, torch.Tensor], # layer_id -> (S, top_k) int16
token_indices: List[int],
num_experts: int,
) -> np.ndarray:
"""
Compute P(expert e is in topK at layer l) over the given token_indices.
Returns (num_layers, num_experts) float32 array of frequencies in [0, 1].
"""
num_layers = MODEL_CONFIG["num_layers"]
freq = np.zeros((num_layers, num_experts), dtype=np.float32)
if not token_indices:
return freq
n = len(token_indices)
for li, topk_ids in topk_ids_by_layer.items():
# topk_ids: (S, top_k)
# Select rows at token_indices
sel = topk_ids[token_indices].numpy().astype(np.int64) # (n, top_k)
# Count occurrences per expert
flat = sel.flatten()
bincount = np.bincount(flat, minlength=num_experts).astype(np.float32)
freq[li] = bincount / n
return freq
def compute_gating_weight(
topk_ids_by_layer: Dict[int, torch.Tensor],
topk_gates_by_layer: Dict[int, torch.Tensor],
token_indices: List[int],
num_experts: int,
) -> np.ndarray:
"""
Compute E[gating weight of expert e at layer l] over token_indices
(conditional on e being in topK).
Returns (num_layers, num_experts) float32.
"""
num_layers = MODEL_CONFIG["num_layers"]
weight_sum = np.zeros((num_layers, num_experts), dtype=np.float32)
count = np.zeros((num_layers, num_experts), dtype=np.float32)
for li in topk_ids_by_layer:
topk_ids = topk_ids_by_layer[li][token_indices].numpy().astype(np.int64) # (n, top_k)
topk_gates = topk_gates_by_layer[li][token_indices].numpy().astype(np.float32)
for row_ids, row_gates in zip(topk_ids, topk_gates):
for e, g in zip(row_ids, row_gates):
weight_sum[li, e] += g
count[li, e] += 1
avg_weight = np.where(count > 0, weight_sum / np.maximum(count, 1), 0.0)
return avg_weight
def rank_experts_global(score_matrix: np.ndarray, top_k: int) -> List[Tuple[int, int]]:
"""
Rank experts globally across all layers by score_matrix (L, E).
Returns top_k [(layer_id, expert_id), ...] in descending order.
"""
L, E = score_matrix.shape
flat = score_matrix.flatten()
# Descending order
top_idx = np.argsort(-flat)[:top_k]
return [(int(i // E), int(i % E)) for i in top_idx]
def select_top_experts(
routing_data: Dict, # loaded from routing shards: see below
plan_tis: List[int],
mon_tis: List[int],
exec_tis: List[int],
top_k: int = 32,
eps: float = 1e-4,
) -> Dict:
"""
Compute all 3 metrics and return structured results.
routing_data format:
{
"topk_ids": {layer_id: concatenated (N_total, top_k) tensor across all CoTs},
"topk_gates": {layer_id: concatenated (N_total, top_k) tensor},
"sample_boundaries": [...] # cumulative token offsets
}
token_indices are GLOBAL indices into the concatenated tensor.
Returns:
{
"freq_plan": (L, E),
"freq_mon": (L, E),
"freq_exec": (L, E),
"delta_plan_vs_exec": (L, E),
"delta_mon_vs_exec": (L, E),
"logratio_plan_vs_exec": (L, E),
"logratio_mon_vs_exec": (L, E),
"delta_plan_vs_mon": (L, E), # cross-dim contrast
"top_experts_planning": [(l, e), ...] top_k based on combined score,
"top_experts_monitoring": [(l, e), ...],
}
"""
num_experts = MODEL_CONFIG["num_experts"]
topk_ids = routing_data["topk_ids"]
topk_gates = routing_data["topk_gates"]
# Compute frequency distributions
freq_plan = compute_selection_frequency(topk_ids, plan_tis, num_experts)
freq_mon = compute_selection_frequency(topk_ids, mon_tis, num_experts)
freq_exec = compute_selection_frequency(topk_ids, exec_tis, num_experts)
# Differentials
delta_plan_exec = freq_plan - freq_exec
delta_mon_exec = freq_mon - freq_exec
delta_plan_mon = freq_plan - freq_mon
# Log ratios
lr_plan_exec = np.log((freq_plan + eps) / (freq_exec + eps))
lr_mon_exec = np.log((freq_mon + eps) / (freq_exec + eps))
# Combined score: rank-normalize both metrics, average
def rank_norm(mat):
flat = mat.flatten()
ranks = np.argsort(np.argsort(flat)).astype(np.float32) / len(flat)
return ranks.reshape(mat.shape)
combined_plan = 0.5 * rank_norm(delta_plan_exec) + 0.5 * rank_norm(lr_plan_exec)
combined_mon = 0.5 * rank_norm(delta_mon_exec) + 0.5 * rank_norm(lr_mon_exec)
top_plan = rank_experts_global(combined_plan, top_k)
top_mon = rank_experts_global(combined_mon, top_k)
return {
"freq_plan": freq_plan,
"freq_mon": freq_mon,
"freq_exec": freq_exec,
"delta_plan_vs_exec": delta_plan_exec,
"delta_mon_vs_exec": delta_mon_exec,
"logratio_plan_vs_exec": lr_plan_exec,
"logratio_mon_vs_exec": lr_mon_exec,
"delta_plan_vs_mon": delta_plan_mon,
"top_experts_planning": top_plan,
"top_experts_monitoring": top_mon,
}
def get_target_layers(top_experts: List[Tuple[int, int]]) -> List[int]:
"""From a list of (layer, expert) pairs, return sorted unique layers."""
return sorted(set(l for l, _ in top_experts))