| """ |
| Stage 1: Expert selection from routing dumps. |
| |
| Two metrics (MoE-sparse-friendly): |
| |
| 1. Selection Frequency Diff: |
| Δfreq(l, e) = P(e ∈ TopK | S_plan) - P(e ∈ TopK | S_exec) |
| |
| 2. Log Ratio (sensitive to sparse distributions): |
| LR(l, e) = log((freq_plan + eps) / (freq_exec + eps)) |
| |
| And a cross-dimensional metric: |
| |
| 3. Cross-dim Contrast: |
| Δfreq_cross(l, e) = P(e | S_plan) - P(e | S_mon) |
| |
| This identifies experts that TRULY separate the two dimensions. |
| |
| Top-K experts: ranked by a combined score. |
| """ |
| import numpy as np |
| import torch |
| from typing import Dict, List, Tuple |
| from configs.model import MODEL_CONFIG |
|
|
|
|
| def compute_selection_frequency( |
| topk_ids_by_layer: Dict[int, torch.Tensor], |
| token_indices: List[int], |
| num_experts: int, |
| ) -> np.ndarray: |
| """ |
| Compute P(expert e is in topK at layer l) over the given token_indices. |
| |
| Returns (num_layers, num_experts) float32 array of frequencies in [0, 1]. |
| """ |
| num_layers = MODEL_CONFIG["num_layers"] |
| freq = np.zeros((num_layers, num_experts), dtype=np.float32) |
| if not token_indices: |
| return freq |
| n = len(token_indices) |
|
|
| for li, topk_ids in topk_ids_by_layer.items(): |
| |
| |
| sel = topk_ids[token_indices].numpy().astype(np.int64) |
| |
| flat = sel.flatten() |
| bincount = np.bincount(flat, minlength=num_experts).astype(np.float32) |
| freq[li] = bincount / n |
|
|
| return freq |
|
|
|
|
| def compute_gating_weight( |
| topk_ids_by_layer: Dict[int, torch.Tensor], |
| topk_gates_by_layer: Dict[int, torch.Tensor], |
| token_indices: List[int], |
| num_experts: int, |
| ) -> np.ndarray: |
| """ |
| Compute E[gating weight of expert e at layer l] over token_indices |
| (conditional on e being in topK). |
| |
| Returns (num_layers, num_experts) float32. |
| """ |
| num_layers = MODEL_CONFIG["num_layers"] |
| weight_sum = np.zeros((num_layers, num_experts), dtype=np.float32) |
| count = np.zeros((num_layers, num_experts), dtype=np.float32) |
|
|
| for li in topk_ids_by_layer: |
| topk_ids = topk_ids_by_layer[li][token_indices].numpy().astype(np.int64) |
| topk_gates = topk_gates_by_layer[li][token_indices].numpy().astype(np.float32) |
| for row_ids, row_gates in zip(topk_ids, topk_gates): |
| for e, g in zip(row_ids, row_gates): |
| weight_sum[li, e] += g |
| count[li, e] += 1 |
|
|
| avg_weight = np.where(count > 0, weight_sum / np.maximum(count, 1), 0.0) |
| return avg_weight |
|
|
|
|
| def rank_experts_global(score_matrix: np.ndarray, top_k: int) -> List[Tuple[int, int]]: |
| """ |
| Rank experts globally across all layers by score_matrix (L, E). |
| Returns top_k [(layer_id, expert_id), ...] in descending order. |
| """ |
| L, E = score_matrix.shape |
| flat = score_matrix.flatten() |
| |
| top_idx = np.argsort(-flat)[:top_k] |
| return [(int(i // E), int(i % E)) for i in top_idx] |
|
|
|
|
| def select_top_experts( |
| routing_data: Dict, |
| plan_tis: List[int], |
| mon_tis: List[int], |
| exec_tis: List[int], |
| top_k: int = 32, |
| eps: float = 1e-4, |
| ) -> Dict: |
| """ |
| Compute all 3 metrics and return structured results. |
| |
| routing_data format: |
| { |
| "topk_ids": {layer_id: concatenated (N_total, top_k) tensor across all CoTs}, |
| "topk_gates": {layer_id: concatenated (N_total, top_k) tensor}, |
| "sample_boundaries": [...] # cumulative token offsets |
| } |
| token_indices are GLOBAL indices into the concatenated tensor. |
| |
| Returns: |
| { |
| "freq_plan": (L, E), |
| "freq_mon": (L, E), |
| "freq_exec": (L, E), |
| "delta_plan_vs_exec": (L, E), |
| "delta_mon_vs_exec": (L, E), |
| "logratio_plan_vs_exec": (L, E), |
| "logratio_mon_vs_exec": (L, E), |
| "delta_plan_vs_mon": (L, E), # cross-dim contrast |
| "top_experts_planning": [(l, e), ...] top_k based on combined score, |
| "top_experts_monitoring": [(l, e), ...], |
| } |
| """ |
| num_experts = MODEL_CONFIG["num_experts"] |
| topk_ids = routing_data["topk_ids"] |
| topk_gates = routing_data["topk_gates"] |
|
|
| |
| freq_plan = compute_selection_frequency(topk_ids, plan_tis, num_experts) |
| freq_mon = compute_selection_frequency(topk_ids, mon_tis, num_experts) |
| freq_exec = compute_selection_frequency(topk_ids, exec_tis, num_experts) |
|
|
| |
| delta_plan_exec = freq_plan - freq_exec |
| delta_mon_exec = freq_mon - freq_exec |
| delta_plan_mon = freq_plan - freq_mon |
|
|
| |
| lr_plan_exec = np.log((freq_plan + eps) / (freq_exec + eps)) |
| lr_mon_exec = np.log((freq_mon + eps) / (freq_exec + eps)) |
|
|
| |
| def rank_norm(mat): |
| flat = mat.flatten() |
| ranks = np.argsort(np.argsort(flat)).astype(np.float32) / len(flat) |
| return ranks.reshape(mat.shape) |
|
|
| combined_plan = 0.5 * rank_norm(delta_plan_exec) + 0.5 * rank_norm(lr_plan_exec) |
| combined_mon = 0.5 * rank_norm(delta_mon_exec) + 0.5 * rank_norm(lr_mon_exec) |
|
|
| top_plan = rank_experts_global(combined_plan, top_k) |
| top_mon = rank_experts_global(combined_mon, top_k) |
|
|
| return { |
| "freq_plan": freq_plan, |
| "freq_mon": freq_mon, |
| "freq_exec": freq_exec, |
| "delta_plan_vs_exec": delta_plan_exec, |
| "delta_mon_vs_exec": delta_mon_exec, |
| "logratio_plan_vs_exec": lr_plan_exec, |
| "logratio_mon_vs_exec": lr_mon_exec, |
| "delta_plan_vs_mon": delta_plan_mon, |
| "top_experts_planning": top_plan, |
| "top_experts_monitoring": top_mon, |
| } |
|
|
|
|
| def get_target_layers(top_experts: List[Tuple[int, int]]) -> List[int]: |
| """From a list of (layer, expert) pairs, return sorted unique layers.""" |
| return sorted(set(l for l, _ in top_experts)) |
|
|