| import os |
| import json |
| import torch |
| from PIL import Image, ImageDraw, ImageFont |
| import numpy as np |
|
|
| def visualize_routing(): |
| model_dir = "model" |
| sample_file = os.path.join(model_dir, "sample.txt") |
| routing_file = os.path.join(model_dir, "routing_log.json") |
| output_file = "himoe_visual.png" |
|
|
| if not os.path.exists(sample_file) or not os.path.exists(routing_file): |
| print(f"Error: Required files missing in {model_dir}") |
| return |
|
|
| with open(sample_file, "r") as f: |
| text = f.read() |
| |
| with open(routing_file, "r") as f: |
| routing_log = json.load(f) |
|
|
| |
| layer_idx = 0 |
| |
| chars = list(text) |
| if len(chars) > len(routing_log): |
| |
| chars = chars[1:] |
| |
| n = min(len(chars), len(routing_log)) |
| chars = chars[:n] |
| routing_log = routing_log[:n] |
| |
| char_w, char_h = 24, 36 |
| cols = 60 |
| |
| |
| current_col = 0 |
| total_rows = 1 |
| for char in chars: |
| if char == "\n": |
| current_col = 0 |
| total_rows += 1 |
| else: |
| current_col += 1 |
| if current_col >= cols: |
| current_col = 0 |
| total_rows += 1 |
| |
| margin = 50 |
| legend_w = 300 |
| img_w = cols * char_w + margin * 2 + legend_w |
| img_h = max(total_rows * char_h + margin * 3, 1000) |
| |
| img = Image.new("RGB", (img_w, img_h), (20, 20, 25)) |
| draw = ImageDraw.Draw(img) |
| |
| try: |
| font = ImageFont.truetype("/System/Library/Fonts/Supplemental/Courier New.ttf", 22) |
| except: |
| font = ImageFont.load_default() |
|
|
| |
| def get_color(moe_id, exp_id): |
| h = (moe_id * 60) % 360 |
| l = 30 + (exp_id * 7) |
| import colorsys |
| r, g, b = colorsys.hls_to_rgb(h/360, l/100, 0.7) |
| return (int(r*255), int(g*255), int(b*255)) |
|
|
| moe_colors = [get_color(i, 4) for i in range(6)] |
|
|
| |
| curr_r, curr_c = 0, 0 |
| for i in range(n): |
| char = chars[i] |
| |
| |
| if char == "\n" or curr_c >= cols: |
| curr_r += 1 |
| curr_c = 0 |
| if char == "\n": continue |
| |
| x = margin + curr_c * char_w |
| y = margin + curr_r * char_h |
| |
| moe_id = routing_log[i]["moe"][layer_idx][0] |
| exp_id = routing_log[i]["exp"][layer_idx][0] |
| |
| bg_color = get_color(moe_id, exp_id) |
| draw.rectangle([x, y, x + char_w - 1, y + char_h - 1], fill=bg_color) |
| |
| if not char.isspace(): |
| text_color = (255, 255, 255) if bg_color[0]*0.299 + bg_color[1]*0.587 + bg_color[2]*0.114 < 128 else (0, 0, 0) |
| draw.text((x + 4, y + 4), char, fill=text_color, font=font) |
| |
| curr_c += 1 |
|
|
| |
| lx = margin + cols * char_w + 40 |
| ly = margin |
| draw.text((lx, ly), "HiMoE Routing Legend", fill=(255, 255, 255), font=font) |
| ly += 40 |
| |
| for mi in range(6): |
| draw.text((lx, ly), f"MoE Block {mi+1}", fill=moe_colors[mi], font=font) |
| ly += 25 |
| |
| for ei in [0, 3, 7]: |
| ex = lx + 20 |
| c = get_color(mi, ei) |
| draw.rectangle([ex, ly, ex + 15, ly + 15], fill=c) |
| draw.text((ex + 25, ly - 2), f"Exp {ei+1}", fill=(200, 200, 200), font=font) |
| ly += 20 |
| ly += 10 |
|
|
| img.save(output_file) |
| print(f"Visualization saved to {output_file}") |
|
|
| if __name__ == "__main__": |
| visualize_routing() |
|
|