import os import json import torch from PIL import Image, ImageDraw, ImageFont import numpy as np def visualize_routing(): model_dir = "model" sample_file = os.path.join(model_dir, "sample.txt") routing_file = os.path.join(model_dir, "routing_log.json") output_file = "himoe_visual.png" if not os.path.exists(sample_file) or not os.path.exists(routing_file): print(f"Error: Required files missing in {model_dir}") return with open(sample_file, "r") as f: text = f.read() with open(routing_file, "r") as f: routing_log = json.load(f) # Use Layer 0 for visualization by default layer_idx = 0 chars = list(text) if len(chars) > len(routing_log): # Skip the context character chars = chars[1:] n = min(len(chars), len(routing_log)) chars = chars[:n] routing_log = routing_log[:n] # --- Setup Visuals --- char_w, char_h = 24, 36 # Larger for zoom cols = 60 # We need to calculate rows based on text AND newlines current_col = 0 total_rows = 1 for char in chars: if char == "\n": current_col = 0 total_rows += 1 else: current_col += 1 if current_col >= cols: current_col = 0 total_rows += 1 margin = 50 legend_w = 300 img_w = cols * char_w + margin * 2 + legend_w img_h = max(total_rows * char_h + margin * 3, 1000) img = Image.new("RGB", (img_w, img_h), (20, 20, 25)) draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("/System/Library/Fonts/Supplemental/Courier New.ttf", 22) except: font = ImageFont.load_default() # --- Color Mapping --- def get_color(moe_id, exp_id): h = (moe_id * 60) % 360 l = 30 + (exp_id * 7) # 30% to 79% import colorsys r, g, b = colorsys.hls_to_rgb(h/360, l/100, 0.7) return (int(r*255), int(g*255), int(b*255)) moe_colors = [get_color(i, 4) for i in range(6)] # --- Draw Text --- curr_r, curr_c = 0, 0 for i in range(n): char = chars[i] # Handle newline or wrap if char == "\n" or curr_c >= cols: curr_r += 1 curr_c = 0 if char == "\n": continue # Skip drawing the newline char itself x = margin + curr_c * char_w y = margin + curr_r * char_h moe_id = routing_log[i]["moe"][layer_idx][0] exp_id = routing_log[i]["exp"][layer_idx][0] bg_color = get_color(moe_id, exp_id) draw.rectangle([x, y, x + char_w - 1, y + char_h - 1], fill=bg_color) if not char.isspace(): text_color = (255, 255, 255) if bg_color[0]*0.299 + bg_color[1]*0.587 + bg_color[2]*0.114 < 128 else (0, 0, 0) draw.text((x + 4, y + 4), char, fill=text_color, font=font) curr_c += 1 # --- Draw Legend --- lx = margin + cols * char_w + 40 ly = margin draw.text((lx, ly), "HiMoE Routing Legend", fill=(255, 255, 255), font=font) ly += 40 for mi in range(6): draw.text((lx, ly), f"MoE Block {mi+1}", fill=moe_colors[mi], font=font) ly += 25 # Show a few expert shades for ei in [0, 3, 7]: ex = lx + 20 c = get_color(mi, ei) draw.rectangle([ex, ly, ex + 15, ly + 15], fill=c) draw.text((ex + 25, ly - 2), f"Exp {ei+1}", fill=(200, 200, 200), font=font) ly += 20 ly += 10 img.save(output_file) print(f"Visualization saved to {output_file}") if __name__ == "__main__": visualize_routing()