Spaces:
Running
Running
added expert ablation
Browse files- app.py +61 -20
- router_backend.py +4 -4
app.py
CHANGED
|
@@ -81,7 +81,16 @@ def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
|
|
| 81 |
return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
|
| 82 |
return user_prompt
|
| 83 |
|
| 84 |
-
def route_and_plot(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
Main pipeline:
|
| 87 |
- Compose prompt (user + optional assistant)
|
|
@@ -91,29 +100,45 @@ def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant
|
|
| 91 |
if hf_token.strip() == "":
|
| 92 |
hf_token = None # allow empty token
|
| 93 |
|
| 94 |
-
|
| 95 |
-
if
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
| 101 |
seed = 42
|
| 102 |
use_mock = False
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
generation = None
|
|
|
|
| 107 |
else:
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 116 |
generation = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
|
| 119 |
colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
|
|
@@ -143,6 +168,20 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 143 |
with gr.Row():
|
| 144 |
model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
|
| 145 |
hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
with gr.Row():
|
| 148 |
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
|
|
@@ -159,11 +198,13 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
| 159 |
with gr.Row():
|
| 160 |
table = gr.Dataframe(label="Routing Percentages", interactive=False)
|
| 161 |
plot = gr.Plot(label="Bar Plot")
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
|
| 164 |
run.click(
|
| 165 |
route_and_plot,
|
| 166 |
-
inputs=[model_choice, hf_token, user_prompt, assistant_prompt],
|
| 167 |
outputs=[generation_output, table, plot, status],
|
| 168 |
)
|
| 169 |
|
|
|
|
| 81 |
return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
|
| 82 |
return user_prompt
|
| 83 |
|
| 84 |
+
def route_and_plot(
|
| 85 |
+
model_choice: str,
|
| 86 |
+
hf_token: str,
|
| 87 |
+
user_prompt: str,
|
| 88 |
+
assistant_prompt: str,
|
| 89 |
+
ablate_language: bool,
|
| 90 |
+
ablate_logic: bool,
|
| 91 |
+
ablate_social: bool,
|
| 92 |
+
ablate_world: bool,
|
| 93 |
+
) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
|
| 94 |
"""
|
| 95 |
Main pipeline:
|
| 96 |
- Compose prompt (user + optional assistant)
|
|
|
|
| 100 |
if hf_token.strip() == "":
|
| 101 |
hf_token = None # allow empty token
|
| 102 |
|
| 103 |
+
ablations = []
|
| 104 |
+
if ablate_language:
|
| 105 |
+
ablations.append("language")
|
| 106 |
+
if ablate_logic:
|
| 107 |
+
ablations.append("logic")
|
| 108 |
+
if ablate_social:
|
| 109 |
+
ablations.append("social")
|
| 110 |
+
if ablate_world:
|
| 111 |
+
ablations.append("world")
|
| 112 |
+
|
| 113 |
seed = 42
|
| 114 |
use_mock = False
|
| 115 |
+
|
| 116 |
+
if len(ablations) == 4:
|
| 117 |
+
msg = "Error message: you can't ablate all experts.<br>Falling back to mock data."
|
| 118 |
generation = None
|
| 119 |
+
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 120 |
else:
|
| 121 |
+
model_id = model_choice.strip()
|
| 122 |
+
if not model_id:
|
| 123 |
+
raise gr.Error("Please select a model or enter a custom model id.")
|
| 124 |
+
prompt = _compose_prompt(user_prompt, assistant_prompt)
|
| 125 |
+
if not prompt:
|
| 126 |
+
raise gr.Error("Please enter a prompt.")
|
| 127 |
+
|
| 128 |
+
if use_mock:
|
| 129 |
+
msg = "Using mock data."
|
| 130 |
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 131 |
generation = None
|
| 132 |
+
else:
|
| 133 |
+
try:
|
| 134 |
+
raw, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function
|
| 135 |
+
vals = _normalize_output(raw)
|
| 136 |
+
msg = "Routed with real backend."
|
| 137 |
+
except Exception as e:
|
| 138 |
+
# fallback to mock on error, but surface message
|
| 139 |
+
msg = f"Backend error: {e}\nFalling back to mock data."
|
| 140 |
+
vals = _mock_routing(model_id, prompt, seed=seed)
|
| 141 |
+
generation = None
|
| 142 |
|
| 143 |
df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
|
| 144 |
colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
|
|
|
|
| 168 |
with gr.Row():
|
| 169 |
model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
|
| 170 |
hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
|
| 171 |
+
|
| 172 |
+
with gr.Column():
|
| 173 |
+
with gr.Row():
|
| 174 |
+
gr.Markdown(
|
| 175 |
+
"""
|
| 176 |
+
#### Ablate Experts
|
| 177 |
+
(Check to disable an expert; the routing percentages will be redistributed among the remaining experts)
|
| 178 |
+
""", label="Ablate Experts"
|
| 179 |
+
)
|
| 180 |
+
with gr.Row():
|
| 181 |
+
ablate_language = gr.Checkbox(value=False, label="Language Expert")
|
| 182 |
+
ablate_logic = gr.Checkbox(value=False, label="Logic Expert")
|
| 183 |
+
ablate_social = gr.Checkbox(value=False, label="Social Expert")
|
| 184 |
+
ablate_world = gr.Checkbox(value=False, label="World Expert")
|
| 185 |
|
| 186 |
with gr.Row():
|
| 187 |
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
|
|
|
|
| 198 |
with gr.Row():
|
| 199 |
table = gr.Dataframe(label="Routing Percentages", interactive=False)
|
| 200 |
plot = gr.Plot(label="Bar Plot")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
status = gr.Markdown("", label="System Message")
|
| 204 |
|
| 205 |
run.click(
|
| 206 |
route_and_plot,
|
| 207 |
+
inputs=[model_choice, hf_token, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
|
| 208 |
outputs=[generation_output, table, plot, status],
|
| 209 |
)
|
| 210 |
|
router_backend.py
CHANGED
|
@@ -26,9 +26,9 @@ from models.micro_moe_llama import MiCRoLlamaMoE
|
|
| 26 |
|
| 27 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 28 |
|
| 29 |
-
def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]]) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:
|
| 30 |
|
| 31 |
-
model, tokenizer = build_model(model_id, hf_token)
|
| 32 |
|
| 33 |
if isinstance(prompt, str):
|
| 34 |
generation, routing_weights = generate_continuation(model, tokenizer, prompt)
|
|
@@ -189,7 +189,7 @@ def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
|
|
| 189 |
|
| 190 |
return routing_weights
|
| 191 |
|
| 192 |
-
def build_model(model_id: str, hf_token: str, use_cache: bool = True):
|
| 193 |
|
| 194 |
model_path, base_model, model_class = get_model_path(model_id)
|
| 195 |
|
|
@@ -203,7 +203,7 @@ def build_model(model_id: str, hf_token: str, use_cache: bool = True):
|
|
| 203 |
model_config.use_bfloat16 = True
|
| 204 |
model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
|
| 205 |
model_config.use_cache = use_cache
|
| 206 |
-
model_config.ablate =
|
| 207 |
|
| 208 |
tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
|
| 209 |
tokenizer.padding_side = "left"
|
|
|
|
| 26 |
|
| 27 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 28 |
|
| 29 |
+
def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]], ablations: List[str] = None) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:
|
| 30 |
|
| 31 |
+
model, tokenizer = build_model(model_id, hf_token, ablations=ablations)
|
| 32 |
|
| 33 |
if isinstance(prompt, str):
|
| 34 |
generation, routing_weights = generate_continuation(model, tokenizer, prompt)
|
|
|
|
| 189 |
|
| 190 |
return routing_weights
|
| 191 |
|
| 192 |
+
def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: bool = True):
|
| 193 |
|
| 194 |
model_path, base_model, model_class = get_model_path(model_id)
|
| 195 |
|
|
|
|
| 203 |
model_config.use_bfloat16 = True
|
| 204 |
model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
|
| 205 |
model_config.use_cache = use_cache
|
| 206 |
+
model_config.ablate = ablations
|
| 207 |
|
| 208 |
tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
|
| 209 |
tokenizer.padding_side = "left"
|