nxphi47 commited on
Commit
0aff8c0
·
verified ·
1 Parent(s): f23ec41

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +1118 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,1121 @@
1
- import altair as alt
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Tuple
8
+ import torch
9
+
10
+ @dataclass
11
+ class WeightTransferPlan:
12
+ expert_id: int
13
+ src_rank: int
14
+ dst_rank: int
15
+ token_start: int
16
+ token_end: int
17
+
18
+
19
+ @dataclass
20
+ class LLEPLptPlan:
21
+ lpt_plan: Dict[int, List[Tuple[int, int, int]]]
22
+ weight_transfers: List[WeightTransferPlan]
23
+ gpu_loads: torch.Tensor
24
+
25
+
26
+ def compute_gpu_imbalance_ratio(global_expert_counts: torch.Tensor, ep_size: int, num_local_experts: int) -> float:
27
+ """
28
+ GPU-level imbalance ratio: max(gpu_load) / mean(gpu_load)
29
+ """
30
+ gpu_loads = global_expert_counts.view(ep_size, num_local_experts).sum(dim=1).float()
31
+ mean_load = gpu_loads.mean()
32
+ max_load = gpu_loads.max()
33
+ if mean_load.item() == 0:
34
+ return 1.0
35
+ return (max_load / mean_load).item()
36
+
37
+
38
+ def compute_expert_imbalance_ratio(global_expert_counts: torch.Tensor, ignore_zeros: bool = False) -> float:
39
+ """
40
+ Expert-level imbalance ratio: max(v) / mean(v)
41
+
42
+ Note:
43
+ - The paper pseudocode uses max(v) / mean(v) on the expert load vector v.
44
+ - If many experts have zero load, mean(v) can be small and inflate this ratio.
45
+ """
46
+ v = global_expert_counts.float()
47
+ if ignore_zeros:
48
+ v = v[v > 0]
49
+ if v.numel() == 0:
50
+ return 1.0
51
+ mean_v = v.mean()
52
+ if mean_v.item() == 0:
53
+ return 1.0
54
+ return (v.max() / mean_v).item()
55
+
56
+
57
+ def compute_llep_lpt_plan(
58
+ global_expert_counts: torch.Tensor,
59
+ ep_size: int,
60
+ num_local_experts: int,
61
+ max_tokens_factor: float = 1.1,
62
+ min_tokens_per_gemm: int = 512,
63
+ ) -> LLEPLptPlan:
64
+ """
65
+ LLA/LLAS-style plan construction.
66
+
67
+ Mapping to your pseudocode:
68
+ - alpha == max_tokens_factor
69
+ - m_alpha = alpha * (sum(v) / P)
70
+ - pending load g_p starts as native loads g_n; for each expert, subtract e from native pending.
71
+ - available on gpu o is (m_alpha - g_a[o] - g_p[o])
72
+ - LLAS: pick least effective-load GPU among other GPUs; respect min_tokens_per_gemm skip rule,
73
+ else force assign to least-loaded (even if it exceeds capacity).
74
+ """
75
+ num_experts = global_expert_counts.size(0)
76
+ total_tokens = int(global_expert_counts.sum().item())
77
+ alpha = float(max_tokens_factor)
78
+
79
+ # Paper: m_alpha = alpha * (1/P) * sum(v)
80
+ m_alpha = alpha * (total_tokens / ep_size) if ep_size > 0 else float(total_tokens)
81
+ max_tokens_per_gpu = max(int(np.ceil(m_alpha)), 1)
82
+
83
+ # Native load per GPU: g_n
84
+ native_load_per_gpu = [0] * ep_size
85
+ for expert_id in range(num_experts):
86
+ native_gpu = expert_id // num_local_experts
87
+ native_load_per_gpu[native_gpu] += int(global_expert_counts[expert_id].item())
88
+
89
+ # g_p (pending) and g_a (assigned)
90
+ pending_native_load = list(native_load_per_gpu) # g_p
91
+ assigned_load = [0] * ep_size # g_a
92
+
93
+ # Sort experts by load, decreasing: hat(v)
94
+ expert_counts_list = [(e, int(global_expert_counts[e].item())) for e in range(num_experts)]
95
+ expert_counts_sorted = sorted(expert_counts_list, key=lambda x: -x[1])
96
+
97
+ lpt_plan: Dict[int, List[Tuple[int, int, int]]] = {}
98
+ weight_transfers: List[WeightTransferPlan] = []
99
+
100
+ def effective_load(gpu_id: int) -> int:
101
+ # g_a + g_p
102
+ return assigned_load[gpu_id] + pending_native_load[gpu_id]
103
+
104
+ def capacity_remaining(gpu_id: int) -> int:
105
+ # m_alpha - g_a - g_p
106
+ return max_tokens_per_gpu - effective_load(gpu_id)
107
+
108
+ for expert_id, expert_tokens in expert_counts_sorted:
109
+ if expert_tokens <= 0:
110
+ continue
111
+
112
+ native_gpu = expert_id // num_local_experts
113
+
114
+ # g_p[native] -= e
115
+ pending_native_load[native_gpu] -= expert_tokens
116
+
117
+ # na = m_alpha - g_a[native] - g_p[native]
118
+ native_available = capacity_remaining(native_gpu)
119
+
120
+ assignments: List[Tuple[int, int, int]] = []
121
+
122
+ # -----------------------
123
+ # Case 1: native can take all
124
+ # -----------------------
125
+ if native_available >= expert_tokens:
126
+ assignments.append((native_gpu, 0, expert_tokens))
127
+ assigned_load[native_gpu] += expert_tokens
128
+
129
+ # -----------------------
130
+ # Case 2: native takes some, spill rest via LLAS
131
+ # -----------------------
132
+ elif native_available > 0:
133
+ native_chunk = min(native_available, expert_tokens)
134
+ assignments.append((native_gpu, 0, native_chunk))
135
+ assigned_load[native_gpu] += native_chunk
136
+
137
+ remaining = expert_tokens - native_chunk
138
+ token_offset = native_chunk
139
+
140
+ while remaining > 0:
141
+ # other GPUs sorted by effective load (g_a + g_p)
142
+ other_gpus = []
143
+ for g in range(ep_size):
144
+ if g == native_gpu:
145
+ continue
146
+ other_gpus.append((g, effective_load(g), capacity_remaining(g)))
147
+ other_gpus_sorted = sorted(other_gpus, key=lambda x: x[1])
148
+
149
+ if not other_gpus_sorted:
150
+ # Degenerate fallback: keep on native
151
+ old_end = assignments[0][2]
152
+ assignments[0] = (native_gpu, 0, old_end + remaining)
153
+ assigned_load[native_gpu] += remaining
154
+ break
155
+
156
+ assigned_this_round = False
157
+ for helper_gpu, _, helper_cap in other_gpus_sorted:
158
+ if helper_cap <= 0:
159
+ continue
160
+
161
+ chunk = min(remaining, helper_cap)
162
+
163
+ # LLAS skip rule: if chunk < m and r > chunk => skip
164
+ if chunk < min_tokens_per_gemm and remaining > chunk:
165
+ continue
166
+
167
+ assignments.append((helper_gpu, token_offset, token_offset + chunk))
168
+ assigned_load[helper_gpu] += chunk
169
+ weight_transfers.append(
170
+ WeightTransferPlan(expert_id, native_gpu, helper_gpu, token_offset, token_offset + chunk)
171
+ )
172
+
173
+ token_offset += chunk
174
+ remaining -= chunk
175
+ assigned_this_round = True
176
+ break
177
+
178
+ if not assigned_this_round:
179
+ # Force assign the least effective-load GPU (can exceed cap)
180
+ helper_gpu = other_gpus_sorted[0][0]
181
+ assignments.append((helper_gpu, token_offset, token_offset + remaining))
182
+ assigned_load[helper_gpu] += remaining
183
+ weight_transfers.append(
184
+ WeightTransferPlan(expert_id, native_gpu, helper_gpu, token_offset, token_offset + remaining)
185
+ )
186
+ token_offset += remaining
187
+ remaining = 0
188
+
189
+ # -----------------------
190
+ # Case 3: native has no available, spill all via LLAS
191
+ # -----------------------
192
+ else:
193
+ remaining = expert_tokens
194
+ token_offset = 0
195
+
196
+ other_gpus = []
197
+ for g in range(ep_size):
198
+ if g == native_gpu:
199
+ continue
200
+ other_gpus.append((g, effective_load(g), capacity_remaining(g)))
201
+ other_gpus_sorted = sorted(other_gpus, key=lambda x: x[1])
202
+
203
+ while remaining > 0:
204
+ if not other_gpus_sorted:
205
+ # Degenerate fallback: keep on native
206
+ assignments.append((native_gpu, 0, expert_tokens))
207
+ assigned_load[native_gpu] += expert_tokens
208
+ break
209
+
210
+ assigned_this_round = False
211
+ for helper_gpu, _, helper_cap in other_gpus_sorted:
212
+ if helper_cap <= 0:
213
+ continue
214
+
215
+ chunk = min(remaining, helper_cap)
216
+
217
+ if chunk < min_tokens_per_gemm and remaining > chunk:
218
+ continue
219
+
220
+ assignments.append((helper_gpu, token_offset, token_offset + chunk))
221
+ assigned_load[helper_gpu] += chunk
222
+ weight_transfers.append(
223
+ WeightTransferPlan(expert_id, native_gpu, helper_gpu, token_offset, token_offset + chunk)
224
+ )
225
+
226
+ token_offset += chunk
227
+ remaining -= chunk
228
+ assigned_this_round = True
229
+ break
230
+
231
+ if not assigned_this_round:
232
+ helper_gpu = other_gpus_sorted[0][0]
233
+ assignments.append((helper_gpu, token_offset, token_offset + remaining))
234
+ assigned_load[helper_gpu] += remaining
235
+ weight_transfers.append(
236
+ WeightTransferPlan(expert_id, native_gpu, helper_gpu, token_offset, token_offset + remaining)
237
+ )
238
+ token_offset += remaining
239
+ remaining = 0
240
+
241
+ lpt_plan[expert_id] = assignments
242
+
243
+ return LLEPLptPlan(lpt_plan=lpt_plan, weight_transfers=weight_transfers, gpu_loads=torch.tensor(assigned_load))
244
+
245
+
246
+ # ============================================================================
247
+ # ANIMATION TAB FUNCTIONS
248
+ # ============================================================================
249
+
250
+ EXPERT_COLORS = ['#3b82f6', '#8b5cf6', '#ec4899', '#14b8a6', '#f97316', '#84cc16', '#06b6d4', '#f43f5e']
251
+
252
+
253
+ def get_effective_load_anim(assigned: List[int], pending: List[int], gpu_id: int) -> int:
254
+ return assigned[gpu_id] + pending[gpu_id]
255
+
256
+
257
+ def generate_animation_steps(
258
+ expert_loads: List[int],
259
+ alpha: float,
260
+ num_gpus: int,
261
+ local_experts_per_gpu: int,
262
+ min_tokens_per_gemm: int,
263
+ ) -> List[dict]:
264
+ """
265
+ Step-by-step LLA/LLAS animation.
266
+
267
+ This follows the same logic as your pseudocode:
268
+ - pending starts as native loads
269
+ - for each expert in sorted order: pending[native] -= e
270
+ - na = m_alpha - assigned[native] - pending[native]
271
+ - case 1/2/3 and LLAS spill with skip rule and force-assign fallback
272
+ """
273
+ total_experts = num_gpus * local_experts_per_gpu
274
+ loads = [int(x) for x in expert_loads[:total_experts]]
275
+
276
+ steps: List[dict] = []
277
+
278
+ sorted_indices = sorted(range(total_experts), key=lambda i: loads[i], reverse=True)
279
+ sorted_loads = [loads[i] for i in sorted_indices]
280
+
281
+ total_load = int(sum(sorted_loads))
282
+ m_alpha = float(alpha) * (total_load / num_gpus) if num_gpus > 0 else float(total_load)
283
+ max_per_gpu = float(m_alpha)
284
+
285
+ native_loads = [0] * num_gpus
286
+ for i in range(total_experts):
287
+ native_loads[i // local_experts_per_gpu] += loads[i]
288
+
289
+ state = {
290
+ "sorted_indices": sorted_indices,
291
+ "sorted_loads": sorted_loads,
292
+ "total_load": total_load,
293
+ "max_per_gpu": max_per_gpu,
294
+ "min_tokens_per_gemm": int(min_tokens_per_gemm),
295
+ "g_pending": list(native_loads),
296
+ "g_assigned": [0] * num_gpus,
297
+ "assignments": {},
298
+ "current_expert_idx": -1,
299
+ "phase": "init",
300
+ "message": f"Sorted experts by load. Total={total_load}, m_alpha={max_per_gpu:.2f} (α={alpha:.2f}, m={min_tokens_per_gemm})",
301
+ "case_type": None,
302
+ "highlight_gpu": None,
303
+ "spill_flows": [],
304
+ "spill_targets": [],
305
+ }
306
+ steps.append(dict(state))
307
+
308
+ def cap_remaining(g_assigned: List[int], g_pending: List[int], gpu_id: int) -> float:
309
+ return max_per_gpu - float(get_effective_load_anim(g_assigned, g_pending, gpu_id))
310
+
311
+ for i in range(total_experts):
312
+ expert_load = int(state["sorted_loads"][i])
313
+ original_idx = int(state["sorted_indices"][i])
314
+ native_gpu = original_idx // local_experts_per_gpu
315
+
316
+ # g_p[native] -= e
317
+ new_pending = list(state["g_pending"])
318
+ new_pending[native_gpu] -= expert_load
319
+
320
+ na = cap_remaining(state["g_assigned"], new_pending, native_gpu)
321
+
322
+ state = dict(state)
323
+ state["g_pending"] = new_pending
324
+ state["current_expert_idx"] = i
325
+ state["highlight_gpu"] = native_gpu
326
+ state["phase"] = "evaluate"
327
+ state["message"] = f"Expert E{original_idx} (load={expert_load}) native=GPU{native_gpu}. na={max(0.0, na):.2f}"
328
+ state["spill_flows"] = []
329
+ state["spill_targets"] = []
330
+ state["case_type"] = None
331
+ steps.append(dict(state))
332
+
333
+ new_assigned = list(state["g_assigned"])
334
+ assignments = []
335
+ spill_flows = []
336
+ spill_targets = []
337
+
338
+ # Case 1
339
+ if na >= expert_load:
340
+ assignments.append({"gpu": native_gpu, "start": 0, "end": expert_load})
341
+ new_assigned[native_gpu] += expert_load
342
+ state["case_type"] = 1
343
+ state["message"] = f"Case 1: native GPU{native_gpu} takes all {expert_load}"
344
+
345
+ # Case 2
346
+ elif na > 0:
347
+ native_chunk = int(np.floor(na))
348
+ native_chunk = max(0, min(native_chunk, expert_load))
349
+
350
+ assignments.append({"gpu": native_gpu, "start": 0, "end": native_chunk})
351
+ new_assigned[native_gpu] += native_chunk
352
+
353
+ remaining = expert_load - native_chunk
354
+ token_offset = native_chunk
355
+
356
+ while remaining > 0:
357
+ helper_gpus = []
358
+ for g in range(num_gpus):
359
+ if g == native_gpu:
360
+ continue
361
+ eff_load = float(get_effective_load_anim(new_assigned, new_pending, g))
362
+ avail = cap_remaining(new_assigned, new_pending, g)
363
+ helper_gpus.append({"gpu": g, "eff_load": eff_load, "avail": avail})
364
+ helper_gpus.sort(key=lambda x: x["eff_load"])
365
+
366
+ if not helper_gpus:
367
+ # Degenerate fallback: keep on native
368
+ assignments[-1]["end"] += remaining
369
+ new_assigned[native_gpu] += remaining
370
+ remaining = 0
371
+ break
372
+
373
+ assigned_flag = False
374
+ for helper in helper_gpus:
375
+ if helper["avail"] <= 0:
376
+ continue
377
+
378
+ c = int(min(remaining, np.floor(helper["avail"])))
379
+ if c <= 0:
380
+ continue
381
+
382
+ if c < min_tokens_per_gemm and remaining > c:
383
+ continue
384
+
385
+ assignments.append({"gpu": helper["gpu"], "start": token_offset, "end": token_offset + c})
386
+ spill_flows.append({"from": native_gpu, "to": helper["gpu"], "amount": c})
387
+ spill_targets.append(helper["gpu"])
388
+ new_assigned[helper["gpu"]] += c
389
+ token_offset += c
390
+ remaining -= c
391
+ assigned_flag = True
392
+ break
393
+
394
+ if not assigned_flag:
395
+ # Force assign to least effective-load helper (may exceed capacity)
396
+ helper = helper_gpus[0]
397
+ c = remaining
398
+ assignments.append({"gpu": helper["gpu"], "start": token_offset, "end": token_offset + c})
399
+ spill_flows.append({"from": native_gpu, "to": helper["gpu"], "amount": c})
400
+ spill_targets.append(helper["gpu"])
401
+ new_assigned[helper["gpu"]] += c
402
+ token_offset += c
403
+ remaining = 0
404
+
405
+ state["case_type"] = 2
406
+ spill_target_str = ", ".join([f"GPU{g}" for g in sorted(set(spill_targets))]) if spill_targets else "none"
407
+ state["message"] = f"Case 2: native GPU{native_gpu} takes {native_chunk}, spill {expert_load - native_chunk} -> {spill_target_str}"
408
+
409
+ # Case 3
410
+ else:
411
+ remaining = expert_load
412
+ token_offset = 0
413
+
414
+ while remaining > 0:
415
+ helper_gpus = []
416
+ for g in range(num_gpus):
417
+ if g == native_gpu:
418
+ continue
419
+ eff_load = float(get_effective_load_anim(new_assigned, new_pending, g))
420
+ avail = cap_remaining(new_assigned, new_pending, g)
421
+ helper_gpus.append({"gpu": g, "eff_load": eff_load, "avail": avail})
422
+ helper_gpus.sort(key=lambda x: x["eff_load"])
423
+
424
+ if not helper_gpus:
425
+ # Degenerate fallback: keep on native
426
+ assignments.append({"gpu": native_gpu, "start": 0, "end": expert_load})
427
+ new_assigned[native_gpu] += expert_load
428
+ remaining = 0
429
+ break
430
+
431
+ assigned_flag = False
432
+ for helper in helper_gpus:
433
+ if helper["avail"] <= 0:
434
+ continue
435
+
436
+ c = int(min(remaining, np.floor(helper["avail"])))
437
+ if c <= 0:
438
+ continue
439
+
440
+ if c < min_tokens_per_gemm and remaining > c:
441
+ continue
442
+
443
+ assignments.append({"gpu": helper["gpu"], "start": token_offset, "end": token_offset + c})
444
+ spill_flows.append({"from": native_gpu, "to": helper["gpu"], "amount": c})
445
+ spill_targets.append(helper["gpu"])
446
+ new_assigned[helper["gpu"]] += c
447
+ token_offset += c
448
+ remaining -= c
449
+ assigned_flag = True
450
+ break
451
+
452
+ if not assigned_flag:
453
+ helper = helper_gpus[0]
454
+ c = remaining
455
+ assignments.append({"gpu": helper["gpu"], "start": token_offset, "end": token_offset + c})
456
+ spill_flows.append({"from": native_gpu, "to": helper["gpu"], "amount": c})
457
+ spill_targets.append(helper["gpu"])
458
+ new_assigned[helper["gpu"]] += c
459
+ token_offset += c
460
+ remaining = 0
461
+
462
+ state["case_type"] = 3
463
+ spill_target_str = ", ".join([f"GPU{g}" for g in sorted(set(spill_targets))]) if spill_targets else "none"
464
+ state["message"] = f"Case 3: native GPU{native_gpu} full; spill all {expert_load} -> {spill_target_str}"
465
+
466
+ state["g_assigned"] = new_assigned
467
+ state["assignments"] = dict(state["assignments"])
468
+ state["assignments"][i] = assignments
469
+ state["spill_flows"] = spill_flows
470
+ state["spill_targets"] = sorted(list(set(spill_targets)))
471
+ state["phase"] = "assign"
472
+ steps.append(dict(state))
473
+
474
+ case_counts = {1: 0, 2: 0, 3: 0}
475
+ for s in steps:
476
+ if s.get("case_type") in case_counts:
477
+ case_counts[int(s["case_type"])] += 1
478
+
479
+ state["phase"] = "complete"
480
+ state["message"] = f"Complete. Case1={case_counts[1]}, Case2={case_counts[2]}, Case3={case_counts[3]}"
481
+ state["current_expert_idx"] = -1
482
+ state["highlight_gpu"] = None
483
+ state["spill_flows"] = []
484
+ state["spill_targets"] = []
485
+ steps.append(dict(state))
486
+
487
+ return steps
488
+
489
+
490
+ def create_gpu_topology_chart(state: dict, num_gpus: int) -> go.Figure:
491
+ """
492
+ GPU topology with spill arrows and overflow indication.
493
+ """
494
+ fig = go.Figure()
495
+
496
+ if num_gpus <= 4:
497
+ gpu_positions = [(i % 2, 1 - i // 2) for i in range(num_gpus)]
498
+ else:
499
+ cols = 4
500
+ gpu_positions = [(i % cols, -(i // cols)) for i in range(num_gpus)]
501
+
502
+ max_load = float(state["max_per_gpu"])
503
+ assigned = state["g_assigned"]
504
+
505
+ for gpu_id in range(num_gpus):
506
+ x, y = gpu_positions[gpu_id]
507
+ a = float(assigned[gpu_id])
508
+
509
+ fill_pct = (a / max_load) if max_load > 0 else 0.0
510
+ fill_pct_clamped = min(fill_pct, 1.0)
511
+
512
+ is_highlighted = gpu_id == state.get("highlight_gpu")
513
+ is_spill_target = gpu_id in state.get("spill_targets", [])
514
+
515
+ overflow = (a - max_load) if max_load > 0 and a > max_load else 0.0
516
+
517
+ if is_highlighted:
518
+ box_color = "#facc15"
519
+ elif is_spill_target:
520
+ box_color = "#f97316"
521
+ elif overflow > 0:
522
+ box_color = "#ef4444"
523
+ else:
524
+ box_color = "#4b5563"
525
+
526
+ fig.add_shape(
527
+ type="rect", x0=x - 0.3, y0=y - 0.15, x1=x + 0.3, y1=y + 0.15,
528
+ fillcolor="#1f2937", line=dict(color=box_color, width=3)
529
+ )
530
+
531
+ bar_width = 0.5 * fill_pct_clamped
532
+ bar_color = "#ef4444" if fill_pct >= 1 else "#3b82f6"
533
+ fig.add_shape(
534
+ type="rect", x0=x - 0.25, y0=y - 0.08, x1=x - 0.25 + bar_width, y1=y - 0.02,
535
+ fillcolor=bar_color, line=dict(width=0)
536
+ )
537
+
538
+ fig.add_annotation(
539
+ x=x, y=y + 0.05, text=f"<b>GPU {gpu_id}</b>",
540
+ showarrow=False, font=dict(color="white", size=12)
541
+ )
542
+
543
+ text = f"{a:.0f} / {max_load:.0f}"
544
+ if overflow > 0:
545
+ text = f"{a:.0f} / {max_load:.0f} (+{overflow:.0f})"
546
+ fig.add_annotation(
547
+ x=x, y=y - 0.05, text=text,
548
+ showarrow=False, font=dict(color="white", size=10)
549
+ )
550
+
551
+ if is_highlighted:
552
+ fig.add_annotation(x=x, y=y - 0.22, text="NATIVE", showarrow=False, font=dict(color="#facc15", size=9))
553
+ elif is_spill_target:
554
+ fig.add_annotation(x=x, y=y - 0.22, text="HELPER", showarrow=False, font=dict(color="#f97316", size=9))
555
+ elif overflow > 0:
556
+ fig.add_annotation(x=x, y=y - 0.22, text="OVER", showarrow=False, font=dict(color="#ef4444", size=9))
557
+
558
+ for flow in state.get("spill_flows", []):
559
+ from_pos = gpu_positions[flow["from"]]
560
+ to_pos = gpu_positions[flow["to"]]
561
+
562
+ fig.add_annotation(
563
+ x=to_pos[0], y=to_pos[1],
564
+ ax=from_pos[0], ay=from_pos[1],
565
+ xref="x", yref="y", axref="x", ayref="y",
566
+ showarrow=True,
567
+ arrowhead=2, arrowsize=1.5, arrowwidth=3,
568
+ arrowcolor="#f97316"
569
+ )
570
+
571
+ mid_x = (from_pos[0] + to_pos[0]) / 2
572
+ mid_y = (from_pos[1] + to_pos[1]) / 2
573
+ fig.add_annotation(
574
+ x=mid_x, y=mid_y + 0.1,
575
+ text=f"<b>{flow['amount']}</b>",
576
+ showarrow=False,
577
+ font=dict(color="#f97316", size=12),
578
+ bgcolor="#1f2937"
579
+ )
580
+
581
+ y_min = min(p[1] for p in gpu_positions) - 0.4
582
+ y_max = max(p[1] for p in gpu_positions) + 0.4
583
+ x_min = min(p[0] for p in gpu_positions) - 0.5
584
+ x_max = max(p[0] for p in gpu_positions) + 0.5
585
+
586
+ fig.update_layout(
587
+ xaxis=dict(range=[x_min, x_max], showgrid=False, zeroline=False, showticklabels=False),
588
+ yaxis=dict(range=[y_min, y_max], showgrid=False, zeroline=False, showticklabels=False, scaleanchor="x"),
589
+ plot_bgcolor="#1f2937",
590
+ paper_bgcolor="#1f2937",
591
+ margin=dict(l=10, r=10, t=10, b=10),
592
+ height=280
593
+ )
594
+ return fig
595
+
596
+
597
+ def create_load_bars_chart(state: dict, num_gpus: int) -> go.Figure:
598
+ """
599
+ GPU load bar chart with capacity marker, showing overflow if it occurs.
600
+ """
601
+ max_load = float(state["max_per_gpu"])
602
+ gpus = [f"GPU {i}" for i in range(num_gpus)]
603
+ assigned = [float(x) for x in state["g_assigned"]]
604
+
605
+ colors = []
606
+ for i in range(num_gpus):
607
+ if i == state.get("highlight_gpu"):
608
+ colors.append("#facc15")
609
+ elif i in state.get("spill_targets", []):
610
+ colors.append("#f97316")
611
+ elif assigned[i] > max_load:
612
+ colors.append("#ef4444")
613
+ else:
614
+ colors.append("#3b82f6")
615
+
616
+ x_max = max(max_load * 1.1, (max(assigned) * 1.1 if assigned else 1.0), 1.0)
617
+
618
+ fig = go.Figure()
619
+ fig.add_trace(go.Bar(
620
+ y=gpus, x=assigned, orientation="h",
621
+ marker_color=colors,
622
+ text=[f"{a:.0f}/{max_load:.0f}" for a in assigned],
623
+ textposition="inside",
624
+ textfont=dict(color="white")
625
+ ))
626
+ fig.add_vline(x=max_load, line_dash="dash", line_color="#ef4444", line_width=2)
627
+
628
+ fig.update_layout(
629
+ xaxis=dict(title="Tokens", range=[0, x_max]),
630
+ yaxis=dict(autorange="reversed"),
631
+ plot_bgcolor="#1f2937",
632
+ paper_bgcolor="#1f2937",
633
+ font=dict(color="white"),
634
+ margin=dict(l=10, r=10, t=10, b=30),
635
+ height=max(160, num_gpus * 40),
636
+ showlegend=False
637
+ )
638
+ return fig
639
+
640
+
641
+ # ============================================================================
642
+ # STATISTICS TAB FUNCTIONS
643
+ # ============================================================================
644
+
645
+ def generate_loads(n_experts: int, n_tokens: int, k: int, skew: float) -> np.ndarray:
646
+ alpha = 10.0 * ((1.0 - skew) ** 2) + 0.05
647
+ probs = np.random.dirichlet(np.ones(n_experts) * alpha)
648
+ return np.random.multinomial(n_tokens * k, probs)
649
+
650
+
651
+ def plot_gpu_load(data: List[dict], title: str, ep_world_size: int, gpu_color_map: dict) -> go.Figure:
652
+ fig = go.Figure()
653
+ df = pd.DataFrame(data)
654
+ if df.empty:
655
+ return fig
656
+
657
+ df_grouped = df.groupby(["GPU", "Owner", "Type"])["Tokens"].sum().reset_index()
658
+ type_order = {"Native": 0, "Spill": 1}
659
+ df_grouped["TypeOrder"] = df_grouped["Type"].map(type_order)
660
+ df_grouped = df_grouped.sort_values(by=["GPU", "TypeOrder"]).reset_index(drop=True)
661
+
662
+ for _, row in df_grouped.iterrows():
663
+ gpu_id = int(row["GPU"])
664
+ owner_id = int(row["Owner"])
665
+ val = float(row["Tokens"])
666
+ is_spill = row["Type"] == "Spill"
667
+
668
+ fig.add_trace(go.Bar(
669
+ name=f"Exp from GPU {owner_id}",
670
+ x=[f"GPU {gpu_id}"],
671
+ y=[val],
672
+ marker_color=gpu_color_map[owner_id],
673
+ marker_pattern_shape="/" if is_spill else "",
674
+ marker_line_color="black",
675
+ marker_line_width=0.5,
676
+ showlegend=False,
677
+ hoverinfo="text",
678
+ hovertext=f"Processing work for native owner GPU {owner_id}<br>Tokens: {val:.0f}<br>{'SPILL' if is_spill else 'NATIVE'}"
679
+ ))
680
+
681
+ fig.update_layout(
682
+ barmode="stack",
683
+ title=title,
684
+ height=300,
685
+ margin=dict(l=20, r=20, t=40, b=20)
686
+ )
687
+ return fig
688
+
689
+
690
+ def plot_expert_distribution(data: List[dict], title: str, gpu_color_map: dict) -> go.Figure:
691
+ df = pd.DataFrame(data)
692
+ if df.empty:
693
+ return go.Figure()
694
+
695
+ fig = go.Figure()
696
+ df_grouped = df.groupby(["Expert", "GPU", "Type"])["Tokens"].sum().reset_index()
697
+ type_order = {"Native": 0, "Spill": 1}
698
+ df_grouped["TypeOrder"] = df_grouped["Type"].map(type_order)
699
+ df_grouped = df_grouped.sort_values(by=["Expert", "TypeOrder"]).reset_index(drop=True)
700
+
701
+ for _, row in df_grouped.iterrows():
702
+ expert = int(row["Expert"])
703
+ gpu = int(row["GPU"])
704
+ val = float(row["Tokens"])
705
+ is_spill = row["Type"] == "Spill"
706
+
707
+ fig.add_trace(go.Bar(
708
+ name=f"GPU {gpu}",
709
+ x=[f"E{expert}"],
710
+ y=[val],
711
+ marker_color=gpu_color_map[gpu],
712
+ marker_pattern_shape="/" if is_spill else "",
713
+ marker_line_color="black",
714
+ marker_line_width=0.5,
715
+ showlegend=False,
716
+ hoverinfo="text",
717
+ hovertext=f"Processed by GPU {gpu}<br>Tokens: {val:.0f}<br>{'SPILL' if is_spill else 'NATIVE'}"
718
+ ))
719
+
720
+ fig.update_layout(
721
+ barmode="stack",
722
+ title=title,
723
+ height=300,
724
+ margin=dict(l=20, r=20, t=40, b=20)
725
+ )
726
+ fig.update_xaxes(type="category")
727
+ return fig
728
+
729
+
730
+ # ============================================================================
731
+ # MAIN STREAMLIT APP
732
+ # ============================================================================
733
+
734
+ st.set_page_config(layout="wide", page_title="LLEP Simulator & Visualizer")
735
+
736
+ st.title("Least-Loaded Expert Parallelism (LLEP)")
737
+ st.markdown("Compare **Standard EP** against the **LLEP (LLA/LLAS)** plan and visualize step-by-step spilling.")
738
+ st.markdown("""
739
+ **Authors:** [Xuan-Phi Nguyen](https://scholar.google.com/), [Shrey Pandit](https://scholar.google.com/), [Austin Xu](https://scholar.google.com/), [Caiming Xiong](https://scholar.google.com/), [Shafiq Joty](https://scholar.google.com/)
740
+ **Affiliation:** Salesforce AI Research
741
+ **Contact:** xnguyen@salesforce.com
742
+ """)
743
+
744
+ tab_stats, tab_anim = st.tabs(["Statistics & Comparison", "Step-by-Step Animation"])
745
+
746
+ # ============================================================================
747
+ # TAB 1: STATISTICS & COMPARISON
748
+ # ============================================================================
749
+ with tab_stats:
750
+ cfg_col, out_col = st.columns([0.36, 0.64], gap="large")
751
+
752
+ with cfg_col:
753
+ st.subheader("Statistics Config")
754
+
755
+ num_experts = st.selectbox("Num Experts", [32, 64, 128, 256], index=0, key="stats_experts")
756
+ ep_world_size = st.selectbox("World Size (GPUs)", [4, 8, 16, 32], index=1, key="stats_gpus")
757
+ experts_per_gpu = num_experts // ep_world_size
758
+
759
+ st.markdown("#### Traffic Config")
760
+ total_tokens = st.selectbox("Batch Tokens", [4096, 8192, 16384, 32768, 65536, 131072], index=3, key="stats_tokens")
761
+ top_k = st.slider("Top K", 1, num_experts // 2, min(4, num_experts // 2), key="stats_topk")
762
+ imbalance = st.slider("Skew (Imbalance)", 0.0, 0.99, 0.6, key="stats_skew", help="Higher = more hotspots")
763
+
764
+ st.markdown("#### LLEP / LLA Config")
765
+ alpha_capacity = st.slider(
766
+ "α (capacity factor)",
767
+ 1.0, 2.0, 1.1, 0.05,
768
+ key="stats_alpha",
769
+ help="m_alpha = α * (sum(v)/P). Lower α -> more spilling."
770
+ )
771
+ min_tokens_per_gemm = st.slider(
772
+ "Min tokens per GEMM (m)",
773
+ 1, 4096, 512, 32,
774
+ key="stats_min_gemm",
775
+ help="If a candidate chunk c < m and remaining r > c, we skip that GPU (LLAS rule)."
776
+ )
777
+
778
+ st.markdown("#### Activation Threshold (λ)")
779
+ imbalance_metric = st.radio(
780
+ "Imbalance metric used for λ check",
781
+ ["Expert-level (paper)", "GPU-level (practical)"],
782
+ index=0,
783
+ key="stats_metric",
784
+ help="Paper pseudocode uses max(v)/mean(v) on expert loads v. Your earlier code used GPU aggregation."
785
+ )
786
+ ignore_zeros = st.checkbox(
787
+ "Ignore zero-load experts for expert-level mean",
788
+ value=True,
789
+ key="stats_ignore_zeros",
790
+ help="Prevents max(v)/mean(v) from exploding when many experts are unused."
791
+ )
792
+ imbalance_threshold = st.slider(
793
+ "λ (threshold)",
794
+ 1.0, 10.0, 1.3, 0.1,
795
+ key="stats_lambda",
796
+ help="If ratio < λ, we use standard EP. Else, compute LLA/LLAS plan."
797
+ )
798
+
799
+ regen = st.button("Regenerate Traffic", key="stats_regen")
800
+
801
+ # Generate Synthetic Data (scoped to this tab, no sidebar bleed)
802
+ config_key = (num_experts, total_tokens, top_k, imbalance, "stats")
803
+ if ("stats_config_key" not in st.session_state) or (st.session_state["stats_config_key"] != config_key) or regen:
804
+ st.session_state["stats_config_key"] = config_key
805
+ st.session_state["stats_expert_loads"] = generate_loads(num_experts, total_tokens, top_k, imbalance)
806
+
807
+ expert_loads = st.session_state["stats_expert_loads"]
808
+ expert_loads_tensor = torch.tensor(expert_loads, dtype=torch.int64)
809
+
810
+ # Standard EP
811
+ ep_gpu_loads = [0] * ep_world_size
812
+ ep_expert_assignment = []
813
+ for e_id, count in enumerate(expert_loads):
814
+ if int(count) == 0:
815
+ continue
816
+ owner_gpu = e_id // experts_per_gpu
817
+ ep_gpu_loads[owner_gpu] += int(count)
818
+ ep_expert_assignment.append({
819
+ "Expert": int(e_id),
820
+ "GPU": int(owner_gpu),
821
+ "Tokens": int(count),
822
+ "Type": "Native",
823
+ "Owner": int(owner_gpu),
824
+ })
825
+
826
+ # Ratios
827
+ ratio_expert = compute_expert_imbalance_ratio(expert_loads_tensor, ignore_zeros=bool(ignore_zeros))
828
+ ratio_gpu = compute_gpu_imbalance_ratio(expert_loads_tensor, ep_world_size, experts_per_gpu)
829
+
830
+ if imbalance_metric == "Expert-level (paper)":
831
+ imbalance_ratio = ratio_expert
832
+ else:
833
+ imbalance_ratio = ratio_gpu
834
+
835
+ use_lpt = imbalance_ratio >= float(imbalance_threshold)
836
+
837
+ # LLEP (LLA/LLAS) plan
838
+ if use_lpt:
839
+ llep_result = compute_llep_lpt_plan(
840
+ expert_loads_tensor,
841
+ ep_world_size,
842
+ experts_per_gpu,
843
+ max_tokens_factor=float(alpha_capacity),
844
+ min_tokens_per_gemm=int(min_tokens_per_gemm),
845
+ )
846
+ llep_expert_assignment = []
847
+ for e_id, assigns in llep_result.lpt_plan.items():
848
+ native_owner = int(e_id) // experts_per_gpu
849
+ for (assigned_gpu, start_t, end_t) in assigns:
850
+ count = int(end_t - start_t)
851
+ if count <= 0:
852
+ continue
853
+ is_spill = (int(assigned_gpu) != int(native_owner))
854
+ llep_expert_assignment.append({
855
+ "Expert": int(e_id),
856
+ "GPU": int(assigned_gpu),
857
+ "Tokens": count,
858
+ "Type": "Spill" if is_spill else "Native",
859
+ "Owner": int(native_owner),
860
+ })
861
+ else:
862
+ llep_result = LLEPLptPlan(lpt_plan={}, weight_transfers=[], gpu_loads=torch.tensor(ep_gpu_loads))
863
+ llep_expert_assignment = ep_expert_assignment.copy()
864
+
865
+ colors = px.colors.qualitative.Plotly
866
+ gpu_color_map = {i: colors[i % len(colors)] for i in range(ep_world_size)}
867
+
868
+ with out_col:
869
+ st.subheader("Status")
870
+ st.write(
871
+ pd.DataFrame([{
872
+ "expert_ratio max/mean": f"{ratio_expert:.2f}x",
873
+ "gpu_ratio max/mean": f"{ratio_gpu:.2f}x",
874
+ "metric_used": imbalance_metric,
875
+ "λ": float(imbalance_threshold),
876
+ "activated": bool(use_lpt),
877
+ "α": float(alpha_capacity),
878
+ "m": int(min_tokens_per_gemm),
879
+ }])
880
+ )
881
+
882
+ if not use_lpt:
883
+ st.warning(
884
+ f"LLA skipped: ratio {imbalance_ratio:.2f}x < λ {imbalance_threshold:.2f}. Using standard EP."
885
+ )
886
+
887
+ st.markdown("---")
888
+
889
+ # GPU Load Comparison
890
+ st.subheader("1. GPU Load Comparison")
891
+ c_load1, c_load2 = st.columns(2)
892
+
893
+ with c_load1:
894
+ st.markdown("##### Standard EP")
895
+ st.caption("Each GPU processes its native experts only.")
896
+ st.plotly_chart(plot_gpu_load(ep_expert_assignment, "", ep_world_size, gpu_color_map), use_container_width=True, key="ep_gpu_load")
897
+
898
+ with c_load2:
899
+ st.markdown("##### LLEP / LLA (Solid=Native, Hatched=Spill)" if use_lpt else "##### LLEP (standard EP fallback)")
900
+ st.caption("Overloaded GPUs spill to least-loaded helpers, following LLAS rules." if use_lpt else "Imbalance below λ, so no spilling.")
901
+ st.plotly_chart(plot_gpu_load(llep_expert_assignment, "", ep_world_size, gpu_color_map), use_container_width=True, key="llep_gpu_load")
902
+
903
+ # Expert Assignment
904
+ st.subheader("2. Experts' GPU Assignment")
905
+ c_exp1, c_exp2 = st.columns(2)
906
+
907
+ with c_exp1:
908
+ st.markdown("##### Standard EP (Fixed)")
909
+ st.caption("Each expert is assigned to exactly one GPU.")
910
+ st.plotly_chart(plot_expert_distribution(ep_expert_assignment, "", gpu_color_map), use_container_width=True, key="ep_expert_dist")
911
+
912
+ with c_exp2:
913
+ st.markdown("##### LLEP (Split across GPUs)" if use_lpt else "##### LLEP (standard EP fallback)")
914
+ st.caption("Experts may be split across GPUs when spilling is needed." if use_lpt else "Same as standard EP.")
915
+ st.plotly_chart(plot_expert_distribution(llep_expert_assignment, "", gpu_color_map), use_container_width=True, key="llep_expert_dist")
916
+
917
+ legend_html = " &nbsp; ".join(
918
+ f"<span style='display:inline-block;width:14px;height:14px;background-color:{gpu_color_map[i]};border:1px solid black;vertical-align:middle;'></span> GPU {i}"
919
+ for i in range(ep_world_size)
920
+ )
921
+ st.markdown(f"**Legend:** {legend_html}", unsafe_allow_html=True)
922
+
923
+ with st.expander("Show Plan Details"):
924
+ st.write("Weight Transfers Needed:", len(llep_result.weight_transfers))
925
+ if len(llep_result.weight_transfers) > 0:
926
+ st.dataframe([vars(x) for x in llep_result.weight_transfers])
927
+
928
+
929
+ # ============================================================================
930
+ # TAB 2: STEP-BY-STEP ANIMATION
931
+ # ============================================================================
932
+ with tab_anim:
933
+ st.subheader("Step-by-Step Algorithm Animation")
934
+ st.caption("This animation follows LLA + LLAS with α capacity and min-tokens-per-GEMM (m) skip/force-assign behavior.")
935
+
936
+ anim_num_gpus = 4
937
+ anim_local_experts = 2
938
+ anim_total_experts = anim_num_gpus * anim_local_experts
939
+
940
+ # Initialize widget-backed state once
941
+ if "anim_alpha" not in st.session_state:
942
+ st.session_state["anim_alpha"] = 1.0
943
+ if "anim_min_gemm" not in st.session_state:
944
+ st.session_state["anim_min_gemm"] = 1
945
+ if "anim_step" not in st.session_state:
946
+ st.session_state["anim_step"] = 0
947
+ for idx in range(anim_total_experts):
948
+ key = f"anim_load_{idx}"
949
+ if key not in st.session_state:
950
+ default = [150, 50, 20, 20, 100, 40, 40, 20][idx]
951
+ st.session_state[key] = int(default)
952
+
953
+ PRESETS = {
954
+ "No Spill (high α)": {"alpha": 1.5, "loads": [50, 50, 50, 50, 50, 50, 50, 50]},
955
+ "Some Spills": {"alpha": 1.0, "loads": [150, 50, 20, 20, 100, 40, 40, 20]},
956
+ "Many Spills (low α)": {"alpha": 0.8, "loads": [150, 50, 20, 20, 100, 40, 40, 20]},
957
+ "Extreme Imbalance": {"alpha": 0.6, "loads": [200, 10, 10, 10, 180, 10, 10, 10]},
958
+ }
959
+
960
+ # Define callback to apply preset BEFORE widgets are instantiated
961
+ def apply_preset_callback():
962
+ preset_name = st.session_state.get("anim_preset", "Some Spills")
963
+ if preset_name in PRESETS:
964
+ st.session_state["anim_alpha"] = float(PRESETS[preset_name]["alpha"])
965
+ st.session_state["anim_min_gemm"] = st.session_state.get("anim_min_gemm", 1)
966
+ for idx, v in enumerate(PRESETS[preset_name]["loads"]):
967
+ st.session_state[f"anim_load_{idx}"] = int(v)
968
+ st.session_state["anim_step"] = 0
969
+
970
+ with st.expander("Animation Configuration", expanded=True):
971
+ left, right = st.columns([1, 1], gap="large")
972
+
973
+ with left:
974
+ preset = st.selectbox("Preset", list(PRESETS.keys()), key="anim_preset")
975
+ st.button("Apply Preset", key="anim_apply_preset", on_click=apply_preset_callback)
976
+
977
+ with right:
978
+ st.slider(
979
+ "α (capacity factor)",
980
+ 0.5, 1.5,
981
+ step=0.05,
982
+ key="anim_alpha"
983
+ )
984
+ st.slider(
985
+ "m (min tokens per GEMM)",
986
+ 1, 512,
987
+ step=1,
988
+ key="anim_min_gemm",
989
+ help="LLAS rule: if candidate chunk c < m and remaining r > c, skip that GPU; else may force-assign."
990
+ )
991
+
992
+ st.markdown("**Expert Loads (native placement shown as E{i} -> GPU{i//2})**")
993
+ load_cols = st.columns(anim_num_gpus)
994
+ for gpu_idx in range(anim_num_gpus):
995
+ with load_cols[gpu_idx]:
996
+ st.caption(f"GPU {gpu_idx}")
997
+ for local_idx in range(anim_local_experts):
998
+ idx = gpu_idx * anim_local_experts + local_idx
999
+ st.number_input(
1000
+ f"E{idx}",
1001
+ min_value=0,
1002
+ max_value=500,
1003
+ value=int(st.session_state[f"anim_load_{idx}"]),
1004
+ step=1,
1005
+ key=f"anim_load_{idx}"
1006
+ )
1007
+
1008
+ loads_now = [int(st.session_state[f"anim_load_{i}"]) for i in range(anim_total_experts)]
1009
+ alpha_now = float(st.session_state["anim_alpha"])
1010
+ m_now = int(st.session_state["anim_min_gemm"])
1011
+
1012
+ total_now = sum(loads_now)
1013
+ m_alpha_now = alpha_now * (total_now / anim_num_gpus) if anim_num_gpus > 0 else float(total_now)
1014
+
1015
+ st.info(f"Current: α={alpha_now:.2f}, m={m_now}, Total={total_now}, m_alpha={m_alpha_now:.2f}")
1016
+
1017
+ if st.button("Reset Animation Step", key="anim_reset_step"):
1018
+ st.session_state["anim_step"] = 0
1019
+ st.rerun()
1020
+
1021
+ # Build steps from current widget values (so changes are visible immediately)
1022
+ anim_steps = generate_animation_steps(
1023
+ expert_loads=[int(st.session_state[f"anim_load_{i}"]) for i in range(anim_total_experts)],
1024
+ alpha=float(st.session_state["anim_alpha"]),
1025
+ num_gpus=anim_num_gpus,
1026
+ local_experts_per_gpu=anim_local_experts,
1027
+ min_tokens_per_gemm=int(st.session_state["anim_min_gemm"]),
1028
+ )
1029
+
1030
+ current_step = int(st.session_state["anim_step"])
1031
+ current_step = max(0, min(current_step, len(anim_steps) - 1))
1032
+ st.session_state["anim_step"] = current_step
1033
+ state = anim_steps[current_step]
1034
+
1035
+ # Controls
1036
+ ctrl_col1, ctrl_col2, ctrl_col3, ctrl_col4, ctrl_col5 = st.columns([1, 1, 1, 1, 4])
1037
+ with ctrl_col1:
1038
+ if st.button("Reset", key="anim_reset"):
1039
+ st.session_state["anim_step"] = 0
1040
+ st.rerun()
1041
+ with ctrl_col2:
1042
+ if st.button("Prev", key="anim_prev") and current_step > 0:
1043
+ st.session_state["anim_step"] -= 1
1044
+ st.rerun()
1045
+ with ctrl_col3:
1046
+ if st.button("Next", key="anim_next") and current_step < len(anim_steps) - 1:
1047
+ st.session_state["anim_step"] += 1
1048
+ st.rerun()
1049
+ with ctrl_col4:
1050
+ if st.button("End", key="anim_end"):
1051
+ st.session_state["anim_step"] = len(anim_steps) - 1
1052
+ st.rerun()
1053
+
1054
+ st.progress(current_step / max(len(anim_steps) - 1, 1), text=f"Step {current_step + 1} / {len(anim_steps)}")
1055
+
1056
+ case_type = state.get("case_type")
1057
+ if case_type in (1, 2, 3):
1058
+ label = "Case 1" if case_type == 1 else "Case 2" if case_type == 2 else "Case 3"
1059
+ st.write(f"**{label}** — {state['message']}")
1060
+ else:
1061
+ st.info(state["message"])
1062
+
1063
+ viz_col1, viz_col2, viz_col3 = st.columns([1.3, 1.2, 1.5])
1064
+
1065
+ with viz_col1:
1066
+ st.markdown("##### Experts (sorted by load)")
1067
+ exp_cols = st.columns(2)
1068
+
1069
+ for idx in range(anim_total_experts):
1070
+ if idx >= len(state["sorted_loads"]):
1071
+ continue
1072
+ load = int(state["sorted_loads"][idx])
1073
+ original_idx = int(state["sorted_indices"][idx])
1074
+ is_processed = idx in state.get("assignments", {})
1075
+ is_current = idx == int(state["current_expert_idx"])
1076
+
1077
+ color = EXPERT_COLORS[original_idx % len(EXPERT_COLORS)]
1078
+ opacity = "0.4" if is_processed else "1"
1079
+ border = "3px solid #facc15" if is_current else "1px solid #4b5563"
1080
+
1081
+ with exp_cols[idx % 2]:
1082
+ st.markdown(
1083
+ f"""<div style="background-color: {color}22; border: {border}; border-radius: 6px;
1084
+ padding: 6px; margin: 2px 0; opacity: {opacity};">
1085
+ <span style="color: #9ca3af; font-size: 10px;">E{original_idx} -> GPU{original_idx // anim_local_experts}</span>
1086
+ <span style="color: {color}; font-size: 16px; font-weight: bold; float: right;">{load}</span>
1087
+ </div>""",
1088
+ unsafe_allow_html=True
1089
+ )
1090
+
1091
+ with viz_col2:
1092
+ st.markdown("##### GPU Topology")
1093
+ st.plotly_chart(create_gpu_topology_chart(state, anim_num_gpus), use_container_width=True, key="anim_topology")
1094
+ st.caption("Helpers exclude the native GPU. Overflow is possible via force-assign in LLAS.")
1095
+
1096
+ with viz_col3:
1097
+ st.markdown("##### GPU Loads")
1098
+ st.plotly_chart(create_load_bars_chart(state, anim_num_gpus), use_container_width=True, key="anim_loads")
1099
+
1100
+ st.markdown("##### Assignment Map")
1101
+ st.caption("Format: (GPU, start, end)")
1102
+ if state.get("assignments"):
1103
+ rows = []
1104
+ for idx, assigns in state["assignments"].items():
1105
+ original_idx = int(state["sorted_indices"][idx])
1106
+ native_gpu = original_idx // anim_local_experts
1107
+ has_spill = any(int(a["gpu"]) != int(native_gpu) for a in assigns)
1108
+
1109
+ assign_str = " ".join([f"(G{int(a['gpu'])},{int(a['start'])},{int(a['end'])})" for a in assigns])
1110
+
1111
+ rows.append({
1112
+ "Expert": f"E{original_idx}",
1113
+ "Load": int(state["sorted_loads"][idx]),
1114
+ "Assignments": assign_str,
1115
+ "Spilled?": "Yes" if has_spill else "No",
1116
+ })
1117
 
1118
+ df = pd.DataFrame(rows)
1119
+ st.dataframe(df, use_container_width=True, hide_index=True, height=220)
1120
+ else:
1121
+ st.caption("No assignments yet")