ICGenAIShare07's picture
Upload app.py
3d2b58b verified
import os
from dataclasses import dataclass
from PIL import Image
import cv2
import numpy as np
import gradio as gr
import torch
import spaces # type: ignore
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.models.controlnets.controlnet import ControlNetModel
from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from transformers import CLIPTextModel, CLIPTokenizer
BIG_CSS = """
/* Global bump */
.gradio-container {
font-size: 18px !important;
}
/* Force most UI text bigger */
.gradio-container * {
font-size: 18px !important;
}
/* Keep markdown headings bigger */
.gradio-container h1 { font-size: 28px !important; }
.gradio-container h2 { font-size: 24px !important; }
.gradio-container h3 { font-size: 20px !important; }
/* Slightly smaller helper/info text if you want */
.gradio-container .info,
.gradio-container .prose p,
.gradio-container .prose li {
font-size: 16px !important;
line-height: 1.35 !important;
}
"""
# -----------------------------
# Pipeline builder
# -----------------------------
def build_controlnet_pipe(
base_model_name: str,
controlnet: ControlNetModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
device: torch.device,
weight_dtype: torch.dtype,
use_unipc: bool = True,
) -> StableDiffusionControlNetPipeline:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_name,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
torch_dtype=weight_dtype,
)
if use_unipc:
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=True)
return pipe
@dataclass
class CannyCFG:
use_clahe: bool = True
clahe_clip: float = 2.0
clahe_grid: int = 8
gaussian_ksize: int = 5
gaussian_sigma: float = 1.2
high_pct: float = 90.0 # higher -> fewer edges (stricter)
low_ratio: float = 0.4 # low = low_ratio * high
aperture_size: int = 3
l2_gradient: bool = True
def canny_percentile(pil_img: Image.Image, cfg: CannyCFG) -> Image.Image:
gray = np.array(pil_img.convert("L"), dtype=np.uint8)
if cfg.use_clahe:
clahe = cv2.createCLAHE(
clipLimit=float(cfg.clahe_clip),
tileGridSize=(int(cfg.clahe_grid), int(cfg.clahe_grid)),
)
gray = clahe.apply(gray)
k = int(cfg.gaussian_ksize) | 1 # ensure odd
blur = cv2.GaussianBlur(gray, (k, k), float(cfg.gaussian_sigma))
gx = cv2.Sobel(blur, cv2.CV_32F, 1, 0, ksize=3)
gy = cv2.Sobel(blur, cv2.CV_32F, 0, 1, ksize=3)
mag = cv2.magnitude(gx, gy)
high = float(np.percentile(mag, float(cfg.high_pct)))
low = float(cfg.low_ratio) * high
if high <= low:
high = low + 1.0
ap = int(cfg.aperture_size)
if ap not in (3, 5, 7):
ap = 3
edges = cv2.Canny(
blur,
threshold1=low,
threshold2=high,
apertureSize=ap,
L2gradient=bool(cfg.l2_gradient),
)
return Image.fromarray(edges, mode="L")
# -----------------------------
# Config
# -----------------------------
BASE_MODEL = "sd-legacy/stable-diffusion-v1-5"
WEIGHTS_REPO = "mvp-lab/ControlNet_Weight"
WEIGHTS_FILENAME = "diffusion_pytorch_model_1.safetensors"
LOCAL_WEIGHTS = os.getenv(
"CONTROLNET_WEIGHTS",
"/home/nik/ImperialWork/GenerativeAi/sd15-controlnet-trainer/controlnet_laion/final/diffusion_pytorch_model.safetensors",
)
if os.path.isfile(LOCAL_WEIGHTS):
CONTROLNET_PATH = LOCAL_WEIGHTS
else:
CONTROLNET_PATH = hf_hub_download(repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILENAME, repo_type="model")
DTYPE = torch.float32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------
# Model load (once)
# -----------------------------
vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE)
unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet", torch_dtype=DTYPE)
tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder", torch_dtype=DTYPE)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
controlnet = ControlNetModel.from_unet(unet, conditioning_channels=3)
state = load_file(CONTROLNET_PATH)
missing, unexpected = controlnet.load_state_dict(state, strict=False)
pipe = build_controlnet_pipe(
base_model_name=BASE_MODEL,
controlnet=controlnet,
vae=vae,
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=DEVICE,
weight_dtype=DTYPE,
use_unipc=True,
)
# -----------------------------
# Helpers: fixed resize policy (longest side = 512, keep aspect, divisible by 8)
# -----------------------------
def round_down_to_multiple(x: int, m: int = 8) -> int:
return max(m, (x // m) * m)
def resize_longest_side_div8(img: Image.Image, longest: int = 512) -> tuple[Image.Image, int, int]:
w, h = img.size
if w <= 0 or h <= 0:
raise ValueError("Invalid image size")
scale = float(longest) / float(max(w, h))
tw = int(round(w * scale))
th = int(round(h * scale))
tw = round_down_to_multiple(tw, 8)
th = round_down_to_multiple(th, 8)
tw = max(8, tw)
th = max(8, th)
resized = img.resize((tw, th), resample=Image.BICUBIC) # type: ignore
return resized, tw, th
def compute_canny_rgb(img_rgb_resized: Image.Image, use_clahe: bool, edge_amount: float, smoothing: float) -> Image.Image:
high_pct = 95.0 - 20.0 * float(edge_amount) # 0 => 95 (few), 1 => 75 (many)
high_pct = float(np.clip(high_pct, 70.0, 99.0))
gaussian_sigma = 0.6 + 2.2 * float(smoothing) # 0 => 0.6, 1 => 2.8
cfg = CannyCFG(
use_clahe=bool(use_clahe),
clahe_clip=2.0,
clahe_grid=8,
gaussian_ksize=5,
gaussian_sigma=float(gaussian_sigma),
high_pct=float(high_pct),
low_ratio=0.4,
aperture_size=3,
l2_gradient=True,
)
edges_l = canny_percentile(img_rgb_resized, cfg)
return edges_l.convert("RGB")
def update_canny_preview(input_image, use_clahe, edge_amount, smoothing):
if input_image is None:
return None, None, 512, 512
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
img_rgb0 = input_image.convert("RGB")
img_rgb, width, height = resize_longest_side_div8(img_rgb0, longest=512)
canny = compute_canny_rgb(
img_rgb,
use_clahe=use_clahe,
edge_amount=float(edge_amount),
smoothing=float(smoothing),
)
return canny, canny, width, height
@spaces.GPU
@torch.inference_mode()
def generate_from_canny(
canny: Image.Image,
width: int,
height: int,
prompt: str,
negative_prompt: str,
guidance_scale: float,
num_inference_steps: int,
num_images: int,
controlnet_conditioning_scale: float,
):
if canny is None:
raise gr.Error("Canny conditioning image missing. Upload an image first.")
if int(num_images) < 1:
raise gr.Error("num_images must be >= 1")
gens = [torch.Generator(device=DEVICE).manual_seed(i) for i in range(int(num_images))]
imgs = pipe(
prompt=[prompt] * int(num_images),
negative_prompt=[negative_prompt] * int(num_images),
image=[canny] * int(num_images),
num_inference_steps=int(num_inference_steps),
guidance_scale=float(guidance_scale),
height=int(height),
width=int(width),
generator=gens,
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
).images # type: ignore
first = imgs[0] if imgs else None
return first, imgs
def next_image(images, idx):
if not images:
return None, 0, "0 / 0"
idx = (int(idx) + 1) % len(images)
return images[idx], idx, f"{idx + 1} / {len(images)}"
def prev_image(images, idx):
if not images:
return None, 0, "0 / 0"
idx = (int(idx) - 1) % len(images)
return images[idx], idx, f"{idx + 1} / {len(images)}"
# -----------------------------
# UI
# -----------------------------
IMG_H = 360 # uniform-ish size for both preview boxes
with gr.Blocks(css=BIG_CSS) as demo:
gr.Markdown("# Canny-Edge ControlNet Demo")
gr.Markdown("**Note:** Trained on aesthetic/artistic images — best results come from similar, stylised inputs.")
# state
canny_state = gr.State(None)
width_state = gr.State(512)
height_state = gr.State(512)
gen_images_state = gr.State([]) # list[PIL]
gen_index_state = gr.State(0)
with gr.Row():
# ---- Left: Canny + Canny controls ----
with gr.Column(scale=1):
input_image = gr.Image(
label="Input Image",
type="pil",
image_mode="RGB",
height=IMG_H,
)
canny_preview = gr.Image(
label="Canny edges",
type="pil",
height=IMG_H,
)
gr.Markdown("### Edge controls")
use_clahe = gr.Checkbox(
label="Stabilise contrast (CLAHE)",
value=True,
info="Helps edges stay consistent under different lighting/contrast.",
)
edge_amount = gr.Slider(
label="Edge Amount",
minimum=0.0, maximum=1.0, value=0.6, step=0.01,
info="More = detect more edges (more detail). Less = cleaner outline.",
)
smoothing = gr.Slider(
label="Smoothing",
minimum=0.0, maximum=1.0, value=0.4, step=0.01,
info="More = reduce tiny texture/noise edges, cleaner structure.",
)
# ---- Right: Generated output + generation controls ----
with gr.Column(scale=1):
generated = gr.Image(
label="Generated image",
type="pil",
height=IMG_H,
)
with gr.Row():
prev_btn = gr.Button("◀ Prev")
page_label = gr.Markdown("0 / 0")
next_btn = gr.Button("Next ▶")
gr.Markdown("### Generation controls")
positive_prompt = gr.Textbox(
label="Positive Prompt",
value="",
lines=2,
info="Describe what you want. The edges guide the structure.",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="",
lines=2,
info="Things to avoid (e.g. blurry, deformed, low quality).",
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0, maximum=15.0, value=7.5, step=0.1,
info="Higher = follow text prompt more strongly (can drift from edges).",
)
controlnet_conditioning_scale = gr.Slider(
label="Control Strength",
minimum=0.0, maximum=2.0, value=1.0, step=0.05,
info="Higher = follow edges more strongly. Too high can reduce creativity.",
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Steps",
minimum=10, maximum=80, value=50, step=1,
info="More steps can improve quality but is slower.",
)
num_images = gr.Slider(
label="Samples",
minimum=1, maximum=8, value=4, step=1,
info="How many images to generate.",
)
run_btn = gr.Button("Generate", variant="primary")
# Auto-update Canny preview on changes (CPU)
auto_inputs = [input_image, use_clahe, edge_amount, smoothing]
for c in auto_inputs:
c.change(
fn=update_canny_preview,
inputs=auto_inputs,
outputs=[canny_preview, canny_state, width_state, height_state],
)
# Generate (GPU) -> store list -> show first -> update paging label
run_btn.click(
fn=generate_from_canny,
inputs=[
canny_state,
width_state,
height_state,
positive_prompt,
negative_prompt,
guidance_scale,
num_inference_steps,
num_images,
controlnet_conditioning_scale,
],
outputs=[generated, gen_images_state], # visible output first => proper "Generating..." UX
).then(
fn=lambda imgs: (0, f"1 / {len(imgs)}") if imgs else (0, "0 / 0"),
inputs=[gen_images_state],
outputs=[gen_index_state, page_label],
)
# Paging buttons (CPU)
next_btn.click(
fn=next_image,
inputs=[gen_images_state, gen_index_state],
outputs=[generated, gen_index_state, page_label],
)
prev_btn.click(
fn=prev_image,
inputs=[gen_images_state, gen_index_state],
outputs=[generated, gen_index_state, page_label],
)
if __name__ == "__main__":
demo.launch()