bkhmsi commited on
Commit
c0742fe
·
1 Parent(s): 55b6215

added expert ablation

Browse files
Files changed (2) hide show
  1. app.py +61 -20
  2. router_backend.py +4 -4
app.py CHANGED
@@ -81,7 +81,16 @@ def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
81
  return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
82
  return user_prompt
83
 
84
- def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant_prompt: str) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
 
 
 
 
 
 
 
 
 
85
  """
86
  Main pipeline:
87
  - Compose prompt (user + optional assistant)
@@ -91,29 +100,45 @@ def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant
91
  if hf_token.strip() == "":
92
  hf_token = None # allow empty token
93
 
94
- model_id = model_choice.strip()
95
- if not model_id:
96
- raise gr.Error("Please select a model or enter a custom model id.")
97
- prompt = _compose_prompt(user_prompt, assistant_prompt)
98
- if not prompt:
99
- raise gr.Error("Please enter a prompt.")
100
-
 
 
 
101
  seed = 42
102
  use_mock = False
103
- if use_mock:
104
- msg = "Using mock data."
105
- vals = _mock_routing(model_id, prompt, seed=seed)
106
  generation = None
 
107
  else:
108
- try:
109
- raw, generation = get_expert_routing(model_id, hf_token, prompt) # <-- your real function
110
- vals = _normalize_output(raw)
111
- msg = "Routed with real backend."
112
- except Exception as e:
113
- # fallback to mock on error, but surface message
114
- msg = f"Backend error: {e}\nFalling back to mock data."
 
 
115
  vals = _mock_routing(model_id, prompt, seed=seed)
116
  generation = None
 
 
 
 
 
 
 
 
 
 
117
 
118
  df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
119
  colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
@@ -143,6 +168,20 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
143
  with gr.Row():
144
  model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
145
  hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  with gr.Row():
148
  user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
@@ -159,11 +198,13 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
159
  with gr.Row():
160
  table = gr.Dataframe(label="Routing Percentages", interactive=False)
161
  plot = gr.Plot(label="Bar Plot")
162
- status = gr.Markdown("")
 
 
163
 
164
  run.click(
165
  route_and_plot,
166
- inputs=[model_choice, hf_token, user_prompt, assistant_prompt],
167
  outputs=[generation_output, table, plot, status],
168
  )
169
 
 
81
  return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
82
  return user_prompt
83
 
84
+ def route_and_plot(
85
+ model_choice: str,
86
+ hf_token: str,
87
+ user_prompt: str,
88
+ assistant_prompt: str,
89
+ ablate_language: bool,
90
+ ablate_logic: bool,
91
+ ablate_social: bool,
92
+ ablate_world: bool,
93
+ ) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
94
  """
95
  Main pipeline:
96
  - Compose prompt (user + optional assistant)
 
100
  if hf_token.strip() == "":
101
  hf_token = None # allow empty token
102
 
103
+ ablations = []
104
+ if ablate_language:
105
+ ablations.append("language")
106
+ if ablate_logic:
107
+ ablations.append("logic")
108
+ if ablate_social:
109
+ ablations.append("social")
110
+ if ablate_world:
111
+ ablations.append("world")
112
+
113
  seed = 42
114
  use_mock = False
115
+
116
+ if len(ablations) == 4:
117
+ msg = "Error message: you can't ablate all experts.<br>Falling back to mock data."
118
  generation = None
119
+ vals = _mock_routing(model_id, prompt, seed=seed)
120
  else:
121
+ model_id = model_choice.strip()
122
+ if not model_id:
123
+ raise gr.Error("Please select a model or enter a custom model id.")
124
+ prompt = _compose_prompt(user_prompt, assistant_prompt)
125
+ if not prompt:
126
+ raise gr.Error("Please enter a prompt.")
127
+
128
+ if use_mock:
129
+ msg = "Using mock data."
130
  vals = _mock_routing(model_id, prompt, seed=seed)
131
  generation = None
132
+ else:
133
+ try:
134
+ raw, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function
135
+ vals = _normalize_output(raw)
136
+ msg = "Routed with real backend."
137
+ except Exception as e:
138
+ # fallback to mock on error, but surface message
139
+ msg = f"Backend error: {e}\nFalling back to mock data."
140
+ vals = _mock_routing(model_id, prompt, seed=seed)
141
+ generation = None
142
 
143
  df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
144
  colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
 
168
  with gr.Row():
169
  model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
170
  hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
171
+
172
+ with gr.Column():
173
+ with gr.Row():
174
+ gr.Markdown(
175
+ """
176
+ #### Ablate Experts
177
+ (Check to disable an expert; the routing percentages will be redistributed among the remaining experts)
178
+ """, label="Ablate Experts"
179
+ )
180
+ with gr.Row():
181
+ ablate_language = gr.Checkbox(value=False, label="Language Expert")
182
+ ablate_logic = gr.Checkbox(value=False, label="Logic Expert")
183
+ ablate_social = gr.Checkbox(value=False, label="Social Expert")
184
+ ablate_world = gr.Checkbox(value=False, label="World Expert")
185
 
186
  with gr.Row():
187
  user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
 
198
  with gr.Row():
199
  table = gr.Dataframe(label="Routing Percentages", interactive=False)
200
  plot = gr.Plot(label="Bar Plot")
201
+
202
+
203
+ status = gr.Markdown("", label="System Message")
204
 
205
  run.click(
206
  route_and_plot,
207
+ inputs=[model_choice, hf_token, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
208
  outputs=[generation_output, table, plot, status],
209
  )
210
 
router_backend.py CHANGED
@@ -26,9 +26,9 @@ from models.micro_moe_llama import MiCRoLlamaMoE
26
 
27
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
 
29
- def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]]) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:
30
 
31
- model, tokenizer = build_model(model_id, hf_token)
32
 
33
  if isinstance(prompt, str):
34
  generation, routing_weights = generate_continuation(model, tokenizer, prompt)
@@ -189,7 +189,7 @@ def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
189
 
190
  return routing_weights
191
 
192
- def build_model(model_id: str, hf_token: str, use_cache: bool = True):
193
 
194
  model_path, base_model, model_class = get_model_path(model_id)
195
 
@@ -203,7 +203,7 @@ def build_model(model_id: str, hf_token: str, use_cache: bool = True):
203
  model_config.use_bfloat16 = True
204
  model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
205
  model_config.use_cache = use_cache
206
- model_config.ablate = []
207
 
208
  tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
209
  tokenizer.padding_side = "left"
 
26
 
27
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
 
29
+ def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]], ablations: List[str] = None) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:
30
 
31
+ model, tokenizer = build_model(model_id, hf_token, ablations=ablations)
32
 
33
  if isinstance(prompt, str):
34
  generation, routing_weights = generate_continuation(model, tokenizer, prompt)
 
189
 
190
  return routing_weights
191
 
192
+ def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: bool = True):
193
 
194
  model_path, base_model, model_class = get_model_path(model_id)
195
 
 
203
  model_config.use_bfloat16 = True
204
  model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
205
  model_config.use_cache = use_cache
206
+ model_config.ablate = ablations
207
 
208
  tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
209
  tokenizer.padding_side = "left"