MetricMogul's picture
Update app.py
74c671b verified
"""
SD 1.5 Lab — pick two conditioning slots (ControlNet or one of several
IP-Adapter variants), supply images, and see how different conditioners
interact during generation.
File layout:
1. CONFIG — registry of available conditioners (ControlNets and IP-Adapter variants)
2. STATE — global cache of loaded models (lazy loading)
3. PREPROCESS — detectors that turn an input image into a condition map
4. PIPELINE — diffusers pipeline assembly based on selected slots
5. GENERATE — main function called by the UI
6. UI — Gradio interface
"""
import gc
import torch
import numpy as np
from PIL import Image
import gradio as gr
from diffusers import (
StableDiffusionControlNetPipeline,
StableDiffusionPipeline,
ControlNetModel,
UniPCMultistepScheduler,
)
# Diagnostic: print versions of key packages on startup so we can
# debug any compatibility issues from the logs.
import diffusers, transformers, huggingface_hub
print(
f"[versions] torch={torch.__version__} "
f"diffusers={diffusers.__version__} "
f"transformers={transformers.__version__} "
f"hf_hub={huggingface_hub.__version__}"
)
try:
import peft
print(f"[versions] peft={peft.__version__}")
except ImportError:
print("[versions] peft=not installed")
# Detectors are imported lazily inside functions to avoid loading
# everything into memory at startup.
# ============================================================
# 1. CONFIG
# ============================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
BASE_MODEL = "runwayml/stable-diffusion-v1-5"
CONTROLNETS = {
"depth": {
"repo": "lllyasviel/sd-controlnet-depth",
"detector_kind": "midas",
},
"normals": {
"repo": "lllyasviel/sd-controlnet-normal",
"detector_kind": "midas_normal",
},
"pose": {
"repo": "lllyasviel/sd-controlnet-openpose",
"detector_kind": "openpose",
},
"lineart": {
"repo": "lllyasviel/control_v11p_sd15_lineart",
"detector_kind": "lineart",
},
"canny": {
"repo": "lllyasviel/sd-controlnet-canny",
"detector_kind": "canny",
},
"scribble": {
"repo": "lllyasviel/sd-controlnet-scribble",
"detector_kind": "scribble",
},
}
# IP-Adapter variants. All share the same repo, differ in weight_name.
# - base: general composition / style transfer
# - plus: sharper detail preservation (uses patch tokens)
# - plus_face: face-focused, better identity from face crops
# - full_face: strongest identity, full-face variant
IP_ADAPTERS = {
"ip_adapter": {
"repo": "h94/IP-Adapter",
"subfolder": "models",
"weight_name": "ip-adapter_sd15.bin",
},
"ip_adapter_plus": {
"repo": "h94/IP-Adapter",
"subfolder": "models",
"weight_name": "ip-adapter-plus_sd15.bin",
},
"ip_adapter_plus_face": {
"repo": "h94/IP-Adapter",
"subfolder": "models",
"weight_name": "ip-adapter-plus-face_sd15.bin",
},
"ip_adapter_full_face": {
"repo": "h94/IP-Adapter",
"subfolder": "models",
"weight_name": "ip-adapter-full-face_sd15.bin",
},
}
SLOT_CHOICES = list(CONTROLNETS.keys()) + list(IP_ADAPTERS.keys()) + ["none"]
# ============================================================
# 2. STATE — lazy model cache
# ============================================================
_controlnet_cache = {}
_detector_cache = {}
def get_controlnet(name):
"""Return a ControlNetModel, downloading on first access."""
if name not in _controlnet_cache:
repo = CONTROLNETS[name]["repo"]
print(f"[load] ControlNet: {name} ({repo})")
_controlnet_cache[name] = ControlNetModel.from_pretrained(
repo, torch_dtype=DTYPE
)
return _controlnet_cache[name]
def get_detector(kind):
"""Return a preprocessor by kind. Loaded lazily."""
if kind in _detector_cache:
return _detector_cache[kind]
print(f"[load] detector: {kind}")
if kind == "midas":
from controlnet_aux import MidasDetector
det = MidasDetector.from_pretrained("lllyasviel/Annotators")
elif kind == "midas_normal":
from controlnet_aux import MidasDetector
det = MidasDetector.from_pretrained("lllyasviel/Annotators")
elif kind == "openpose":
from controlnet_aux import OpenposeDetector
det = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
elif kind == "lineart":
from controlnet_aux import LineartDetector
det = LineartDetector.from_pretrained("lllyasviel/Annotators")
elif kind == "scribble":
from controlnet_aux import HEDdetector
det = HEDdetector.from_pretrained("lllyasviel/Annotators")
elif kind == "canny":
det = "canny"
else:
raise ValueError(f"Unknown detector: {kind}")
_detector_cache[kind] = det
return det
# ============================================================
# 3. PREPROCESS
# ============================================================
def preprocess_for_controlnet(image, cn_name):
"""Run the appropriate detector on the input image. Returns a PIL.Image."""
kind = CONTROLNETS[cn_name]["detector_kind"]
detector = get_detector(kind)
if kind == "canny":
import cv2
arr = np.array(image)
edges = cv2.Canny(arr, 100, 200)
edges = np.stack([edges] * 3, axis=-1)
return Image.fromarray(edges)
if kind == "midas_normal":
result = detector(image, depth_and_normal=True)
return result[1] if isinstance(result, tuple) else result
return detector(image)
# ============================================================
# 4. PIPELINE — assembled per slot configuration
# ============================================================
def build_pipeline(slot1, slot2):
"""
Assemble a pipeline matching the selected slot pair.
Logic:
- Count selected ControlNets (0, 1, or 2)
- If any ControlNet is selected, use StableDiffusionControlNetPipeline
- Otherwise fall back to StableDiffusionPipeline
- Detect which IP-Adapter variant (if any) is selected
- IP-Adapter is loaded BEFORE moving the pipe to device, so its
image_encoder ends up on the correct device.
"""
cn_slots = [s for s in (slot1, slot2) if s in CONTROLNETS]
ip_slot = next((s for s in (slot1, slot2) if s in IP_ADAPTERS), None)
if cn_slots:
controlnet_arg = (
get_controlnet(cn_slots[0]) if len(cn_slots) == 1
else [get_controlnet(n) for n in cn_slots]
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
BASE_MODEL,
controlnet=controlnet_arg,
torch_dtype=DTYPE,
safety_checker=None,
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=DTYPE,
safety_checker=None,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
if ip_slot is not None:
cfg = IP_ADAPTERS[ip_slot]
print(f"[load] IP-Adapter: {ip_slot} ({cfg['weight_name']})")
pipe.load_ip_adapter(
cfg["repo"],
subfolder=cfg["subfolder"],
weight_name=cfg["weight_name"],
)
pipe = pipe.to(DEVICE)
# NOTE: attention_slicing is intentionally disabled here.
# In some diffusers versions it conflicts with IP-Adapter and produces
# the "tuple object has no attribute shape" error. We pay with higher
# peak RAM but get correctness. Re-enable cautiously after verifying
# that IP-Adapter still works in your installed diffusers version.
# if DEVICE == "cpu":
# pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
return pipe, cn_slots, ip_slot
# ============================================================
# 5. GENERATE
# ============================================================
def generate(
prompt,
negative_prompt,
slot1, slot2,
image1, image2,
is_preprocessed1, is_preprocessed2,
weight1, weight2,
steps, guidance, seed,
):
"""
image1/image2 — input image for each slot.
is_preprocessed1/2 — checkbox: skip detector if the image is already a
condition map.
weight1/2 — controlnet_conditioning_scale or ip_adapter_scale.
"""
if slot1 == "none" and slot2 == "none":
return None, "Both slots are empty — pick at least one conditioner."
# 1. Prepare conditioning inputs
cn_images = []
cn_weights = []
ip_image = None
ip_weight = 1.0
for slot, img, is_pre, w in [
(slot1, image1, is_preprocessed1, weight1),
(slot2, image2, is_preprocessed2, weight2),
]:
if slot == "none":
continue
if img is None:
return None, f"Slot '{slot}' is selected but no image was provided."
img = img.convert("RGB").resize((512, 512))
if slot in CONTROLNETS:
cond = img if is_pre else preprocess_for_controlnet(img, slot)
cond = cond.resize((512, 512))
cn_images.append(cond)
cn_weights.append(float(w))
elif slot in IP_ADAPTERS:
ip_image = img
ip_weight = float(w)
# 2. Build pipeline
pipe, cn_slots, ip_slot = build_pipeline(slot1, slot2)
if ip_slot is not None:
pipe.set_ip_adapter_scale(ip_weight)
# 3. Call arguments
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
call_kwargs = dict(
prompt=prompt,
negative_prompt=negative_prompt or None,
num_inference_steps=int(steps),
guidance_scale=float(guidance),
generator=generator,
height=512,
width=512,
)
if cn_images:
call_kwargs["image"] = cn_images[0] if len(cn_images) == 1 else cn_images
call_kwargs["controlnet_conditioning_scale"] = (
cn_weights[0] if len(cn_weights) == 1 else cn_weights
)
if ip_slot is not None:
# Pass raw PIL image; diffusers handles CLIP encoding internally.
# This is the documented API path on huggingface.co/docs/diffusers.
call_kwargs["ip_adapter_image"] = ip_image
# 4. Run generation
result = pipe(**call_kwargs).images[0]
# 5. Cleanup
del pipe
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
info = (
f"Slots: {slot1} ({weight1}) + {slot2} ({weight2}) | "
f"steps={steps}, cfg={guidance}, seed={seed}, device={DEVICE}"
)
return result, info
# ============================================================
# 6. UI
# ============================================================
def make_slot_ui(slot_idx):
"""One slot: type dropdown, image input, preprocessed checkbox, weight slider."""
with gr.Group():
gr.Markdown(f"### Slot {slot_idx}")
slot_type = gr.Dropdown(
choices=SLOT_CHOICES,
value="none",
label="Conditioner type",
)
image = gr.Image(type="pil", label="Image")
is_preprocessed = gr.Checkbox(
label="Already a condition map (skip detector)",
value=False,
)
weight = gr.Slider(
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
label="Weight (conditioning scale)",
)
return slot_type, image, is_preprocessed, weight
with gr.Blocks(title="SD 1.5 Lab — dual conditioner playground") as demo:
gr.Markdown(
"# SD 1.5 Lab\n"
"Pick two conditioners (ControlNet or IP-Adapter variant), supply images, "
"and see how they combine.\n\n"
"**IP-Adapter variants:** `ip_adapter` (general), `ip_adapter_plus` "
"(sharper detail), `ip_adapter_plus_face` (face-focused), "
"`ip_adapter_full_face` (strong identity).\n\n"
f"Current device: **{DEVICE}**. On CPU, generating a 512×512 image "
"takes roughly 15–30 minutes — that's expected."
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a photo of a cat in a garden")
negative_prompt = gr.Textbox(label="Negative prompt", value="")
with gr.Row():
steps = gr.Slider(5, 50, value=20, step=1, label="Steps")
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG")
seed = gr.Number(value=42, label="Seed", precision=0)
with gr.Row():
s1_type, s1_img, s1_pre, s1_w = make_slot_ui(1)
s2_type, s2_img, s2_pre, s2_w = make_slot_ui(2)
run_btn = gr.Button("Generate", variant="primary")
output_img = gr.Image(label="Result")
info = gr.Textbox(label="Info", interactive=False)
run_btn.click(
fn=generate,
inputs=[
prompt, negative_prompt,
s1_type, s2_type,
s1_img, s2_img,
s1_pre, s2_pre,
s1_w, s2_w,
steps, guidance, seed,
],
outputs=[output_img, info],
)
if __name__ == "__main__":
demo.launch()