File size: 3,722 Bytes
5404f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import json
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np

def visualize_routing():
    model_dir = "model"
    sample_file = os.path.join(model_dir, "sample.txt")
    routing_file = os.path.join(model_dir, "routing_log.json")
    output_file = "himoe_visual.png"

    if not os.path.exists(sample_file) or not os.path.exists(routing_file):
        print(f"Error: Required files missing in {model_dir}")
        return

    with open(sample_file, "r") as f:
        text = f.read()
    
    with open(routing_file, "r") as f:
        routing_log = json.load(f)

    # Use Layer 0 for visualization by default
    layer_idx = 0
    
    chars = list(text)
    if len(chars) > len(routing_log):
        # Skip the context character
        chars = chars[1:]
    
    n = min(len(chars), len(routing_log))
    chars = chars[:n]
    routing_log = routing_log[:n]
    # --- Setup Visuals ---
    char_w, char_h = 24, 36 # Larger for zoom
    cols = 60 
    
    # We need to calculate rows based on text AND newlines
    current_col = 0
    total_rows = 1
    for char in chars:
        if char == "\n":
            current_col = 0
            total_rows += 1
        else:
            current_col += 1
            if current_col >= cols:
                current_col = 0
                total_rows += 1
                
    margin = 50
    legend_w = 300
    img_w = cols * char_w + margin * 2 + legend_w
    img_h = max(total_rows * char_h + margin * 3, 1000)
    
    img = Image.new("RGB", (img_w, img_h), (20, 20, 25))
    draw = ImageDraw.Draw(img)
    
    try:
        font = ImageFont.truetype("/System/Library/Fonts/Supplemental/Courier New.ttf", 22)
    except:
        font = ImageFont.load_default()

    # --- Color Mapping ---
    def get_color(moe_id, exp_id):
        h = (moe_id * 60) % 360
        l = 30 + (exp_id * 7) # 30% to 79%
        import colorsys
        r, g, b = colorsys.hls_to_rgb(h/360, l/100, 0.7)
        return (int(r*255), int(g*255), int(b*255))

    moe_colors = [get_color(i, 4) for i in range(6)]

    # --- Draw Text ---
    curr_r, curr_c = 0, 0
    for i in range(n):
        char = chars[i]
        
        # Handle newline or wrap
        if char == "\n" or curr_c >= cols:
            curr_r += 1
            curr_c = 0
            if char == "\n": continue # Skip drawing the newline char itself
            
        x = margin + curr_c * char_w
        y = margin + curr_r * char_h
        
        moe_id = routing_log[i]["moe"][layer_idx][0]
        exp_id = routing_log[i]["exp"][layer_idx][0]
        
        bg_color = get_color(moe_id, exp_id)
        draw.rectangle([x, y, x + char_w - 1, y + char_h - 1], fill=bg_color)
        
        if not char.isspace():
            text_color = (255, 255, 255) if bg_color[0]*0.299 + bg_color[1]*0.587 + bg_color[2]*0.114 < 128 else (0, 0, 0)
            draw.text((x + 4, y + 4), char, fill=text_color, font=font)
        
        curr_c += 1

    # --- Draw Legend ---
    lx = margin + cols * char_w + 40
    ly = margin
    draw.text((lx, ly), "HiMoE Routing Legend", fill=(255, 255, 255), font=font)
    ly += 40
    
    for mi in range(6):
        draw.text((lx, ly), f"MoE Block {mi+1}", fill=moe_colors[mi], font=font)
        ly += 25
        # Show a few expert shades
        for ei in [0, 3, 7]:
            ex = lx + 20
            c = get_color(mi, ei)
            draw.rectangle([ex, ly, ex + 15, ly + 15], fill=c)
            draw.text((ex + 25, ly - 2), f"Exp {ei+1}", fill=(200, 200, 200), font=font)
            ly += 20
        ly += 10

    img.save(output_file)
    print(f"Visualization saved to {output_file}")

if __name__ == "__main__":
    visualize_routing()