|
|
from __future__ import annotations |
|
|
|
|
|
import gc |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import streamlit as st |
|
|
import torch |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
from pytorch_lightning import seed_everything |
|
|
from torch import autocast |
|
|
|
|
|
ROOT = Path(__file__).resolve().parent |
|
|
PARENT = ROOT.parent |
|
|
|
|
|
if str(PARENT) not in sys.path: |
|
|
sys.path.insert(0, str(PARENT)) |
|
|
|
|
|
STABLE_DIFFUSION_DIR = ROOT / "stable_diffusion" |
|
|
if str(STABLE_DIFFUSION_DIR) not in sys.path: |
|
|
sys.path.insert(0, str(STABLE_DIFFUSION_DIR)) |
|
|
|
|
|
from stable_diffusion.ldm.models.diffusion.ddim import DDIMSampler |
|
|
from stable_diffusion.ldm.util import instantiate_from_config |
|
|
|
|
|
WEIGHTS_DIR = ROOT / "weights" |
|
|
CONFIG_PATH = ROOT / "generate_sd.yaml" |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
theme_available = [ |
|
|
"Abstractionism", "Artist_Sketch", "Blossom_Season", "Bricks", "Byzantine", "Cartoon", |
|
|
"Cold_Warm", "Color_Fantasy", "Comic_Etch", "Crayon", "Cubism", "Dadaism", "Dapple", |
|
|
"Defoliation", "Early_Autumn", "Expressionism", "Fauvism", "French", "Glowing_Sunset", |
|
|
"Gorgeous_Love", "Greenfield", "Impressionism", "Ink_Art", "Joy", "Liquid_Dreams", |
|
|
"Magic_Cube", "Meta_Physics", "Meteor_Shower", "Monet", "Mosaic", "Neon_Lines", "On_Fire", |
|
|
"Pastel", "Pencil_Drawing", "Picasso", "Pop_Art", "Red_Blue_Ink", "Rust", "Seed_Images", |
|
|
"Sketch", "Sponge_Dabbed", "Structuralism", "Superstring", "Surrealism", "Ukiyoe", |
|
|
"Van_Gogh", "Vibrant_Flow", "Warm_Love", "Warm_Smear", "Watercolor", "Winter", |
|
|
] |
|
|
|
|
|
class_available = [ |
|
|
"Architectures", "Bears", "Birds", "Butterfly", "Cats", "Dogs", "Fishes", "Flame", "Flowers", |
|
|
"Frogs", "Horses", "Human", "Jellyfish", "Rabbits", "Sandwiches", "Sea", "Statues", "Towers", |
|
|
"Trees", "Waterfalls", |
|
|
] |
|
|
|
|
|
if not WEIGHTS_DIR.exists(): |
|
|
raise FileNotFoundError(f"Weights directory not found: {WEIGHTS_DIR}") |
|
|
|
|
|
MODEL_CONFIGS = {} |
|
|
|
|
|
original_display_name = None |
|
|
theme_model_for = {} |
|
|
class_model_for = {} |
|
|
other_models = set() |
|
|
|
|
|
for pattern in ("*.pth", "*.ckpt"): |
|
|
for ckpt in WEIGHTS_DIR.glob(pattern): |
|
|
stem = ckpt.stem |
|
|
|
|
|
if stem.lower() == "original": |
|
|
display_name = "Original (no unlearning)" |
|
|
category = "original" |
|
|
original_display_name = display_name |
|
|
|
|
|
elif stem in theme_available: |
|
|
display_name = f"Style Unlearned: {stem}" |
|
|
category = "theme" |
|
|
theme_model_for[stem] = display_name |
|
|
|
|
|
elif stem in class_available: |
|
|
display_name = f"Object Unlearned: {stem}" |
|
|
category = "class" |
|
|
class_model_for[stem] = display_name |
|
|
|
|
|
else: |
|
|
display_name = stem |
|
|
category = "other" |
|
|
other_models.add(display_name) |
|
|
|
|
|
MODEL_CONFIGS[display_name] = { |
|
|
"ckpt": str(ckpt), |
|
|
"config": str(CONFIG_PATH), |
|
|
"category": category, |
|
|
"raw_name": stem, |
|
|
} |
|
|
|
|
|
if not MODEL_CONFIGS: |
|
|
raise RuntimeError(f"No .pth or .ckpt files found in {WEIGHTS_DIR}") |
|
|
|
|
|
def load_model_from_config(config, ckpt_path: str, verbose: bool = False): |
|
|
""" |
|
|
Load model from checkpoint + config, move to DEVICE, eval mode. |
|
|
""" |
|
|
print(f"Loading model from {ckpt_path}") |
|
|
pl_sd = torch.load(ckpt_path, map_location="cpu") |
|
|
if "global_step" in pl_sd: |
|
|
print(f"Global Step: {pl_sd['global_step']}") |
|
|
sd = pl_sd["state_dict"] |
|
|
|
|
|
model = instantiate_from_config(config.model) |
|
|
missing, unexpected = model.load_state_dict(sd, strict=False) |
|
|
if verbose: |
|
|
if len(missing) > 0: |
|
|
print("missing keys:") |
|
|
print(missing) |
|
|
if len(unexpected) > 0: |
|
|
print("unexpected keys:") |
|
|
print(unexpected) |
|
|
|
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def generate_image_single( |
|
|
model_name: str, |
|
|
prompt: str, |
|
|
steps: int, |
|
|
cfg_text: float, |
|
|
seed: int, |
|
|
H: int, |
|
|
W: int, |
|
|
ddim_eta: float, |
|
|
): |
|
|
""" |
|
|
Load selected checkpoint, generate one image for given prompt, |
|
|
then free all model memory (CPU + GPU). |
|
|
""" |
|
|
model_cfg = MODEL_CONFIGS[model_name] |
|
|
ckpt_path = model_cfg["ckpt"] |
|
|
config_path = model_cfg["config"] |
|
|
|
|
|
|
|
|
config = OmegaConf.load(config_path) |
|
|
model = load_model_from_config(config, ckpt_path) |
|
|
sampler = DDIMSampler(model) |
|
|
|
|
|
seed_everything(seed) |
|
|
|
|
|
print(f"Prompt: {prompt}") |
|
|
|
|
|
|
|
|
if DEVICE == "cuda": |
|
|
autocast_ctx = autocast("cuda") |
|
|
else: |
|
|
from contextlib import nullcontext |
|
|
autocast_ctx = nullcontext() |
|
|
|
|
|
with torch.no_grad(): |
|
|
with autocast_ctx: |
|
|
try: |
|
|
ema_ctx = model.ema_scope() |
|
|
except AttributeError: |
|
|
from contextlib import nullcontext |
|
|
ema_ctx = nullcontext() |
|
|
|
|
|
with ema_ctx: |
|
|
uc = model.get_learned_conditioning([""]) |
|
|
c = model.get_learned_conditioning(prompt) |
|
|
shape = [4, H // 8, W // 8] |
|
|
|
|
|
samples_ddim, _ = sampler.sample( |
|
|
S=steps, |
|
|
conditioning=c, |
|
|
batch_size=1, |
|
|
shape=shape, |
|
|
verbose=False, |
|
|
unconditional_guidance_scale=cfg_text, |
|
|
unconditional_conditioning=uc, |
|
|
eta=ddim_eta, |
|
|
x_T=None, |
|
|
) |
|
|
|
|
|
x_samples_ddim = model.decode_first_stage(samples_ddim) |
|
|
x_samples_ddim = torch.clamp( |
|
|
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0 |
|
|
) |
|
|
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1) |
|
|
assert len(x_samples_ddim) == 1 |
|
|
x_sample = x_samples_ddim[0].numpy() |
|
|
|
|
|
|
|
|
x_sample = (255.0 * x_sample).round().astype(np.uint8) |
|
|
img = Image.fromarray(x_sample) |
|
|
|
|
|
|
|
|
del sampler |
|
|
del model |
|
|
if DEVICE == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
return img, prompt |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Unlearning Styles Demo", layout="wide") |
|
|
|
|
|
st.title("Machine Unlearning Demo - Styles and Objects") |
|
|
|
|
|
st.sidebar.header("Model selection") |
|
|
|
|
|
model_family_options = [] |
|
|
if original_display_name is not None: |
|
|
model_family_options.append("Original") |
|
|
if theme_model_for: |
|
|
model_family_options.append("Style Unlearned") |
|
|
if class_model_for: |
|
|
model_family_options.append("Object Unlearned") |
|
|
if other_models: |
|
|
model_family_options.append("Other") |
|
|
|
|
|
model_family = st.sidebar.radio( |
|
|
"Which model family?", |
|
|
model_family_options, |
|
|
label_visibility='hidden', |
|
|
) |
|
|
|
|
|
selected_model_display_name = None |
|
|
|
|
|
if model_family == "Original": |
|
|
st.sidebar.markdown(f"**Using Model:** \n {original_display_name}") |
|
|
selected_model_display_name = original_display_name |
|
|
|
|
|
elif model_family == "Style Unlearned": |
|
|
available_theme_keys = sorted(theme_model_for.keys()) |
|
|
chosen_theme_model = st.sidebar.selectbox( |
|
|
"Unlearned style model", |
|
|
available_theme_keys, |
|
|
) |
|
|
selected_model_display_name = theme_model_for[chosen_theme_model] |
|
|
st.sidebar.markdown(f"**Using Model:** \n {selected_model_display_name}") |
|
|
|
|
|
elif model_family == "Object Unlearned": |
|
|
available_class_keys = sorted(class_model_for.keys()) |
|
|
chosen_class_model = st.sidebar.selectbox( |
|
|
"Unlearned object model", |
|
|
available_class_keys, |
|
|
) |
|
|
selected_model_display_name = class_model_for[chosen_class_model] |
|
|
st.sidebar.markdown(f"**Using Model:** \n {selected_model_display_name}") |
|
|
|
|
|
elif model_family == "Other": |
|
|
other_list = sorted(other_models) |
|
|
selected_model_display_name = st.sidebar.selectbox( |
|
|
"Other models", |
|
|
other_list, |
|
|
) |
|
|
|
|
|
st.sidebar.header("Generation settings") |
|
|
seed = st.sidebar.number_input("Random seed", value=256, step=1) |
|
|
steps = 100 |
|
|
cfg_text = 9.0 |
|
|
H = 512 |
|
|
W = 512 |
|
|
ddim_eta = 0.0 |
|
|
|
|
|
|
|
|
prompt_mode = st.radio( |
|
|
"Prompt mode", |
|
|
["Preset Style/Object", "Free Text Prompt"], |
|
|
horizontal=True, |
|
|
) |
|
|
|
|
|
if prompt_mode == "Preset Style/Object": |
|
|
st.subheader("Style") |
|
|
theme = st.pills("Choose style", theme_available) |
|
|
|
|
|
st.subheader("Object") |
|
|
object_class = st.pills("Choose object", class_available) |
|
|
|
|
|
prompt = None |
|
|
if theme and object_class: |
|
|
prompt = f"A {object_class} image in {theme.replace('_', ' ')} style." |
|
|
else: |
|
|
st.subheader("Free Text Prompt") |
|
|
prompt = st.text_area( |
|
|
"Enter your prompt", |
|
|
placeholder="e.g., A beautiful sunset over mountains, digital art", |
|
|
height=100, |
|
|
) |
|
|
theme = None |
|
|
object_class = None |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
if st.button("Generate"): |
|
|
if selected_model_display_name is None: |
|
|
st.error("Please select a model in the sidebar.") |
|
|
elif prompt_mode == "Preset Style/Object": |
|
|
if theme is None: |
|
|
st.error("Please select a style.") |
|
|
elif object_class is None: |
|
|
st.error("Please select an object.") |
|
|
else: |
|
|
with st.spinner("Generating image..."): |
|
|
img, used_prompt = generate_image_single( |
|
|
model_name=selected_model_display_name, |
|
|
prompt=prompt, |
|
|
steps=int(steps), |
|
|
cfg_text=float(cfg_text), |
|
|
seed=int(seed), |
|
|
H=int(H), |
|
|
W=int(W), |
|
|
ddim_eta=float(ddim_eta), |
|
|
) |
|
|
|
|
|
st.image( |
|
|
img, |
|
|
caption=f"Model: {selected_model_display_name} | Prompt: {used_prompt}", |
|
|
) |
|
|
else: |
|
|
if not prompt or not prompt.strip(): |
|
|
st.error("Please enter a prompt.") |
|
|
else: |
|
|
with st.spinner("Generating image..."): |
|
|
img, used_prompt = generate_image_single( |
|
|
model_name=selected_model_display_name, |
|
|
prompt=prompt.strip(), |
|
|
steps=int(steps), |
|
|
cfg_text=float(cfg_text), |
|
|
seed=int(seed), |
|
|
H=int(H), |
|
|
W=int(W), |
|
|
ddim_eta=float(ddim_eta), |
|
|
) |
|
|
|
|
|
st.image( |
|
|
img, |
|
|
caption=f"Model: {selected_model_display_name} | Prompt: {used_prompt}", |
|
|
) |
|
|
|