""" Stage 1: Expert selection from routing dumps. Two metrics (MoE-sparse-friendly): 1. Selection Frequency Diff: Δfreq(l, e) = P(e ∈ TopK | S_plan) - P(e ∈ TopK | S_exec) 2. Log Ratio (sensitive to sparse distributions): LR(l, e) = log((freq_plan + eps) / (freq_exec + eps)) And a cross-dimensional metric: 3. Cross-dim Contrast: Δfreq_cross(l, e) = P(e | S_plan) - P(e | S_mon) This identifies experts that TRULY separate the two dimensions. Top-K experts: ranked by a combined score. """ import numpy as np import torch from typing import Dict, List, Tuple from configs.model import MODEL_CONFIG def compute_selection_frequency( topk_ids_by_layer: Dict[int, torch.Tensor], # layer_id -> (S, top_k) int16 token_indices: List[int], num_experts: int, ) -> np.ndarray: """ Compute P(expert e is in topK at layer l) over the given token_indices. Returns (num_layers, num_experts) float32 array of frequencies in [0, 1]. """ num_layers = MODEL_CONFIG["num_layers"] freq = np.zeros((num_layers, num_experts), dtype=np.float32) if not token_indices: return freq n = len(token_indices) for li, topk_ids in topk_ids_by_layer.items(): # topk_ids: (S, top_k) # Select rows at token_indices sel = topk_ids[token_indices].numpy().astype(np.int64) # (n, top_k) # Count occurrences per expert flat = sel.flatten() bincount = np.bincount(flat, minlength=num_experts).astype(np.float32) freq[li] = bincount / n return freq def compute_gating_weight( topk_ids_by_layer: Dict[int, torch.Tensor], topk_gates_by_layer: Dict[int, torch.Tensor], token_indices: List[int], num_experts: int, ) -> np.ndarray: """ Compute E[gating weight of expert e at layer l] over token_indices (conditional on e being in topK). Returns (num_layers, num_experts) float32. """ num_layers = MODEL_CONFIG["num_layers"] weight_sum = np.zeros((num_layers, num_experts), dtype=np.float32) count = np.zeros((num_layers, num_experts), dtype=np.float32) for li in topk_ids_by_layer: topk_ids = topk_ids_by_layer[li][token_indices].numpy().astype(np.int64) # (n, top_k) topk_gates = topk_gates_by_layer[li][token_indices].numpy().astype(np.float32) for row_ids, row_gates in zip(topk_ids, topk_gates): for e, g in zip(row_ids, row_gates): weight_sum[li, e] += g count[li, e] += 1 avg_weight = np.where(count > 0, weight_sum / np.maximum(count, 1), 0.0) return avg_weight def rank_experts_global(score_matrix: np.ndarray, top_k: int) -> List[Tuple[int, int]]: """ Rank experts globally across all layers by score_matrix (L, E). Returns top_k [(layer_id, expert_id), ...] in descending order. """ L, E = score_matrix.shape flat = score_matrix.flatten() # Descending order top_idx = np.argsort(-flat)[:top_k] return [(int(i // E), int(i % E)) for i in top_idx] def select_top_experts( routing_data: Dict, # loaded from routing shards: see below plan_tis: List[int], mon_tis: List[int], exec_tis: List[int], top_k: int = 32, eps: float = 1e-4, ) -> Dict: """ Compute all 3 metrics and return structured results. routing_data format: { "topk_ids": {layer_id: concatenated (N_total, top_k) tensor across all CoTs}, "topk_gates": {layer_id: concatenated (N_total, top_k) tensor}, "sample_boundaries": [...] # cumulative token offsets } token_indices are GLOBAL indices into the concatenated tensor. Returns: { "freq_plan": (L, E), "freq_mon": (L, E), "freq_exec": (L, E), "delta_plan_vs_exec": (L, E), "delta_mon_vs_exec": (L, E), "logratio_plan_vs_exec": (L, E), "logratio_mon_vs_exec": (L, E), "delta_plan_vs_mon": (L, E), # cross-dim contrast "top_experts_planning": [(l, e), ...] top_k based on combined score, "top_experts_monitoring": [(l, e), ...], } """ num_experts = MODEL_CONFIG["num_experts"] topk_ids = routing_data["topk_ids"] topk_gates = routing_data["topk_gates"] # Compute frequency distributions freq_plan = compute_selection_frequency(topk_ids, plan_tis, num_experts) freq_mon = compute_selection_frequency(topk_ids, mon_tis, num_experts) freq_exec = compute_selection_frequency(topk_ids, exec_tis, num_experts) # Differentials delta_plan_exec = freq_plan - freq_exec delta_mon_exec = freq_mon - freq_exec delta_plan_mon = freq_plan - freq_mon # Log ratios lr_plan_exec = np.log((freq_plan + eps) / (freq_exec + eps)) lr_mon_exec = np.log((freq_mon + eps) / (freq_exec + eps)) # Combined score: rank-normalize both metrics, average def rank_norm(mat): flat = mat.flatten() ranks = np.argsort(np.argsort(flat)).astype(np.float32) / len(flat) return ranks.reshape(mat.shape) combined_plan = 0.5 * rank_norm(delta_plan_exec) + 0.5 * rank_norm(lr_plan_exec) combined_mon = 0.5 * rank_norm(delta_mon_exec) + 0.5 * rank_norm(lr_mon_exec) top_plan = rank_experts_global(combined_plan, top_k) top_mon = rank_experts_global(combined_mon, top_k) return { "freq_plan": freq_plan, "freq_mon": freq_mon, "freq_exec": freq_exec, "delta_plan_vs_exec": delta_plan_exec, "delta_mon_vs_exec": delta_mon_exec, "logratio_plan_vs_exec": lr_plan_exec, "logratio_mon_vs_exec": lr_mon_exec, "delta_plan_vs_mon": delta_plan_mon, "top_experts_planning": top_plan, "top_experts_monitoring": top_mon, } def get_target_layers(top_experts: List[Tuple[int, int]]) -> List[int]: """From a list of (layer, expert) pairs, return sorted unique layers.""" return sorted(set(l for l, _ in top_experts))