AGofficial commited on
Commit
5404f1c
·
verified ·
1 Parent(s): f7ce35d

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +2 -0
  2. himoe_visual.png +0 -0
  3. train.py +20 -3
  4. visualizer.py +122 -0
README.md CHANGED
@@ -3,6 +3,8 @@ license: mit
3
  language:
4
  - en
5
  ---
 
 
6
  # HiMoE — Hierarchical Mixture of Experts
7
 
8
  > *A Matryoshka-inspired two-level routing architecture for efficient large-scale language modelling.*
 
3
  language:
4
  - en
5
  ---
6
+ <img src="himoe_visual.png">
7
+
8
  # HiMoE — Hierarchical Mixture of Experts
9
 
10
  > *A Matryoshka-inspired two-level routing architecture for efficient large-scale language modelling.*
himoe_visual.png ADDED
train.py CHANGED
@@ -47,7 +47,7 @@ class HiMoEConfig:
47
  num_experts: int = 8 # Level-2 choices per MoE
48
  # Training
49
  batch_size: int = 32
50
- max_iters: int = 3000
51
  eval_interval:int = 50
52
  eval_iters: int = 20
53
  lr: float = 3e-4
@@ -377,7 +377,7 @@ def load_model(model_dir: str, device: str) -> tuple:
377
  cfg.model_dir = model_dir
378
  vocab_size = meta["vocab_size"]
379
  stoi = meta["stoi"]
380
- itos = meta["itos"]
381
  step = meta["step"]
382
 
383
  model = HiMoEModel(cfg, vocab_size).to(device)
@@ -553,6 +553,23 @@ def train(cfg: HiMoEConfig, resume: bool = False):
553
  f"lr {lr_now:.2e} | "
554
  f"ETA {eta/60:.1f}m")
