Shakespeare_MoE / visualize.py
haemant's picture
Upload folder using huggingface_hub
aad4104 verified
import argparse
import json
import os
import matplotlib.pyplot as plt
import numpy as np
def load_telemetry(path: str) -> list[dict]:
with open(path, "r") as f:
return json.load(f)
def plot_training_loss(data: list[dict], output_dir: str):
"""Plot total loss, LM loss, balance loss, and z-loss over steps."""
steps = [d["step"] for d in data]
total = [d["total_loss"] for d in data]
lm = [d["lm_loss"] for d in data]
bal = [d["balance_loss"] for d in data]
zl = [d["z_loss"] for d in data]
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(steps, total, label="Total Loss", linewidth=1.5, alpha=0.9)
ax.plot(steps, lm, label="LM Loss", linewidth=1.5, alpha=0.9)
ax.plot(steps, bal, label="Balance Loss", linewidth=1, alpha=0.7)
ax.plot(steps, zl, label="Z-Loss", linewidth=1, alpha=0.7)
ax.set_xlabel("Step")
ax.set_ylabel("Loss")
ax.set_title("Training Loss Curves")
ax.legend()
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "loss_curves.png"), dpi=150)
plt.show()
plt.close(fig)
def plot_perplexity(data: list[dict], output_dir: str):
"""Plot perplexity over training steps."""
ppl_data = [(d["step"], d["perplexity"]) for d in data if "perplexity" in d]
if not ppl_data:
print("No perplexity data found, skipping plot.")
return
steps, ppls = zip(*ppl_data)
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(steps, ppls, marker="o", markersize=4, linewidth=1.5, color="tab:red")
ax.set_xlabel("Step")
ax.set_ylabel("Perplexity")
ax.set_title("Perplexity Over Training")
ax.grid(True, alpha=0.3)
if max(ppls) > 10 * min(ppls):
ax.set_yscale("log")
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "perplexity.png"), dpi=150)
plt.show()
plt.close(fig)
def plot_expert_heatmap(data: list[dict], output_dir: str):
"""Heatmap of expert utilization over time (steps x experts)."""
entries = [(d["step"], d["expert_counts"]) for d in data if d.get("expert_counts")]
if not entries:
print("No expert count data found, skipping heatmap.")
return
steps, counts = zip(*entries)
n_experts = len(counts[0])
matrix = np.array(counts) # (n_steps, n_experts)
fig, ax = plt.subplots(figsize=(10, 6))
im = ax.imshow(matrix.T, aspect="auto", cmap="YlOrRd", interpolation="nearest")
ax.set_xlabel("Step Index")
ax.set_ylabel("Expert")
ax.set_yticks(range(n_experts))
ax.set_yticklabels([f"E{i}" for i in range(n_experts)])
ax.set_title("Expert Utilization Heatmap")
# Set x-tick labels to actual step numbers (sparse)
n_ticks = min(10, len(steps))
tick_positions = np.linspace(0, len(steps) - 1, n_ticks, dtype=int)
ax.set_xticks(tick_positions)
ax.set_xticklabels([str(steps[i]) for i in tick_positions])
fig.colorbar(im, ax=ax, label="Token Count")
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "expert_heatmap.png"), dpi=150)
plt.show()
plt.close(fig)
def plot_null_routing(data: list[dict], output_dir: str):
"""Line plot of null routing ratio over time."""
steps = [d["step"] for d in data]
null_ratios = [d["null_ratio"] for d in data]
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(steps, [r * 100 for r in null_ratios], linewidth=1.5, color="tab:purple")
ax.axhline(y=50, color="gray", linestyle="--", alpha=0.5, label="Target ρ=0.5")
ax.set_xlabel("Step")
ax.set_ylabel("Null Routing (%)")
ax.set_title("Null Routing Ratio Over Training")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 100)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "null_routing.png"), dpi=150)
plt.show()
plt.close(fig)
def plot_expert_token_distribution(data: list[dict], output_dir: str):
"""Bar chart of total tokens per expert across training."""
entries = [d["expert_counts"] for d in data if d.get("expert_counts")]
if not entries:
print("No expert count data found, skipping bar chart.")
return
totals = np.array(entries).sum(axis=0)
n_experts = len(totals)
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(range(n_experts), totals, color="tab:blue", alpha=0.8)
ax.set_xlabel("Expert")
ax.set_ylabel("Total Tokens Processed")
ax.set_title("Per-Expert Token Distribution (Cumulative)")
ax.set_xticks(range(n_experts))
ax.set_xticklabels([f"E{i}" for i in range(n_experts)])
ax.grid(True, alpha=0.3, axis="y")
# Add value labels on bars
for bar, val in zip(bars, totals):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
f"{val:,.0f}", ha="center", va="bottom", fontsize=8)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "expert_token_dist.png"), dpi=150)
plt.show()
plt.close(fig)
def plot_zero_compute(data: list[dict], output_dir: str):
"""Plot fraction of zero-compute tokens over time."""
steps = [d["step"] for d in data]
zc = [d["zero_compute_ratio"] for d in data]
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(steps, [r * 100 for r in zc], linewidth=1.5, color="tab:orange")
ax.set_xlabel("Step")
ax.set_ylabel("Zero-Compute Tokens (%)")
ax.set_title("Zero-Compute Token Ratio Over Training")
ax.grid(True, alpha=0.3)
ax.set_ylim(bottom=0)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "zero_compute.png"), dpi=150)
plt.show()
plt.close(fig)
def plot_gate_weights(data: list[dict], output_dir: str):
"""Average gate weights per expert (final snapshot and over time)."""
entries = [(d["step"], d["gate_weights"]) for d in data if d.get("gate_weights")]
if not entries:
print("No gate weight data found, skipping plot.")
return
steps, weights = zip(*entries)
n_experts = len(weights[0])
# Bar chart of final gate weights
final_weights = weights[-1]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left: final snapshot bar chart
ax = axes[0]
ax.bar(range(n_experts), final_weights, color="tab:green", alpha=0.8)
ax.set_xlabel("Expert")
ax.set_ylabel("Average Gate Weight")
ax.set_title(f"Gate Weights at Step {steps[-1]}")
ax.set_xticks(range(n_experts))
ax.set_xticklabels([f"E{i}" for i in range(n_experts)])
ax.grid(True, alpha=0.3, axis="y")
# Right: gate weights over time (per expert)
ax = axes[1]
weight_matrix = np.array(weights) # (n_steps, n_experts)
for e in range(n_experts):
ax.plot(list(steps), weight_matrix[:, e], label=f"E{e}", linewidth=1, alpha=0.8)
ax.set_xlabel("Step")
ax.set_ylabel("Average Gate Weight")
ax.set_title("Gate Weights Over Training")
ax.legend(fontsize=7, ncol=2)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(os.path.join(output_dir, "gate_weights.png"), dpi=150)
plt.show()
plt.close(fig)
def main():
parser = argparse.ArgumentParser(description="Visualize MoE Null Expert telemetry")
parser.add_argument("--input", type=str, default="telemetry.json",
help="Path to telemetry JSON file")
parser.add_argument("--output_dir", type=str, default=".",
help="Directory to save plot PNGs")
args = parser.parse_args()
print(f"Loading telemetry from {args.input}...")
data = load_telemetry(args.input)
print(f"Loaded {len(data)} steps of telemetry data.")
os.makedirs(args.output_dir, exist_ok=True)
print("Plotting training loss curves...")
plot_training_loss(data, args.output_dir)
print("Plotting perplexity...")
plot_perplexity(data, args.output_dir)
print("Plotting expert utilization heatmap...")
plot_expert_heatmap(data, args.output_dir)
print("Plotting null routing ratio...")
plot_null_routing(data, args.output_dir)
print("Plotting expert token distribution...")
plot_expert_token_distribution(data, args.output_dir)
print("Plotting zero-compute token ratio...")
plot_zero_compute(data, args.output_dir)
print("Plotting gate weight distributions...")
plot_gate_weights(data, args.output_dir)
print(f"\nAll plots saved to {args.output_dir}/")
if __name__ == "__main__":
main()