| from functools import partial |
| import json |
| import gradio as gr |
| import os |
|
|
| import torch |
| from PIL import Image |
| from SDLens import HookedStableDiffusionXLPipeline, CachedPipeline as CachedFLuxPipeline |
| from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline |
| from SAE import SparseAutoencoder |
| from utils import TimedHook, add_feature_on_area_base, replace_with_feature_base, add_feature_on_area_turbo, replace_with_feature_turbo, add_feature_on_area_flux |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from matplotlib.colors import ListedColormap |
| import threading |
| from einops import rearrange |
| import spaces |
| |
|
|
|
|
| code_to_block_sd = { |
| "down.2.1": "unet.down_blocks.2.attentions.1", |
| "mid.0": "unet.mid_block.attentions.0", |
| "up.0.1": "unet.up_blocks.0.attentions.1", |
| "up.0.0": "unet.up_blocks.0.attentions.0" |
| } |
| code_to_block_flux = {"18": "transformer.transformer_blocks.18"} |
|
|
| FLUX_NAMES = ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev"] |
| MODELS_CONFIG = { |
| "stabilityai/stable-diffusion-xl-base-1.0": { |
| "steps": 25, |
| "guidance_scale": 8.0, |
| "choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"], |
| "value": "down.2.1 (composition)", |
| "code_to_block": code_to_block_sd, |
| "max_steps": 50, |
| "is_flux": False, |
| "downsample_factor": 16, |
| "add_feature_on_area": add_feature_on_area_base, |
| "num_features": 5120, |
|
|
| }, |
| "stabilityai/sdxl-turbo": { |
| "steps": 1, |
| "guidance_scale": 0.0, |
| "choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"], |
| "value": "down.2.1 (composition)", |
| "code_to_block": code_to_block_sd, |
| "max_steps": 4, |
| "is_flux": False, |
| "downsample_factor": 32, |
| "add_feature_on_area": add_feature_on_area_turbo, |
| "num_features": 5120, |
| }, |
| "black-forest-labs/FLUX.1-schnell": { |
| "steps": 1, |
| "guidance_scale": 0.0, |
| "choices": ["18"], |
| "value": "18", |
| "code_to_block": code_to_block_flux, |
| "max_steps": 4, |
| "is_flux": True, |
| "exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102], |
| "downsample_factor": 8, |
| "add_feature_on_area": add_feature_on_area_flux, |
| "num_features": 12288 |
|
|
| }, |
|
|
| "black-forest-labs/FLUX.1-dev": { |
| "steps": 25, |
| "guidance_scale": 0.0, |
| "choices": ["18"], |
| "value": "18", |
| "code_to_block": code_to_block_flux, |
| "max_steps": 50, |
| "is_flux": True, |
| "exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102], |
| "downsample_factor": 8, |
| "add_feature_on_area": add_feature_on_area_flux, |
| "num_features": 12288 |
|
|
| } |
| } |
|
|
|
|
|
|
|
|
| lock = threading.Lock() |
|
|
|
|
|
|
|
|
|
|
| def process_cache(cache, saes_dict, model_config, timestep=None): |
|
|
| top_features_dict = {} |
| sparse_maps_dict = {} |
|
|
| for code in model_config['code_to_block'].keys(): |
| block = model_config["code_to_block"][code] |
| sae = saes_dict[code] |
|
|
|
|
| if model_config["is_flux"]: |
|
|
| with torch.no_grad(): |
| features = sae.encode(torch.stack(cache.image_activation)) |
| features[..., model_config["exclude_list"]] = 0 |
|
|
| if timestep is not None and timestep < features.shape[0]: |
| features = features[timestep:timestep+1] |
|
|
| |
| sparse_maps = rearrange(features, "t b (w h) n -> b t w h n", w=64, h=64).squeeze(0).squeeze(0) |
| |
| else: |
|
|
| diff = cache["output"][block] - cache["input"][block] |
| if diff.shape[0] == 2: |
| diff = diff[1].unsqueeze(0) |
|
|
| |
| if timestep is not None and timestep < diff.shape[1]: |
| diff = diff[:, timestep:timestep+1] |
| |
| diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0) |
| with torch.no_grad(): |
| sparse_maps = sae.encode(diff) |
| |
| averages = torch.mean(sparse_maps, dim=(0, 1)) |
|
|
| top_features = torch.topk(averages, 40).indices |
|
|
| top_features_dict[code] = top_features.cpu().tolist() |
| sparse_maps_dict[code] = sparse_maps.cpu().numpy() |
|
|
| return top_features_dict, sparse_maps_dict |
|
|
|
|
| def plot_image_heatmap(cache, block_select, radio, model_config): |
| code = block_select.split()[0] |
| feature = int(radio) |
| |
| heatmap = cache["heatmaps"][code][:, :, feature] |
| scaling_factor = 16 if model_config["is_flux"] else 32 |
| heatmap = np.kron(heatmap, np.ones((scaling_factor, scaling_factor))) |
| image = cache["image"].convert("RGBA") |
| |
| jet = plt.cm.jet |
| cmap = jet(np.arange(jet.N)) |
| cmap[:1, -1] = 0 |
| cmap[1:, -1] = 0.6 |
| cmap = ListedColormap(cmap) |
| heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap)) |
| heatmap_rgba = cmap(heatmap) |
| heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8)) |
| heatmap_with_transparency = Image.alpha_composite(image, heatmap_image) |
|
|
| return heatmap_with_transparency |
|
|
|
|
| def create_prompt_part(pipe, saes_dict, demo): |
|
|
| model_config = MODELS_CONFIG[pipe.pipe.name_or_path] |
| @spaces.GPU |
| def image_gen(prompt, timestep=None, num_steps=None, guidance_scale=None): |
| lock.acquire() |
| try: |
| |
| default_n_steps = model_config["steps"] |
| default_guidance = model_config["guidance_scale"] |
| |
| |
| n_steps = default_n_steps if num_steps is None else int(num_steps) |
| guidance = default_guidance if guidance_scale is None else float(guidance_scale) |
| |
| |
| timestep_int = None if timestep is None else int(timestep) |
| |
| if "FLUX" in pipe.pipe.name_or_path: |
| images = pipe.run( |
| prompt, |
| num_inference_steps=n_steps, |
| width=1024, |
| height=1024, |
| cache_activations=True, |
| guidance_scale=guidance, |
| positions_to_cache = list(model_config["code_to_block"].values()), |
| inverse=False, |
| ) |
| cache = pipe.activation_cache |
| |
| else: |
| images, cache = pipe.run_with_cache( |
| prompt, |
| positions_to_cache=list(model_config["code_to_block"].values()), |
| num_inference_steps=n_steps, |
| generator=torch.Generator(device="cpu").manual_seed(42), |
| guidance_scale=guidance, |
| save_input=True, |
| save_output=True |
| ) |
| finally: |
| lock.release() |
| |
| top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict, model_config, timestep_int) |
| return images.images[0], { |
| "image": images.images[0], |
| "heatmaps": top_sparse_maps_dict, |
| "features": top_features_dict |
| } |
|
|
| def update_radio(cache, block_select): |
| code = block_select.split()[0] |
| return gr.update(choices=cache["features"][code]) |
|
|
| def update_img(cache, block_select, radio): |
| new_img = plot_image_heatmap(cache, block_select, radio, model_config) |
| return new_img |
|
|
| with gr.Tab("Explore", elem_classes="tabs") as explore_tab: |
| cache = gr.State(value={ |
| "image": None, |
| "heatmaps": None, |
| "features": [] |
| }) |
| with gr.Row(): |
| with gr.Column(scale=7): |
| with gr.Row(equal_height=True): |
| prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.") |
| button = gr.Button("Generate", elem_classes="generate_button1") |
|
|
| with gr.Row(): |
| image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image") |
| |
| with gr.Column(scale=4): |
| block_select = gr.Dropdown( |
| choices=model_config["choices"], |
| value=model_config["value"], |
| label="Select block", |
| elem_id="block_select", |
| interactive=True |
| ) |
| |
| with gr.Group() as sdxl_base_controls: |
| steps_slider = gr.Slider( |
| minimum=1, |
| maximum=model_config["max_steps"], |
| value= model_config["steps"], |
| step=1, |
| label="Number of steps", |
| elem_id="steps_slider", |
| interactive=True, |
| visible=True |
| ) |
| |
| |
| |
| timestep_selector = gr.Slider( |
| minimum=0, |
| maximum=model_config["max_steps"]-1, |
| value=None, |
| step=1, |
| label="Timestep (leave empty for average across all steps)", |
| elem_id="timestep_selector", |
| interactive=True, |
| visible=True, |
| ) |
| recompute_button = gr.Button("Recompute", elem_id="recompute_button") |
| |
| steps_slider.change(lambda s: gr.update(maximum=s-1), [steps_slider], [timestep_selector]) |
| |
| radio = gr.Radio(choices=[], label="Select a feature", interactive=True) |
| |
| button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache]) |
| cache.change(update_radio, [cache, block_select], outputs=[radio]) |
| block_select.select(update_radio, [cache, block_select], outputs=[radio]) |
| radio.select(update_img, [cache, block_select, radio], outputs=[image]) |
| recompute_button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache]) |
| demo.load(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache]) |
|
|
| return explore_tab |
|
|
| def downsample_mask(image, factor): |
| downsampled = image.reshape( |
| (image.shape[0] // factor, factor, |
| image.shape[1] // factor, factor) |
| ) |
| downsampled = downsampled.mean(axis=(1, 3)) |
| return downsampled |
|
|
| def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo): |
| model_config = MODELS_CONFIG[pipe.pipe.name_or_path] |
|
|
| @spaces.GPU |
| def image_gen(prompt, num_steps, guidance_scale=None): |
| lock.acquire() |
| guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale) |
| try: |
|
|
| if "FLUX" in pipe.pipe.name_or_path: |
| images = pipe.run( |
| prompt, |
| num_inference_steps=int(num_steps), |
| width=1024, |
| height=1024, |
| cache_activations=False, |
| guidance_scale=guidance, |
| inverse=False, |
| ) |
| else: |
| images = pipe.run_with_hooks( |
| prompt, |
| position_hook_dict={}, |
| num_inference_steps=int(num_steps), |
| generator=torch.Generator(device="cpu").manual_seed(42), |
| guidance_scale=guidance, |
| ) |
| finally: |
| lock.release() |
| if images.images[0].size == (1024, 1024): |
| return images.images[0].resize((512, 512)) |
| else: |
| return images.images[0] |
|
|
| @spaces.GPU |
| def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image, block_start, block_end, start_index=None, end_index=None, guidance_scale=None): |
| block = block_str.split(" ")[0] |
|
|
| mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float) |
| mask = downsample_mask(mask, model_config["downsample_factor"]) |
| mask = torch.tensor(mask, dtype=torch.float32, device="cuda") |
|
|
| if mask.sum() == 0: |
| gr.Info("No mask selected, please draw on the input image") |
| |
| |
| |
| if start_index is None: |
| start_index = 0 |
| if end_index is None: |
| end_index = int(num_steps) |
| |
| |
| start_index = max(0, min(int(start_index), int(num_steps))) |
| end_index = max(0, min(int(end_index), int(num_steps))) |
| |
| |
| if start_index >= end_index: |
| start_index = max(0, end_index - 1) |
|
|
|
|
| def myhook(module, input, output): |
| return model_config["add_feature_on_area"]( |
| saes_dict[block], |
| brush_index, |
| mask * means_dict[block][brush_index] * strength, |
| module, |
| input, |
| output) |
| hook = TimedHook(myhook, int(num_steps), np.arange(start_index, end_index)) |
|
|
| lock.acquire() |
| guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale) |
| |
| try: |
|
|
| if model_config["is_flux"]: |
| image = pipe.run_with_edit( |
| prompt, |
| seed=42, |
| num_inference_steps=int(num_steps), |
| edit_fn= lambda input, output: hook(None, input, output), |
| layers_for_edit_fn=[i for i in range(block_start, block_end + 1)], |
| stream="image").images[0] |
| else: |
|
|
| image = pipe.run_with_hooks( |
| prompt, |
| position_hook_dict={model_config["code_to_block"][block]: hook}, |
| num_inference_steps=int(num_steps), |
| generator=torch.Generator(device="cpu").manual_seed(42), |
| guidance_scale=guidance |
| ).images[0] |
| finally: |
| lock.release() |
| return image |
|
|
| def feature_icon(block_str, brush_index, guidance_scale=None): |
| block = block_str.split(" ")[0] |
| if block in ["mid.0", "up.0.0"]: |
| gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3) |
|
|
| def hook(module, input, output): |
| if is_base_model: |
| return replace_with_feature_base( |
| saes_dict[block], |
| brush_index, |
| means_dict[block][brush_index] * saes_dict[block].k, |
| module, |
| input, |
| output |
| ) |
| else: |
| return replace_with_feature_turbo( |
| saes_dict[block], |
| brush_index, |
| means_dict[block][brush_index] * saes_dict[block].k, |
| module, |
| input, |
| output) |
| lock.acquire() |
| guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale) |
| |
| try: |
| image = pipe.run_with_hooks( |
| "", |
| position_hook_dict={model_config["code_to_block"][block]: hook}, |
| num_inference_steps=model_config["steps"], |
| generator=torch.Generator(device="cpu").manual_seed(42), |
| guidance_scale=guidance, |
| ).images[0] |
| finally: |
| lock.release() |
| return image |
|
|
| with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab: |
| image_state = gr.State(value=None) |
| with gr.Row(): |
| with gr.Column(scale=3): |
| |
| with gr.Row(): |
| |
| prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, cartoon", elem_id="prompt_input") |
| |
| with gr.Row(): |
| num_steps = gr.Number(value=model_config["steps"], label="Number of steps", minimum=1, maximum=model_config["max_steps"], elem_id="num_steps", precision=0) |
| |
| with gr.Row(): |
| |
| button_generate = gr.Button("Generate", elem_id="generate_button") |
| with gr.Column(scale=3): |
| |
| with gr.Row(): |
| |
| with gr.Column(scale=7): |
| with gr.Row(): |
| block_select = gr.Dropdown( |
| choices=model_config["choices"], |
| value=model_config["value"], |
| label="Select block", |
| elem_id="block_select" |
| ) |
| brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, elem_id="brush_index", precision=0) |
| |
| |
| with gr.Row(): |
| gr.Markdown("**TimedHook Range** (which steps to apply the feature)", visible=True) |
| with gr.Row(): |
| start_index = gr.Number(value=0, label="Start index", minimum=0, maximum=model_config["max_steps"], elem_id="start_index", precision=0, visible=True) |
| end_index = gr.Number(value=model_config["steps"], label="End index", minimum=0, maximum=model_config["max_steps"], elem_id="end_index", precision=0, visible=True) |
| with gr.Column(scale=3): |
| with gr.Row(): |
| strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2) |
| with gr.Row(): |
| block_start = gr.Number(value=18, label="Block start", minimum=0, maximum=56, elem_id="block_start") |
| block_finish = gr.Number(value=56, label="Block finish", minimum=0, maximum=56, elem_id="block_finish") |
| with gr.Row(): |
| button = gr.Button('Apply', elem_id="apply_button") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| |
| i_image = gr.Sketchpad( |
| height=610, |
| layers=False, transforms=[], placeholder="Generate and paint!", |
| brush=gr.Brush(default_size=64, color_mode="fixed", colors=['black']), |
| container=False, |
| canvas_size=(512, 512), |
| label="Input Image") |
| clear_button = gr.Button("Clear") |
| clear_button.click(lambda x: x, [image_state], [i_image]) |
| |
| o_image = gr.Image(width=512, height=512, label="Output Image") |
|
|
| |
| button_generate.click(image_gen, inputs=[prompt_field, num_steps], outputs=[image_state]) |
| image_state.change(lambda x: x, [image_state], [i_image]) |
| |
| |
| def update_index_maxes(steps): |
| return gr.update(maximum=steps), gr.update(maximum=steps) |
| |
| num_steps.change(update_index_maxes, [num_steps], [start_index, end_index]) |
| |
| button.click(image_mod, |
| inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image, block_start, block_finish, start_index, end_index], |
| outputs=o_image) |
| |
| demo.load(image_gen, [prompt_field, num_steps], outputs=[image_state]) |
|
|
|
|
| return intervene_tab |
|
|
|
|
|
|
| def create_top_images_part(demo, pipe): |
|
|
| model_config = MODELS_CONFIG[pipe.pipe.name_or_path] |
| |
| if isinstance(pipe, HookedStableDiffusionXLPipeline): |
| is_flux = False |
| elif isinstance(pipe, CachedFLuxPipeline): |
| is_flux = True |
| else: |
| raise AssertionError(f"Unknown pipe class: {type(pipe)}") |
| |
| def update_top_images(block_select, brush_index): |
| block = block_select.split(" ")[0] |
| |
| if is_flux: |
| part = 1 if brush_index <= 7000 else 2 |
| url = f"https://huggingface.co/datasets/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{brush_index}.jpg" |
| else: |
| url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg" |
| return url |
|
|
| with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab: |
| with gr.Row(): |
| block_select = gr.Dropdown( |
| choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"], |
| value="flux_18" if is_flux else "down.2.1 (composition)", |
| label="Select block" |
| ) |
| brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, precision=0) |
| with gr.Row(): |
| image = gr.Image(width=600, height=600, label="Top Images") |
|
|
| block_select.select(update_top_images, [block_select, brush_index], outputs=[image]) |
| brush_index.change(update_top_images, [block_select, brush_index], outputs=[image]) |
| demo.load(update_top_images, [block_select, brush_index], outputs=[image]) |
| return top_images_tab |
|
|
|
|
| def create_top_images_plus_search_part(retriever, demo, pipe): |
|
|
| model_config = MODELS_CONFIG[pipe.pipe.name_or_path] |
|
|
| |
|
|
| if isinstance(pipe, HookedStableDiffusionXLPipeline): |
| is_flux = False |
| elif isinstance(pipe, CachedFLuxPipeline): |
| is_flux = True |
| else: |
| raise AssertionError(f"Unknown pipe class: {type(pipe)}") |
|
|
| def update_cache(block_select, search_by_text, search_by_index): |
| if search_by_text == "": |
| top_indices = [] |
| index = search_by_index |
| block = block_select.split(" ")[0] |
|
|
| |
| if is_flux: |
| part = 1 if index <= 7000 else 2 |
| url = f"https://huggingface.co/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{index}.jpg" |
| else: |
| url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg" |
| return url, {"image": url, "feature_idx": index, "features": top_indices} |
| else: |
| |
| if retriever is None: |
| raise ValueError("Feature retrieval is not enabled") |
| lock.acquire() |
| try: |
| top_indices = list(retriever.query_text(search_by_text, block_select.split(" ")[0]).keys()) |
| finally: |
| lock.release() |
| block = block_select.split(" ")[0] |
| top_indices = list(map(int, top_indices)) |
| index = top_indices[0] |
| url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg" |
| return url, {"image": url, "feature_idx": index, "features": top_indices[:20]} |
|
|
| def update_radio(cache): |
| return gr.update(choices=cache["features"], value=cache["feature_idx"]) |
|
|
| def update_img(cache, block_select, index): |
| block = block_select.split(" ")[0] |
| url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg" |
| return url |
|
|
| with gr.Tab("Top Images", elem_classes="tabs") as explore_tab: |
| cache = gr.State(value={ |
| "image": None, |
| "feature_idx": None, |
| "features": [] |
| }) |
| with gr.Row(): |
| with gr.Column(scale=7): |
| with gr.Row(): |
| |
| image = gr.Image(width=600, height=600, image_mode="RGB", label="Top images") |
| |
| with gr.Column(scale=4): |
| block_select = gr.Dropdown( |
| choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"], |
| value="flux_18" if is_flux else "down.2.1 (composition)", |
| label="Select block", |
| elem_id="block_select", |
| interactive=True |
| ) |
| search_by_index = gr.Number(value=0, label="Search by index", minimum=0, maximum=model_config["num_features"]-1, precision=0) |
| search_by_text = gr.Textbox(lines=1, label="Search by text", value="", visible=False) |
| radio = gr.Radio(choices=[], label="Select a feature", interactive=True, visible=False) |
| |
|
|
| search_by_text.change(update_cache, |
| [block_select, search_by_text, search_by_index], |
| outputs=[image, cache]) |
| block_select.select(update_cache, |
| [block_select, search_by_text, search_by_index], |
| outputs=[image, cache]) |
| cache.change(update_radio, [cache], outputs=[radio]) |
| radio.select(update_img, [cache, block_select, radio], outputs=[image]) |
| search_by_index.change(update_img, [cache, block_select, search_by_index], outputs=[image]) |
| demo.load(update_img, |
| [cache, block_select, search_by_index], |
| outputs=[image]) |
|
|
| return explore_tab |
|
|
|
|
| def create_intro_part(): |
| with gr.Tab("Instructions", elem_classes="tabs") as intro_tab: |
| gr.Markdown( |
| '''# Unpacking SDXL Turbo with Sparse Autoencoders |
| ## Demo Overview |
| This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL Turbo model. |
| |
| ## How to Use |
| ### Explore |
| * Enter a prompt in the text box and click on the "Generate" button to generate an image. |
| * You can observe the active features in different blocks plot on top of the generated image. |
| ### Top Images |
| * For each feature, you can view the top images that activate the feature the most. |
| ### Paint! |
| * Generate an image using the prompt. |
| * Paint on the generated image to apply interventions. |
| * Use the "Feature Icon" button to understand how the selected brush functions. |
| |
| ### Remarks |
| * Not all brushes mix well with all images. Experiment with different brushes and strengths. |
| * Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks. |
| * This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo. |
| |
| ### Interesting features to try |
| To get started, try the following features: |
| - down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon) |
| - up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur) |
| ''' |
| ) |
| |
| return intro_tab |
|
|
|
|
| def create_demo(pipe, saes_dict, means_dict, use_retrieval=True): |
| custom_css = """ |
| .tabs button { |
| font-size: 20px !important; /* Adjust font size for tab text */ |
| padding: 10px !important; /* Adjust padding to make the tabs bigger */ |
| font-weight: bold !important; /* Adjust font weight to make the text bold */ |
| } |
| .generate_button1 { |
| max-width: 160px !important; |
| margin-top: 20px !important; |
| margin-bottom: 20px !important; |
| } |
| """ |
| if use_retrieval: |
| retriever = None |
| else: |
| retriever = None |
|
|
| with gr.Blocks(css=custom_css) as demo: |
| |
| |
| with create_prompt_part(pipe, saes_dict, demo): |
| pass |
| with create_top_images_part(demo, pipe): |
| pass |
| with create_intervene_part(pipe, saes_dict, means_dict, demo): |
| pass |
| |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| import os |
| import gradio as gr |
| import torch |
| from SDLens import HookedStableDiffusionXLPipeline |
| from SAE import SparseAutoencoder |
| from huggingface_hub import hf_hub_download |
| from huggingface_hub import login |
| login(token=os.environ["HF_TOKEN"]) |
|
|
| dtype = torch.float16 |
| pipe = EditedFluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-schnell", |
| device_map="balanced", |
| torch_dtype=dtype |
| ) |
| pipe.set_progress_bar_config(disable=True) |
| pipe = CachedFLuxPipeline(pipe) |
|
|
| |
| DEVICE = "cuda" |
|
|
| |
| HF_REPO_ID = "antoniomari/SAE_flux_18" |
| HF_BRANCH = "main" |
|
|
| |
| block_code = "18" |
| block_name = code_to_block_flux[block_code] |
|
|
| saes_dict = {} |
| means_dict = {} |
|
|
| |
| state_dict_path = hf_hub_download( |
| repo_id=HF_REPO_ID, |
| filename="state_dict.pth", |
| revision=HF_BRANCH |
| ) |
|
|
| config_path = hf_hub_download( |
| repo_id=HF_REPO_ID, |
| filename="config.json", |
| revision=HF_BRANCH |
| ) |
|
|
| mean_path = hf_hub_download( |
| repo_id=HF_REPO_ID, |
| filename="mean.pt", |
| revision=HF_BRANCH |
| ) |
|
|
| |
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| sae = SparseAutoencoder(**config) |
| checkpoint = torch.load(state_dict_path, map_location=DEVICE) |
| state_dict = checkpoint["state_dict"] |
| sae.load_state_dict(state_dict) |
| sae = sae.to(DEVICE, dtype=torch.float16).eval() |
| means = torch.load(mean_path, map_location=DEVICE).to(dtype) |
|
|
| saes_dict[block_code] = sae |
| means_dict[block_code] = means |
|
|
| demo = create_demo(pipe, saes_dict, means_dict) |
| demo.launch() |
|
|