nxphi47 commited on
Commit
b17633a
·
verified ·
1 Parent(s): b5f5724

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +113 -110
src/streamlit_app.py CHANGED
@@ -933,9 +933,6 @@ with tab_stats:
933
  # TAB 2: STEP-BY-STEP ANIMATION
934
  # ============================================================================
935
  with tab_anim:
936
- st.subheader("Step-by-Step Algorithm Animation")
937
- st.caption("This animation follows LLA + LLAS with α capacity and min-tokens-per-GEMM (m) skip/force-assign behavior.")
938
-
939
  anim_num_gpus = 4
940
  anim_local_experts = 2
941
  anim_total_experts = anim_num_gpus * anim_local_experts
@@ -970,33 +967,36 @@ with tab_anim:
970
  st.session_state[f"anim_load_{idx}"] = int(v)
971
  st.session_state["anim_step"] = 0
972
 
973
- with st.expander("Animation Configuration", expanded=True):
974
- left, right = st.columns([1, 1], gap="large")
975
 
976
- with left:
977
- preset = st.selectbox("Preset", list(PRESETS.keys()), key="anim_preset")
978
- st.button("Apply Preset", key="anim_apply_preset", on_click=apply_preset_callback)
979
 
980
- with right:
981
- st.slider(
982
- "α (capacity factor)",
983
- 0.5, 1.5,
984
- step=0.05,
985
- key="anim_alpha"
986
- )
987
- st.slider(
988
- "m (min tokens per GEMM)",
989
- 1, 512,
990
- step=1,
991
- key="anim_min_gemm",
992
- help="LLAS rule: if candidate chunk c < m and remaining r > c, skip that GPU; else may force-assign."
993
- )
 
 
 
994
 
995
- st.markdown("**Expert Loads (native placement shown as E{i} -> GPU{i//2})**")
996
- load_cols = st.columns(anim_num_gpus)
 
997
  for gpu_idx in range(anim_num_gpus):
998
- with load_cols[gpu_idx]:
999
- st.caption(f"GPU {gpu_idx}")
1000
  for local_idx in range(anim_local_experts):
1001
  idx = gpu_idx * anim_local_experts + local_idx