555
  save_model(model, cfg, vocab_size, stoi, itos, step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
  # forward + backward
558
  x, y = get_batch(train_data, cfg.block_size,
@@ -587,7 +604,7 @@ def train(cfg: HiMoEConfig, resume: bool = False):
587
  with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f:
588
  f.write(sample)
589
  with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f:
590
- json.dump(routing_log[:50], f, indent=2) # first 50 tokens
591
 
592
  print(f"\n[himoe] Sample + routing log saved to '{cfg.model_dir}/'")
593
 
 
47
  num_experts: int = 8 # Level-2 choices per MoE
48
  # Training
49
  batch_size: int = 32
50
+ max_iters: int = 750 # for testing, increase to 3000 for actual training
51
  eval_interval:int = 50
52
  eval_iters: int = 20
53
  lr: float = 3e-4
 
377
  cfg.model_dir = model_dir
378
  vocab_size = meta["vocab_size"]
379
  stoi = meta["stoi"]
380
+ itos = {int(k): v for k, v in meta["itos"].items()}
381
  step = meta["step"]
382
 
383
  model = HiMoEModel(cfg, vocab_size).to(device)
 
553
  f"lr {lr_now:.2e} | "
554
  f"ETA {eta/60:.1f}m")
555
  save_model(model, cfg, vocab_size, stoi, itos, step)
556
+
557
+ # Generate sample and save routing log periodically for visualization
558
+ model.eval()
559
+ with torch.no_grad():
560
+ # Workaround for MPS generation hangs: move to CPU for sampling
561
+ original_device = next(model.parameters()).device
562
+ model.to("cpu")
563
+ context = torch.zeros((1, 1), dtype=torch.long, device="cpu")
564
+ gen_ids, r_log = model.generate(context, max_new_tokens=400, temperature=0.8, top_k=40)
565
+ smp = "".join(itos[i] for i in gen_ids[0].tolist())
566
+ with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f:
567
+ f.write(smp)
568
+ with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f:
569
+ json.dump(r_log, f, indent=2)
570
+ model.to(original_device)
571
+ model.train()
572
+
573
 
574
  # forward + backward
575
  x, y = get_batch(train_data, cfg.block_size,
 
604
  with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f:
605
  f.write(sample)
606
  with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f:
607
+ json.dump(routing_log, f, indent=2) # save full log for visualization
608
 
609
  print(f"\n[himoe] Sample + routing log saved to '{cfg.model_dir}/'")
610
 
visualizer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+
7
+ def visualize_routing():
8
+ model_dir = "model"
9
+ sample_file = os.path.join(model_dir, "sample.txt")
10
+ routing_file = os.path.join(model_dir, "routing_log.json")
11
+ output_file = "himoe_visual.png"
12
+
13
+ if not os.path.exists(sample_file) or not os.path.exists(routing_file):
14
+ print(f"Error: Required files missing in {model_dir}")
15
+ return
16
+
17
+ with open(sample_file, "r") as f:
18
+ text = f.read()
19
+
20
+ with open(routing_file, "r") as f:
21
+ routing_log = json.load(f)
22
+
23
+ # Use Layer 0 for visualization by default
24
+ layer_idx = 0
25
+
26
+ chars = list(text)
27
+ if len(chars) > len(routing_log):
28
+ # Skip the context character
29
+ chars = chars[1:]
30
+
31
+ n = min(len(chars), len(routing_log))
32
+ chars = chars[:n]
33
+ routing_log = routing_log[:n]
34
+ # --- Setup Visuals ---
35
+ char_w, char_h = 24, 36 # Larger for zoom
36
+ cols = 60
37
+
38
+ # We need to calculate rows based on text AND newlines
39
+ current_col = 0
40
+ total_rows = 1
41
+ for char in chars:
42
+ if char == "\n":
43
+ current_col = 0
44
+ total_rows += 1
45
+ else:
46
+ current_col += 1
47
+ if current_col >= cols:
48
+ current_col = 0
49
+ total_rows += 1
50
+
51
+ margin = 50
52
+ legend_w = 300
53
+ img_w = cols * char_w + margin * 2 + legend_w
54
+ img_h = max(total_rows * char_h + margin * 3, 1000)
55
+
56
+ img = Image.new("RGB", (img_w, img_h), (20, 20, 25))
57
+ draw = ImageDraw.Draw(img)
58
+
59
+ try:
60
+ font = ImageFont.truetype("/System/Library/Fonts/Supplemental/Courier New.ttf", 22)
61
+ except:
62
+ font = ImageFont.load_default()
63
+
64
+ # --- Color Mapping ---
65
+ def get_color(moe_id, exp_id):
66
+ h = (moe_id * 60) % 360
67
+ l = 30 + (exp_id * 7) # 30% to 79%
68
+ import colorsys
69
+ r, g, b = colorsys.hls_to_rgb(h/360, l/100, 0.7)
70
+ return (int(r*255), int(g*255), int(b*255))
71
+
72
+ moe_colors = [get_color(i, 4) for i in range(6)]
73
+
74
+ # --- Draw Text ---
75
+ curr_r, curr_c = 0, 0
76
+ for i in range(n):
77
+ char = chars[i]
78
+
79
+ # Handle newline or wrap
80
+ if char == "\n" or curr_c >= cols:
81
+ curr_r += 1
82
+ curr_c = 0
83
+ if char == "\n": continue # Skip drawing the newline char itself
84
+
85
+ x = margin + curr_c * char_w
86
+ y = margin + curr_r * char_h
87
+
88
+ moe_id = routing_log[i]["moe"][layer_idx][0]
89
+ exp_id = routing_log[i]["exp"][layer_idx][0]
90
+
91
+ bg_color = get_color(moe_id, exp_id)
92
+ draw.rectangle([x, y, x + char_w - 1, y + char_h - 1], fill=bg_color)
93
+
94
+ if not char.isspace():
95
+ text_color = (255, 255, 255) if bg_color[0]*0.299 + bg_color[1]*0.587 + bg_color[2]*0.114 < 128 else (0, 0, 0)
96
+ draw.text((x + 4, y + 4), char, fill=text_color, font=font)
97
+
98
+ curr_c += 1
99
+
100
+ # --- Draw Legend ---
101
+ lx = margin + cols * char_w + 40
102
+ ly = margin
103
+ draw.text((lx, ly), "HiMoE Routing Legend", fill=(255, 255, 255), font=font)
104
+ ly += 40
105
+
106
+ for mi in range(6):
107
+ draw.text((lx, ly), f"MoE Block {mi+1}", fill=moe_colors[mi], font=font)
108
+ ly += 25
109
+ # Show a few expert shades
110
+ for ei in [0, 3, 7]:
111
+ ex = lx + 20
112
+ c = get_color(mi, ei)
113
+ draw.rectangle([ex, ly, ex + 15, ly + 15], fill=c)
114
+ draw.text((ex + 25, ly - 2), f"Exp {ei+1}", fill=(200, 200, 200), font=font)
115
+ ly += 20
116
+ ly += 10
117
+
118
+ img.save(output_file)
119
+ print(f"Visualization saved to {output_file}")
120
+
121
+ if __name__ == "__main__":
122
+ visualize_routing()