HiMoE / visualizer.py
AGofficial's picture
Upload 4 files
5404f1c verified
import os
import json
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np
def visualize_routing():
model_dir = "model"
sample_file = os.path.join(model_dir, "sample.txt")
routing_file = os.path.join(model_dir, "routing_log.json")
output_file = "himoe_visual.png"
if not os.path.exists(sample_file) or not os.path.exists(routing_file):
print(f"Error: Required files missing in {model_dir}")
return
with open(sample_file, "r") as f:
text = f.read()
with open(routing_file, "r") as f:
routing_log = json.load(f)
# Use Layer 0 for visualization by default
layer_idx = 0
chars = list(text)
if len(chars) > len(routing_log):
# Skip the context character
chars = chars[1:]
n = min(len(chars), len(routing_log))
chars = chars[:n]
routing_log = routing_log[:n]
# --- Setup Visuals ---
char_w, char_h = 24, 36 # Larger for zoom
cols = 60
# We need to calculate rows based on text AND newlines
current_col = 0
total_rows = 1
for char in chars:
if char == "\n":
current_col = 0
total_rows += 1
else:
current_col += 1
if current_col >= cols:
current_col = 0
total_rows += 1
margin = 50
legend_w = 300
img_w = cols * char_w + margin * 2 + legend_w
img_h = max(total_rows * char_h + margin * 3, 1000)
img = Image.new("RGB", (img_w, img_h), (20, 20, 25))
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("/System/Library/Fonts/Supplemental/Courier New.ttf", 22)
except:
font = ImageFont.load_default()
# --- Color Mapping ---
def get_color(moe_id, exp_id):
h = (moe_id * 60) % 360
l = 30 + (exp_id * 7) # 30% to 79%
import colorsys
r, g, b = colorsys.hls_to_rgb(h/360, l/100, 0.7)
return (int(r*255), int(g*255), int(b*255))
moe_colors = [get_color(i, 4) for i in range(6)]
# --- Draw Text ---
curr_r, curr_c = 0, 0
for i in range(n):
char = chars[i]
# Handle newline or wrap
if char == "\n" or curr_c >= cols:
curr_r += 1
curr_c = 0
if char == "\n": continue # Skip drawing the newline char itself
x = margin + curr_c * char_w
y = margin + curr_r * char_h
moe_id = routing_log[i]["moe"][layer_idx][0]
exp_id = routing_log[i]["exp"][layer_idx][0]
bg_color = get_color(moe_id, exp_id)
draw.rectangle([x, y, x + char_w - 1, y + char_h - 1], fill=bg_color)
if not char.isspace():
text_color = (255, 255, 255) if bg_color[0]*0.299 + bg_color[1]*0.587 + bg_color[2]*0.114 < 128 else (0, 0, 0)
draw.text((x + 4, y + 4), char, fill=text_color, font=font)
curr_c += 1
# --- Draw Legend ---
lx = margin + cols * char_w + 40
ly = margin
draw.text((lx, ly), "HiMoE Routing Legend", fill=(255, 255, 255), font=font)
ly += 40
for mi in range(6):
draw.text((lx, ly), f"MoE Block {mi+1}", fill=moe_colors[mi], font=font)
ly += 25
# Show a few expert shades
for ei in [0, 3, 7]:
ex = lx + 20
c = get_color(mi, ei)
draw.rectangle([ex, ly, ex + 15, ly + 15], fill=c)
draw.text((ex + 25, ly - 2), f"Exp {ei+1}", fill=(200, 200, 200), font=font)
ly += 20
ly += 10
img.save(output_file)
print(f"Visualization saved to {output_file}")
if __name__ == "__main__":
visualize_routing()