Spaces:
Running
Running
| import argparse | |
| import json | |
| import os | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def load_telemetry(path: str) -> list[dict]: | |
| with open(path, "r") as f: | |
| return json.load(f) | |
| def plot_training_loss(data: list[dict], output_dir: str): | |
| """Plot total loss, LM loss, balance loss, and z-loss over steps.""" | |
| steps = [d["step"] for d in data] | |
| total = [d["total_loss"] for d in data] | |
| lm = [d["lm_loss"] for d in data] | |
| bal = [d["balance_loss"] for d in data] | |
| zl = [d["z_loss"] for d in data] | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(steps, total, label="Total Loss", linewidth=1.5, alpha=0.9) | |
| ax.plot(steps, lm, label="LM Loss", linewidth=1.5, alpha=0.9) | |
| ax.plot(steps, bal, label="Balance Loss", linewidth=1, alpha=0.7) | |
| ax.plot(steps, zl, label="Z-Loss", linewidth=1, alpha=0.7) | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Loss") | |
| ax.set_title("Training Loss Curves") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "loss_curves.png"), dpi=150) | |
| plt.show() | |
| plt.close(fig) | |
| def plot_perplexity(data: list[dict], output_dir: str): | |
| """Plot perplexity over training steps.""" | |
| ppl_data = [(d["step"], d["perplexity"]) for d in data if "perplexity" in d] | |
| if not ppl_data: | |
| print("No perplexity data found, skipping plot.") | |
| return | |
| steps, ppls = zip(*ppl_data) | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(steps, ppls, marker="o", markersize=4, linewidth=1.5, color="tab:red") | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Perplexity") | |
| ax.set_title("Perplexity Over Training") | |
| ax.grid(True, alpha=0.3) | |
| if max(ppls) > 10 * min(ppls): | |
| ax.set_yscale("log") | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "perplexity.png"), dpi=150) | |
| plt.show() | |
| plt.close(fig) | |
| def plot_expert_heatmap(data: list[dict], output_dir: str): | |
| """Heatmap of expert utilization over time (steps x experts).""" | |
| entries = [(d["step"], d["expert_counts"]) for d in data if d.get("expert_counts")] | |
| if not entries: | |
| print("No expert count data found, skipping heatmap.") | |
| return | |
| steps, counts = zip(*entries) | |
| n_experts = len(counts[0]) | |
| matrix = np.array(counts) # (n_steps, n_experts) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| im = ax.imshow(matrix.T, aspect="auto", cmap="YlOrRd", interpolation="nearest") | |
| ax.set_xlabel("Step Index") | |
| ax.set_ylabel("Expert") | |
| ax.set_yticks(range(n_experts)) | |
| ax.set_yticklabels([f"E{i}" for i in range(n_experts)]) | |
| ax.set_title("Expert Utilization Heatmap") | |
| # Set x-tick labels to actual step numbers (sparse) | |
| n_ticks = min(10, len(steps)) | |
| tick_positions = np.linspace(0, len(steps) - 1, n_ticks, dtype=int) | |
| ax.set_xticks(tick_positions) | |
| ax.set_xticklabels([str(steps[i]) for i in tick_positions]) | |
| fig.colorbar(im, ax=ax, label="Token Count") | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "expert_heatmap.png"), dpi=150) | |
| plt.show() | |
| plt.close(fig) | |
| def plot_null_routing(data: list[dict], output_dir: str): | |
| """Line plot of null routing ratio over time.""" | |
| steps = [d["step"] for d in data] | |
| null_ratios = [d["null_ratio"] for d in data] | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(steps, [r * 100 for r in null_ratios], linewidth=1.5, color="tab:purple") | |
| ax.axhline(y=50, color="gray", linestyle="--", alpha=0.5, label="Target ρ=0.5") | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Null Routing (%)") | |
| ax.set_title("Null Routing Ratio Over Training") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| ax.set_ylim(0, 100) | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "null_routing.png"), dpi=150) | |
| plt.show() | |
| plt.close(fig) | |
| def plot_expert_token_distribution(data: list[dict], output_dir: str): | |
| """Bar chart of total tokens per expert across training.""" | |
| entries = [d["expert_counts"] for d in data if d.get("expert_counts")] | |
| if not entries: | |
| print("No expert count data found, skipping bar chart.") | |
| return | |
| totals = np.array(entries).sum(axis=0) | |
| n_experts = len(totals) | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| bars = ax.bar(range(n_experts), totals, color="tab:blue", alpha=0.8) | |
| ax.set_xlabel("Expert") | |
| ax.set_ylabel("Total Tokens Processed") | |
| ax.set_title("Per-Expert Token Distribution (Cumulative)") | |
| ax.set_xticks(range(n_experts)) | |
| ax.set_xticklabels([f"E{i}" for i in range(n_experts)]) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| # Add value labels on bars | |
| for bar, val in zip(bars, totals): | |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), | |
| f"{val:,.0f}", ha="center", va="bottom", fontsize=8) | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "expert_token_dist.png"), dpi=150) | |
| plt.show() | |
| plt.close(fig) | |
| def plot_zero_compute(data: list[dict], output_dir: str): | |
| """Plot fraction of zero-compute tokens over time.""" | |
| steps = [d["step"] for d in data] | |
| zc = [d["zero_compute_ratio"] for d in data] | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(steps, [r * 100 for r in zc], linewidth=1.5, color="tab:orange") | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Zero-Compute Tokens (%)") | |
| ax.set_title("Zero-Compute Token Ratio Over Training") | |
| ax.grid(True, alpha=0.3) | |
| ax.set_ylim(bottom=0) | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "zero_compute.png"), dpi=150) | |
| plt.show() | |
| plt.close(fig) | |
| def plot_gate_weights(data: list[dict], output_dir: str): | |
| """Average gate weights per expert (final snapshot and over time).""" | |
| entries = [(d["step"], d["gate_weights"]) for d in data if d.get("gate_weights")] | |
| if not entries: | |
| print("No gate weight data found, skipping plot.") | |
| return | |
| steps, weights = zip(*entries) | |
| n_experts = len(weights[0]) | |
| # Bar chart of final gate weights | |
| final_weights = weights[-1] | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Left: final snapshot bar chart | |
| ax = axes[0] | |
| ax.bar(range(n_experts), final_weights, color="tab:green", alpha=0.8) | |
| ax.set_xlabel("Expert") | |
| ax.set_ylabel("Average Gate Weight") | |
| ax.set_title(f"Gate Weights at Step {steps[-1]}") | |
| ax.set_xticks(range(n_experts)) | |
| ax.set_xticklabels([f"E{i}" for i in range(n_experts)]) | |
| ax.grid(True, alpha=0.3, axis="y") | |
| # Right: gate weights over time (per expert) | |
| ax = axes[1] | |
| weight_matrix = np.array(weights) # (n_steps, n_experts) | |
| for e in range(n_experts): | |
| ax.plot(list(steps), weight_matrix[:, e], label=f"E{e}", linewidth=1, alpha=0.8) | |
| ax.set_xlabel("Step") | |
| ax.set_ylabel("Average Gate Weight") | |
| ax.set_title("Gate Weights Over Training") | |
| ax.legend(fontsize=7, ncol=2) | |
| ax.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(output_dir, "gate_weights.png"), dpi=150) | |
| plt.show() | |
| plt.close(fig) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Visualize MoE Null Expert telemetry") | |
| parser.add_argument("--input", type=str, default="telemetry.json", | |
| help="Path to telemetry JSON file") | |
| parser.add_argument("--output_dir", type=str, default=".", | |
| help="Directory to save plot PNGs") | |
| args = parser.parse_args() | |
| print(f"Loading telemetry from {args.input}...") | |
| data = load_telemetry(args.input) | |
| print(f"Loaded {len(data)} steps of telemetry data.") | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| print("Plotting training loss curves...") | |
| plot_training_loss(data, args.output_dir) | |
| print("Plotting perplexity...") | |
| plot_perplexity(data, args.output_dir) | |
| print("Plotting expert utilization heatmap...") | |
| plot_expert_heatmap(data, args.output_dir) | |
| print("Plotting null routing ratio...") | |
| plot_null_routing(data, args.output_dir) | |
| print("Plotting expert token distribution...") | |
| plot_expert_token_distribution(data, args.output_dir) | |
| print("Plotting zero-compute token ratio...") | |
| plot_zero_compute(data, args.output_dir) | |
| print("Plotting gate weight distributions...") | |
| plot_gate_weights(data, args.output_dir) | |
| print(f"\nAll plots saved to {args.output_dir}/") | |
| if __name__ == "__main__": | |
| main() | |