nirmalendu
no head selected fix
e2e40c6
import gradio as gr
# --- Local image paths ---
CAT_WITHOUT_GLASSES_PATH = "cat_without_glasses.png"
CAT_WITH_GLASSES_PATH = "cat_with_glasses.png"
UNICORN_WITH_HORN_PATH = "unicorn_with_horn.png"
UNICORN_NO_HORN_PATH = "unicorn_without_horn.png"
# --- Model / head info ---
NUM_HEADS = 8 # SDXL UNet attention heads per layer
LAYER_CHOICES = [
'unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2',
'unet.down_blocks.0.attentions.1.transformer_blocks.0.attn2',
'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn2',
'unet.down_blocks.1.attentions.1.transformer_blocks.0.attn2',
'unet.down_blocks.2.attentions.0.transformer_blocks.0.attn2',
'unet.down_blocks.2.attentions.1.transformer_blocks.0.attn2',
'unet.mid_block.attentions.0.transformer_blocks.0.attn2',
'unet.up_blocks.1.attentions.0.transformer_blocks.0.attn2',
'unet.up_blocks.1.attentions.1.transformer_blocks.0.attn2',
'unet.up_blocks.1.attentions.2.transformer_blocks.0.attn2',
'unet.up_blocks.2.attentions.0.transformer_blocks.0.attn2',
'unet.up_blocks.2.attentions.1.transformer_blocks.0.attn2',
'unet.up_blocks.2.attentions.2.transformer_blocks.0.attn2',
'unet.up_blocks.3.attentions.0.transformer_blocks.0.attn2',
'unet.up_blocks.3.attentions.1.transformer_blocks.0.attn2',
'unet.up_blocks.3.attentions.2.transformer_blocks.0.attn2'
]
HEAD_CHOICES = [f"head_{i}" for i in range(NUM_HEADS)]
# --- Callbacks ---
def steer_spectacles(strength: int):
"""
Simple placeholder:
- For strength ~0, show no glasses on both sides.
- For strength > 0, show original on the left, glasses on the right.
"""
print(f"Steering strength was {strength}")
if strength == 0:
print(f"Returning {CAT_WITHOUT_GLASSES_PATH} twice")
return CAT_WITHOUT_GLASSES_PATH, CAT_WITHOUT_GLASSES_PATH
else:
CAT_WITH_GLASSES_PATH = f"./cat_steering/Cat_step_{strength}.png.png"
print(f"Returning {CAT_WITH_GLASSES_PATH}")
return CAT_WITHOUT_GLASSES_PATH, CAT_WITH_GLASSES_PATH
def run_unicorn_ablation(selected_layer, selected_heads):
"""
Given selected head labels (e.g., ["head_3", "head_17"]), return:
- Unicorn with horn (original)
- Unicorn without horn (example ablation outcome)
Enforce max of 3 heads.
"""
# layer = selected_layer.split('blocks.')[1].split('.attentions')[0]
if selected_heads is None or len(selected_heads) ==0:
selected_heads = []
return UNICORN_WITH_HORN_PATH
if len(selected_heads) > 3:
selected_heads = selected_heads[:3]
selected_heads.sort()
if len(selected_heads) ==1:
path = f"unicorn_steering/single_heads/{selected_layer}_h{selected_heads[0].replace('head_', '')}"
elif len(selected_heads) ==2:
path = f"unicorn_steering/head_pairs/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}"
elif len(selected_heads) ==3:
path = f"unicorn_steering/head_triples/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}_h{selected_heads[2].replace('head_', '')}"
path = path.replace('.', '_') + ".png"
UNICORN_NO_HORN_PATH = path
print(f"Unicorn no horn path was {UNICORN_NO_HORN_PATH}")
# In a real experiment you'd use `selected_heads` to ablate SDXL heads.
return UNICORN_NO_HORN_PATH
with gr.Blocks() as demo:
# Global CSS, including 😼 slider thumb
gr.HTML("""
<style>
body {
background: #f5f6ff;
color: #222;
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
}
.section-title {
font-size: 1.6rem;
font-weight: 700;
margin: 0.3em 0 0.1em 0;
color: #222;
}
.section-subtitle {
font-size: 0.95rem;
color: #444;
margin-bottom: 0.8em;
}
/* Image "cards" */
.img-card {
padding: 10px;
border-radius: 18px;
border: 3px solid #ffcc66;
background: #fffaf0;
box-shadow: 0 4px 10px rgba(0,0,0,0.12);
}
.img-card img {
border-radius: 12px;
border: 2px dashed #ff9f43;
}
/* Cat slider thumb styling with 😼 emoji */
#cat_steer_slider input[type="range"] {
height: 28px;
}
#cat_steer_slider input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 40px;
height: 40px;
border-radius: 50%;
border: none;
background: transparent;
background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='40' height='40'><text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle' font-size='26'>😼</text></svg>");
background-size: contain;
background-repeat: no-repeat;
background-position: center;
cursor: pointer;
}
#cat_steer_slider input[type="range"]::-moz-range-thumb {
width: 40px;
height: 40px;
border-radius: 50%;
border: none;
background: transparent;
background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='40' height='40'><text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle' font-size='26'>😼</text></svg>");
background-size: contain;
background-repeat: no-repeat;
background-position: center;
cursor: pointer;
}
#cat_steer_slider input[type="range"]::-webkit-slider-runnable-track {
height: 6px;
border-radius: 3px;
background: #ddd;
}
#cat_steer_slider input[type="range"]::-moz-range-track {
height: 6px;
border-radius: 3px;
background: #ddd;
}
</style>
""")
# -------- 1. Cat Steering (CAA) --------
gr.HTML("""
<div>
<h1 class="section-title">Cat Steering Console 😼</h1>
<p class="section-subtitle">
We steer a normal cat image using <strong>contrastive activation addition (CAA)</strong>:
nudging hidden activations along a learned “wear spectacles” direction while keeping
other visual features as stable as possible.
</p>
</div>
""")
with gr.Group():
with gr.Row():
cat_left = gr.Image(
value=CAT_WITHOUT_GLASSES_PATH,
label="Original cat (no glasses)",
interactive=False,
show_label=True,
elem_classes=["img-card"],
)
cat_right = gr.Image(
value=CAT_WITH_GLASSES_PATH,
label="Steered cat",
interactive=False,
show_label=True,
elem_classes=["img-card"],
)
steer_slider = gr.Slider(
minimum=0,
maximum=35,
value=35,
step=5,
label="Steer 😼 (CAA strength towards glasses)",
info="Connect this to your actual CAA steering pipeline.",
elem_id="cat_steer_slider",
)
steer_slider.input(
fn=steer_spectacles,
inputs=steer_slider,
outputs=[cat_left, cat_right]
)
# -------- Transition text --------
gr.Markdown(
"### From steering to ablation\n"
"Below, we move from **additive CAA steering** to **structured ablation**. "
"Instead of pushing along a spectacles direction, we toggle specific SDXL attention "
"heads on/off and interpret the unicorn without a horn as an ablation outcome."
)
# -------- 2. Unicorn Head Ablation --------
gr.HTML("""
<div>
<h2 class="section-title">Unicorn Head Ablation (SDXL)</h2>
<p class="section-subtitle">
Choose up to <strong>three</strong> attention heads (out of 64) to ablate.
In a real SDXL experiment, those heads would be zeroed; here we show a unicorn
with its horn intact alongside an example where the horn is removed by ablation.
</p>
</div>
""")
with gr.Group():
with gr.Row():
unicorn_original = gr.Image(
value=UNICORN_WITH_HORN_PATH,
label="Original unicorn (all heads active)",
interactive=False,
show_label=True,
elem_classes=["img-card"],
)
unicorn_ablated = gr.Image(
value=UNICORN_NO_HORN_PATH,
label="Ablated unicorn",
interactive=False,
show_label=True,
elem_classes=["img-card"],
)
layer_selector = gr.Dropdown(
choices=LAYER_CHOICES,
multiselect=False,
value="unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2",
label="Attention layer to ablate",
info="Select an attention layer to ablate.",
)
head_selector = gr.Dropdown(
choices=HEAD_CHOICES,
multiselect=True,
value=["head_0", "head_1"],
label="Attention heads to ablate (max 3)",
info="Select up to three head indices (0-7). In this demo, images are fixed placeholders.",
)
head_selector.input(
fn=run_unicorn_ablation,
inputs=[layer_selector, head_selector],
outputs=[unicorn_ablated]
)
demo.launch()