Spaces:
Build error
Build error
| import gradio as gr | |
| # --- Local image paths --- | |
| CAT_WITHOUT_GLASSES_PATH = "cat_without_glasses.png" | |
| CAT_WITH_GLASSES_PATH = "cat_with_glasses.png" | |
| UNICORN_WITH_HORN_PATH = "unicorn_with_horn.png" | |
| UNICORN_NO_HORN_PATH = "unicorn_without_horn.png" | |
| # --- Model / head info --- | |
| NUM_HEADS = 8 # SDXL UNet attention heads per layer | |
| LAYER_CHOICES = [ | |
| 'unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2', | |
| 'unet.down_blocks.0.attentions.1.transformer_blocks.0.attn2', | |
| 'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn2', | |
| 'unet.down_blocks.1.attentions.1.transformer_blocks.0.attn2', | |
| 'unet.down_blocks.2.attentions.0.transformer_blocks.0.attn2', | |
| 'unet.down_blocks.2.attentions.1.transformer_blocks.0.attn2', | |
| 'unet.mid_block.attentions.0.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.1.attentions.0.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.1.attentions.1.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.1.attentions.2.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.2.attentions.0.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.2.attentions.1.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.2.attentions.2.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.3.attentions.0.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.3.attentions.1.transformer_blocks.0.attn2', | |
| 'unet.up_blocks.3.attentions.2.transformer_blocks.0.attn2' | |
| ] | |
| HEAD_CHOICES = [f"head_{i}" for i in range(NUM_HEADS)] | |
| # --- Callbacks --- | |
| def steer_spectacles(strength: int): | |
| """ | |
| Simple placeholder: | |
| - For strength ~0, show no glasses on both sides. | |
| - For strength > 0, show original on the left, glasses on the right. | |
| """ | |
| print(f"Steering strength was {strength}") | |
| if strength == 0: | |
| print(f"Returning {CAT_WITHOUT_GLASSES_PATH} twice") | |
| return CAT_WITHOUT_GLASSES_PATH, CAT_WITHOUT_GLASSES_PATH | |
| else: | |
| CAT_WITH_GLASSES_PATH = f"./cat_steering/Cat_step_{strength}.png.png" | |
| print(f"Returning {CAT_WITH_GLASSES_PATH}") | |
| return CAT_WITHOUT_GLASSES_PATH, CAT_WITH_GLASSES_PATH | |
| def run_unicorn_ablation(selected_layer, selected_heads): | |
| """ | |
| Given selected head labels (e.g., ["head_3", "head_17"]), return: | |
| - Unicorn with horn (original) | |
| - Unicorn without horn (example ablation outcome) | |
| Enforce max of 3 heads. | |
| """ | |
| # layer = selected_layer.split('blocks.')[1].split('.attentions')[0] | |
| if selected_heads is None or len(selected_heads) ==0: | |
| selected_heads = [] | |
| return UNICORN_WITH_HORN_PATH | |
| if len(selected_heads) > 3: | |
| selected_heads = selected_heads[:3] | |
| selected_heads.sort() | |
| if len(selected_heads) ==1: | |
| path = f"unicorn_steering/single_heads/{selected_layer}_h{selected_heads[0].replace('head_', '')}" | |
| elif len(selected_heads) ==2: | |
| path = f"unicorn_steering/head_pairs/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}" | |
| elif len(selected_heads) ==3: | |
| path = f"unicorn_steering/head_triples/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}_h{selected_heads[2].replace('head_', '')}" | |
| path = path.replace('.', '_') + ".png" | |
| UNICORN_NO_HORN_PATH = path | |
| print(f"Unicorn no horn path was {UNICORN_NO_HORN_PATH}") | |
| # In a real experiment you'd use `selected_heads` to ablate SDXL heads. | |
| return UNICORN_NO_HORN_PATH | |
| with gr.Blocks() as demo: | |
| # Global CSS, including 😼 slider thumb | |
| gr.HTML(""" | |
| <style> | |
| body { | |
| background: #f5f6ff; | |
| color: #222; | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; | |
| } | |
| .section-title { | |
| font-size: 1.6rem; | |
| font-weight: 700; | |
| margin: 0.3em 0 0.1em 0; | |
| color: #222; | |
| } | |
| .section-subtitle { | |
| font-size: 0.95rem; | |
| color: #444; | |
| margin-bottom: 0.8em; | |
| } | |
| /* Image "cards" */ | |
| .img-card { | |
| padding: 10px; | |
| border-radius: 18px; | |
| border: 3px solid #ffcc66; | |
| background: #fffaf0; | |
| box-shadow: 0 4px 10px rgba(0,0,0,0.12); | |
| } | |
| .img-card img { | |
| border-radius: 12px; | |
| border: 2px dashed #ff9f43; | |
| } | |
| /* Cat slider thumb styling with 😼 emoji */ | |
| #cat_steer_slider input[type="range"] { | |
| height: 28px; | |
| } | |
| #cat_steer_slider input[type="range"]::-webkit-slider-thumb { | |
| -webkit-appearance: none; | |
| appearance: none; | |
| width: 40px; | |
| height: 40px; | |
| border-radius: 50%; | |
| border: none; | |
| background: transparent; | |
| background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='40' height='40'><text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle' font-size='26'>😼</text></svg>"); | |
| background-size: contain; | |
| background-repeat: no-repeat; | |
| background-position: center; | |
| cursor: pointer; | |
| } | |
| #cat_steer_slider input[type="range"]::-moz-range-thumb { | |
| width: 40px; | |
| height: 40px; | |
| border-radius: 50%; | |
| border: none; | |
| background: transparent; | |
| background-image: url("data:image/svg+xml;utf8,<svg xmlns='http://www.w3.org/2000/svg' width='40' height='40'><text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle' font-size='26'>😼</text></svg>"); | |
| background-size: contain; | |
| background-repeat: no-repeat; | |
| background-position: center; | |
| cursor: pointer; | |
| } | |
| #cat_steer_slider input[type="range"]::-webkit-slider-runnable-track { | |
| height: 6px; | |
| border-radius: 3px; | |
| background: #ddd; | |
| } | |
| #cat_steer_slider input[type="range"]::-moz-range-track { | |
| height: 6px; | |
| border-radius: 3px; | |
| background: #ddd; | |
| } | |
| </style> | |
| """) | |
| # -------- 1. Cat Steering (CAA) -------- | |
| gr.HTML(""" | |
| <div> | |
| <h1 class="section-title">Cat Steering Console 😼</h1> | |
| <p class="section-subtitle"> | |
| We steer a normal cat image using <strong>contrastive activation addition (CAA)</strong>: | |
| nudging hidden activations along a learned “wear spectacles” direction while keeping | |
| other visual features as stable as possible. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Group(): | |
| with gr.Row(): | |
| cat_left = gr.Image( | |
| value=CAT_WITHOUT_GLASSES_PATH, | |
| label="Original cat (no glasses)", | |
| interactive=False, | |
| show_label=True, | |
| elem_classes=["img-card"], | |
| ) | |
| cat_right = gr.Image( | |
| value=CAT_WITH_GLASSES_PATH, | |
| label="Steered cat", | |
| interactive=False, | |
| show_label=True, | |
| elem_classes=["img-card"], | |
| ) | |
| steer_slider = gr.Slider( | |
| minimum=0, | |
| maximum=35, | |
| value=35, | |
| step=5, | |
| label="Steer 😼 (CAA strength towards glasses)", | |
| info="Connect this to your actual CAA steering pipeline.", | |
| elem_id="cat_steer_slider", | |
| ) | |
| steer_slider.input( | |
| fn=steer_spectacles, | |
| inputs=steer_slider, | |
| outputs=[cat_left, cat_right] | |
| ) | |
| # -------- Transition text -------- | |
| gr.Markdown( | |
| "### From steering to ablation\n" | |
| "Below, we move from **additive CAA steering** to **structured ablation**. " | |
| "Instead of pushing along a spectacles direction, we toggle specific SDXL attention " | |
| "heads on/off and interpret the unicorn without a horn as an ablation outcome." | |
| ) | |
| # -------- 2. Unicorn Head Ablation -------- | |
| gr.HTML(""" | |
| <div> | |
| <h2 class="section-title">Unicorn Head Ablation (SDXL)</h2> | |
| <p class="section-subtitle"> | |
| Choose up to <strong>three</strong> attention heads (out of 64) to ablate. | |
| In a real SDXL experiment, those heads would be zeroed; here we show a unicorn | |
| with its horn intact alongside an example where the horn is removed by ablation. | |
| </p> | |
| </div> | |
| """) | |
| with gr.Group(): | |
| with gr.Row(): | |
| unicorn_original = gr.Image( | |
| value=UNICORN_WITH_HORN_PATH, | |
| label="Original unicorn (all heads active)", | |
| interactive=False, | |
| show_label=True, | |
| elem_classes=["img-card"], | |
| ) | |
| unicorn_ablated = gr.Image( | |
| value=UNICORN_NO_HORN_PATH, | |
| label="Ablated unicorn", | |
| interactive=False, | |
| show_label=True, | |
| elem_classes=["img-card"], | |
| ) | |
| layer_selector = gr.Dropdown( | |
| choices=LAYER_CHOICES, | |
| multiselect=False, | |
| value="unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2", | |
| label="Attention layer to ablate", | |
| info="Select an attention layer to ablate.", | |
| ) | |
| head_selector = gr.Dropdown( | |
| choices=HEAD_CHOICES, | |
| multiselect=True, | |
| value=["head_0", "head_1"], | |
| label="Attention heads to ablate (max 3)", | |
| info="Select up to three head indices (0-7). In this demo, images are fixed placeholders.", | |
| ) | |
| head_selector.input( | |
| fn=run_unicorn_ablation, | |
| inputs=[layer_selector, head_selector], | |
| outputs=[unicorn_ablated] | |
| ) | |
| demo.launch() | |