1002
  st.number_input(
@@ -1015,7 +1015,7 @@ with tab_anim:
1015
  total_now = sum(loads_now)
1016
  m_alpha_now = alpha_now * (total_now / anim_num_gpus) if anim_num_gpus > 0 else float(total_now)
1017
 
1018
- st.info(f"Current: α={alpha_now:.2f}, m={m_now}, Total={total_now}, m_alpha={m_alpha_now:.2f}")
1019
 
1020
  if st.button("Reset Animation Step", key="anim_reset_step"):
1021
  st.session_state["anim_step"] = 0
@@ -1035,90 +1035,93 @@ with tab_anim:
1035
  st.session_state["anim_step"] = current_step
1036
  state = anim_steps[current_step]
1037
 
1038
- # Controls
1039
- ctrl_col1, ctrl_col2, ctrl_col3, ctrl_col4, ctrl_col5 = st.columns([1, 1, 1, 1, 4])
1040
- with ctrl_col1:
1041
- if st.button("Reset", key="anim_reset"):
1042
- st.session_state["anim_step"] = 0
1043
- st.rerun()
1044
- with ctrl_col2:
1045
- if st.button("Prev", key="anim_prev") and current_step > 0:
1046
- st.session_state["anim_step"] -= 1
1047
- st.rerun()
1048
- with ctrl_col3:
1049
- if st.button("Next", key="anim_next") and current_step < len(anim_steps) - 1:
1050
- st.session_state["anim_step"] += 1
1051
- st.rerun()
1052
- with ctrl_col4:
1053
- if st.button("End", key="anim_end"):
1054
- st.session_state["anim_step"] = len(anim_steps) - 1
1055
- st.rerun()
1056
-
1057
- st.progress(current_step / max(len(anim_steps) - 1, 1), text=f"Step {current_step + 1} / {len(anim_steps)}")
 
 
 
 
 
 
 
 
 
 
1058
 
1059
- case_type = state.get("case_type")
1060
- if case_type in (1, 2, 3):
1061
- label = "Case 1" if case_type == 1 else "Case 2" if case_type == 2 else "Case 3"
1062
- st.write(f"**{label}** — {state['message']}")
1063
- else:
1064
- st.info(state["message"])
1065
-
1066
- viz_col1, viz_col2, viz_col3 = st.columns([1.3, 1.2, 1.5])
1067
-
1068
- with viz_col1:
1069
- st.markdown("##### Experts (sorted by load)")
1070
- exp_cols = st.columns(2)
1071
-
1072
- for idx in range(anim_total_experts):
1073
- if idx >= len(state["sorted_loads"]):
1074
- continue
1075
- load = int(state["sorted_loads"][idx])
1076
- original_idx = int(state["sorted_indices"][idx])
1077
- is_processed = idx in state.get("assignments", {})
1078
- is_current = idx == int(state["current_expert_idx"])
1079
-
1080
- color = EXPERT_COLORS[original_idx % len(EXPERT_COLORS)]
1081
- opacity = "0.4" if is_processed else "1"
1082
- border = "3px solid #facc15" if is_current else "1px solid #4b5563"
1083
-
1084
- with exp_cols[idx % 2]:
1085
- st.markdown(
1086
- f"""<div style="background-color: {color}22; border: {border}; border-radius: 6px;
1087
- padding: 6px; margin: 2px 0; opacity: {opacity};">
1088
- <span style="color: #9ca3af; font-size: 10px;">E{original_idx} -> GPU{original_idx // anim_local_experts}</span>
1089
- <span style="color: {color}; font-size: 16px; font-weight: bold; float: right;">{load}</span>
1090
- </div>""",
1091
- unsafe_allow_html=True
1092
- )
1093
-
1094
- with viz_col2:
1095
- st.markdown("##### GPU Topology")
1096
- st.plotly_chart(create_gpu_topology_chart(state, anim_num_gpus), use_container_width=True, key="anim_topology")
1097
- st.caption("Helpers exclude the native GPU. Overflow is possible via force-assign in LLAS.")
1098
-
1099
- with viz_col3:
1100
- st.markdown("##### GPU Loads")
1101
- st.plotly_chart(create_load_bars_chart(state, anim_num_gpus), use_container_width=True, key="anim_loads")
1102
-
1103
- st.markdown("##### Assignment Map")
1104
- st.caption("Format: (GPU, start, end)")
1105
- if state.get("assignments"):
1106
- rows = []
1107
- for idx, assigns in state["assignments"].items():
1108
- original_idx = int(state["sorted_indices"][idx])
1109
- native_gpu = original_idx // anim_local_experts
1110
- has_spill = any(int(a["gpu"]) != int(native_gpu) for a in assigns)
1111
 
1112
- assign_str = " ".join([f"(G{int(a['gpu'])},{int(a['start'])},{int(a['end'])})" for a in assigns])
 
 
1113
 
1114
- rows.append({
1115
- "Expert": f"E{original_idx}",
1116
- "Load": int(state["sorted_loads"][idx]),
1117
- "Assignments": assign_str,
1118
- "Spilled?": "Yes" if has_spill else "No",
1119
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
 
1121
- df = pd.DataFrame(rows)
1122
- st.dataframe(df, use_container_width=True, hide_index=True, height=220)
1123
- else:
1124
- st.caption("No assignments yet")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
  # TAB 2: STEP-BY-STEP ANIMATION
934
  # ============================================================================
935
  with tab_anim:
 
 
 
936
  anim_num_gpus = 4
937
  anim_local_experts = 2
938
  anim_total_experts = anim_num_gpus * anim_local_experts
 
967
  st.session_state[f"anim_load_{idx}"] = int(v)
968
  st.session_state["anim_step"] = 0
969
 
970
+ cfg_col, out_col = st.columns([0.32, 0.68], gap="large")
 
971
 
972
+ with cfg_col:
973
+ st.subheader("Animation Config")
974
+ st.caption("LLA + LLAS with α capacity and min-tokens-per-GEMM (m).")
975
 
976
+ preset = st.selectbox("Preset", list(PRESETS.keys()), key="anim_preset")
977
+ st.button("Apply Preset", key="anim_apply_preset", on_click=apply_preset_callback)
978
+
979
+ st.markdown("#### Parameters")
980
+ st.slider(
981
+ "α (capacity factor)",
982
+ 0.5, 1.5,
983
+ step=0.05,
984
+ key="anim_alpha"
985
+ )
986
+ st.slider(
987
+ "m (min tokens per GEMM)",
988
+ 1, 512,
989
+ step=1,
990
+ key="anim_min_gemm",
991
+ help="LLAS rule: if candidate chunk c < m and remaining r > c, skip that GPU; else may force-assign."
992
+ )
993
 
994
+ st.markdown("#### Expert Loads")
995
+ st.caption("E{i} → GPU{i//2}")
996
+ load_cols = st.columns(2)
997
  for gpu_idx in range(anim_num_gpus):
998
+ with load_cols[gpu_idx % 2]:
999
+ st.markdown(f"**GPU {gpu_idx}**")
1000
  for local_idx in range(anim_local_experts):
1001
  idx = gpu_idx * anim_local_experts + local_idx
1002
  st.number_input(
 
1015
  total_now = sum(loads_now)
1016
  m_alpha_now = alpha_now * (total_now / anim_num_gpus) if anim_num_gpus > 0 else float(total_now)
1017
 
1018
+ st.info(f"α={alpha_now:.2f}, m={m_now}, Total={total_now}, m_α={m_alpha_now:.2f}")
1019
 
1020
  if st.button("Reset Animation Step", key="anim_reset_step"):
1021
  st.session_state["anim_step"] = 0
 
1035
  st.session_state["anim_step"] = current_step
1036
  state = anim_steps[current_step]
1037
 
1038
+ with out_col:
1039
+ st.subheader("Step-by-Step Animation")
1040
+
1041
+ # Controls
1042
+ ctrl_col1, ctrl_col2, ctrl_col3, ctrl_col4, ctrl_col5 = st.columns([1, 1, 1, 1, 4])
1043
+ with ctrl_col1:
1044
+ if st.button("Reset", key="anim_reset"):
1045
+ st.session_state["anim_step"] = 0
1046
+ st.rerun()
1047
+ with ctrl_col2:
1048
+ if st.button("Prev", key="anim_prev") and current_step > 0:
1049
+ st.session_state["anim_step"] -= 1
1050
+ st.rerun()
1051
+ with ctrl_col3:
1052
+ if st.button("Next", key="anim_next") and current_step < len(anim_steps) - 1:
1053
+ st.session_state["anim_step"] += 1
1054
+ st.rerun()
1055
+ with ctrl_col4:
1056
+ if st.button("End", key="anim_end"):
1057
+ st.session_state["anim_step"] = len(anim_steps) - 1
1058
+ st.rerun()
1059
+
1060
+ st.progress(current_step / max(len(anim_steps) - 1, 1), text=f"Step {current_step + 1} / {len(anim_steps)}")
1061
+
1062
+ case_type = state.get("case_type")
1063
+ if case_type in (1, 2, 3):
1064
+ label = "Case 1" if case_type == 1 else "Case 2" if case_type == 2 else "Case 3"
1065
+ st.write(f"**{label}** — {state['message']}")
1066
+ else:
1067
+ st.info(state["message"])
1068
 
1069
+ viz_col1, viz_col2, viz_col3 = st.columns([1.3, 1.2, 1.5])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1070
 
1071
+ with viz_col1:
1072
+ st.markdown("##### Experts (sorted by load)")
1073
+ exp_cols = st.columns(2)
1074
 
1075
+ for idx in range(anim_total_experts):
1076
+ if idx >= len(state["sorted_loads"]):
1077
+ continue
1078
+ load = int(state["sorted_loads"][idx])
1079
+ original_idx = int(state["sorted_indices"][idx])
1080
+ is_processed = idx in state.get("assignments", {})
1081
+ is_current = idx == int(state["current_expert_idx"])
1082
+
1083
+ color = EXPERT_COLORS[original_idx % len(EXPERT_COLORS)]
1084
+ opacity = "0.4" if is_processed else "1"
1085
+ border = "3px solid #facc15" if is_current else "1px solid #4b5563"
1086
+
1087
+ with exp_cols[idx % 2]:
1088
+ st.markdown(
1089
+ f"""<div style="background-color: {color}22; border: {border}; border-radius: 6px;
1090
+ padding: 6px; margin: 2px 0; opacity: {opacity};">
1091
+ <span style="color: #9ca3af; font-size: 10px;">E{original_idx} -> GPU{original_idx // anim_local_experts}</span>
1092
+ <span style="color: {color}; font-size: 16px; font-weight: bold; float: right;">{load}</span>
1093
+ </div>""",
1094
+ unsafe_allow_html=True
1095
+ )
1096
 
1097
+ with viz_col2:
1098
+ st.markdown("##### GPU Topology")
1099
+ st.plotly_chart(create_gpu_topology_chart(state, anim_num_gpus), use_container_width=True, key="anim_topology")
1100
+ st.caption("Helpers exclude the native GPU. Overflow is possible via force-assign in LLAS.")
1101
+
1102
+ with viz_col3:
1103
+ st.markdown("##### GPU Loads")
1104
+ st.plotly_chart(create_load_bars_chart(state, anim_num_gpus), use_container_width=True, key="anim_loads")
1105
+
1106
+ st.markdown("##### Assignment Map")
1107
+ st.caption("Format: (GPU, start, end)")
1108
+ if state.get("assignments"):
1109
+ rows = []
1110
+ for idx, assigns in state["assignments"].items():
1111
+ original_idx = int(state["sorted_indices"][idx])
1112
+ native_gpu = original_idx // anim_local_experts
1113
+ has_spill = any(int(a["gpu"]) != int(native_gpu) for a in assigns)
1114
+
1115
+ assign_str = " ".join([f"(G{int(a['gpu'])},{int(a['start'])},{int(a['end'])})" for a in assigns])
1116
+
1117
+ rows.append({
1118
+ "Expert": f"E{original_idx}",
1119
+ "Load": int(state["sorted_loads"][idx]),
1120
+ "Assignments": assign_str,
1121
+ "Spilled?": "Yes" if has_spill else "No",
1122
+ })
1123
+
1124
+ df = pd.DataFrame(rows)
1125
+ st.dataframe(df, use_container_width=True, hide_index=True, height=220)
1126
+ else:
1127
+ st.caption("No assignments yet")