Spaces:
Sleeping
Sleeping
| import os | |
| from dataclasses import dataclass | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import spaces # type: ignore | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL | |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel | |
| from diffusers.models.controlnets.controlnet import ControlNetModel | |
| from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline | |
| from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| BIG_CSS = """ | |
| /* Global bump */ | |
| .gradio-container { | |
| font-size: 18px !important; | |
| } | |
| /* Force most UI text bigger */ | |
| .gradio-container * { | |
| font-size: 18px !important; | |
| } | |
| /* Keep markdown headings bigger */ | |
| .gradio-container h1 { font-size: 28px !important; } | |
| .gradio-container h2 { font-size: 24px !important; } | |
| .gradio-container h3 { font-size: 20px !important; } | |
| /* Slightly smaller helper/info text if you want */ | |
| .gradio-container .info, | |
| .gradio-container .prose p, | |
| .gradio-container .prose li { | |
| font-size: 16px !important; | |
| line-height: 1.35 !important; | |
| } | |
| """ | |
| # ----------------------------- | |
| # Pipeline builder | |
| # ----------------------------- | |
| def build_controlnet_pipe( | |
| base_model_name: str, | |
| controlnet: ControlNetModel, | |
| vae: AutoencoderKL, | |
| unet: UNet2DConditionModel, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| device: torch.device, | |
| weight_dtype: torch.dtype, | |
| use_unipc: bool = True, | |
| ) -> StableDiffusionControlNetPipeline: | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| base_model_name, | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| controlnet=controlnet, | |
| safety_checker=None, | |
| torch_dtype=weight_dtype, | |
| ) | |
| if use_unipc: | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to(device) | |
| pipe.set_progress_bar_config(disable=True) | |
| return pipe | |
| class CannyCFG: | |
| use_clahe: bool = True | |
| clahe_clip: float = 2.0 | |
| clahe_grid: int = 8 | |
| gaussian_ksize: int = 5 | |
| gaussian_sigma: float = 1.2 | |
| high_pct: float = 90.0 # higher -> fewer edges (stricter) | |
| low_ratio: float = 0.4 # low = low_ratio * high | |
| aperture_size: int = 3 | |
| l2_gradient: bool = True | |
| def canny_percentile(pil_img: Image.Image, cfg: CannyCFG) -> Image.Image: | |
| gray = np.array(pil_img.convert("L"), dtype=np.uint8) | |
| if cfg.use_clahe: | |
| clahe = cv2.createCLAHE( | |
| clipLimit=float(cfg.clahe_clip), | |
| tileGridSize=(int(cfg.clahe_grid), int(cfg.clahe_grid)), | |
| ) | |
| gray = clahe.apply(gray) | |
| k = int(cfg.gaussian_ksize) | 1 # ensure odd | |
| blur = cv2.GaussianBlur(gray, (k, k), float(cfg.gaussian_sigma)) | |
| gx = cv2.Sobel(blur, cv2.CV_32F, 1, 0, ksize=3) | |
| gy = cv2.Sobel(blur, cv2.CV_32F, 0, 1, ksize=3) | |
| mag = cv2.magnitude(gx, gy) | |
| high = float(np.percentile(mag, float(cfg.high_pct))) | |
| low = float(cfg.low_ratio) * high | |
| if high <= low: | |
| high = low + 1.0 | |
| ap = int(cfg.aperture_size) | |
| if ap not in (3, 5, 7): | |
| ap = 3 | |
| edges = cv2.Canny( | |
| blur, | |
| threshold1=low, | |
| threshold2=high, | |
| apertureSize=ap, | |
| L2gradient=bool(cfg.l2_gradient), | |
| ) | |
| return Image.fromarray(edges, mode="L") | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| BASE_MODEL = "sd-legacy/stable-diffusion-v1-5" | |
| WEIGHTS_REPO = "mvp-lab/ControlNet_Weight" | |
| WEIGHTS_FILENAME = "diffusion_pytorch_model_1.safetensors" | |
| LOCAL_WEIGHTS = os.getenv( | |
| "CONTROLNET_WEIGHTS", | |
| "/home/nik/ImperialWork/GenerativeAi/sd15-controlnet-trainer/controlnet_laion/final/diffusion_pytorch_model.safetensors", | |
| ) | |
| if os.path.isfile(LOCAL_WEIGHTS): | |
| CONTROLNET_PATH = LOCAL_WEIGHTS | |
| else: | |
| CONTROLNET_PATH = hf_hub_download(repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILENAME, repo_type="model") | |
| DTYPE = torch.float32 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------------- | |
| # Model load (once) | |
| # ----------------------------- | |
| vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE) | |
| unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet", torch_dtype=DTYPE) | |
| tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder", torch_dtype=DTYPE) | |
| vae.requires_grad_(False) | |
| unet.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| controlnet = ControlNetModel.from_unet(unet, conditioning_channels=3) | |
| state = load_file(CONTROLNET_PATH) | |
| missing, unexpected = controlnet.load_state_dict(state, strict=False) | |
| pipe = build_controlnet_pipe( | |
| base_model_name=BASE_MODEL, | |
| controlnet=controlnet, | |
| vae=vae, | |
| unet=unet, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| device=DEVICE, | |
| weight_dtype=DTYPE, | |
| use_unipc=True, | |
| ) | |
| # ----------------------------- | |
| # Helpers: fixed resize policy (longest side = 512, keep aspect, divisible by 8) | |
| # ----------------------------- | |
| def round_down_to_multiple(x: int, m: int = 8) -> int: | |
| return max(m, (x // m) * m) | |
| def resize_longest_side_div8(img: Image.Image, longest: int = 512) -> tuple[Image.Image, int, int]: | |
| w, h = img.size | |
| if w <= 0 or h <= 0: | |
| raise ValueError("Invalid image size") | |
| scale = float(longest) / float(max(w, h)) | |
| tw = int(round(w * scale)) | |
| th = int(round(h * scale)) | |
| tw = round_down_to_multiple(tw, 8) | |
| th = round_down_to_multiple(th, 8) | |
| tw = max(8, tw) | |
| th = max(8, th) | |
| resized = img.resize((tw, th), resample=Image.BICUBIC) # type: ignore | |
| return resized, tw, th | |
| def compute_canny_rgb(img_rgb_resized: Image.Image, use_clahe: bool, edge_amount: float, smoothing: float) -> Image.Image: | |
| high_pct = 95.0 - 20.0 * float(edge_amount) # 0 => 95 (few), 1 => 75 (many) | |
| high_pct = float(np.clip(high_pct, 70.0, 99.0)) | |
| gaussian_sigma = 0.6 + 2.2 * float(smoothing) # 0 => 0.6, 1 => 2.8 | |
| cfg = CannyCFG( | |
| use_clahe=bool(use_clahe), | |
| clahe_clip=2.0, | |
| clahe_grid=8, | |
| gaussian_ksize=5, | |
| gaussian_sigma=float(gaussian_sigma), | |
| high_pct=float(high_pct), | |
| low_ratio=0.4, | |
| aperture_size=3, | |
| l2_gradient=True, | |
| ) | |
| edges_l = canny_percentile(img_rgb_resized, cfg) | |
| return edges_l.convert("RGB") | |
| def update_canny_preview(input_image, use_clahe, edge_amount, smoothing): | |
| if input_image is None: | |
| return None, None, 512, 512 | |
| if not isinstance(input_image, Image.Image): | |
| input_image = Image.fromarray(input_image) | |
| img_rgb0 = input_image.convert("RGB") | |
| img_rgb, width, height = resize_longest_side_div8(img_rgb0, longest=512) | |
| canny = compute_canny_rgb( | |
| img_rgb, | |
| use_clahe=use_clahe, | |
| edge_amount=float(edge_amount), | |
| smoothing=float(smoothing), | |
| ) | |
| return canny, canny, width, height | |
| def generate_from_canny( | |
| canny: Image.Image, | |
| width: int, | |
| height: int, | |
| prompt: str, | |
| negative_prompt: str, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| num_images: int, | |
| controlnet_conditioning_scale: float, | |
| ): | |
| if canny is None: | |
| raise gr.Error("Canny conditioning image missing. Upload an image first.") | |
| if int(num_images) < 1: | |
| raise gr.Error("num_images must be >= 1") | |
| gens = [torch.Generator(device=DEVICE).manual_seed(i) for i in range(int(num_images))] | |
| imgs = pipe( | |
| prompt=[prompt] * int(num_images), | |
| negative_prompt=[negative_prompt] * int(num_images), | |
| image=[canny] * int(num_images), | |
| num_inference_steps=int(num_inference_steps), | |
| guidance_scale=float(guidance_scale), | |
| height=int(height), | |
| width=int(width), | |
| generator=gens, | |
| controlnet_conditioning_scale=float(controlnet_conditioning_scale), | |
| ).images # type: ignore | |
| first = imgs[0] if imgs else None | |
| return first, imgs | |
| def next_image(images, idx): | |
| if not images: | |
| return None, 0, "0 / 0" | |
| idx = (int(idx) + 1) % len(images) | |
| return images[idx], idx, f"{idx + 1} / {len(images)}" | |
| def prev_image(images, idx): | |
| if not images: | |
| return None, 0, "0 / 0" | |
| idx = (int(idx) - 1) % len(images) | |
| return images[idx], idx, f"{idx + 1} / {len(images)}" | |
| # ----------------------------- | |
| # UI | |
| # ----------------------------- | |
| IMG_H = 360 # uniform-ish size for both preview boxes | |
| with gr.Blocks(css=BIG_CSS) as demo: | |
| gr.Markdown("# Canny-Edge ControlNet Demo") | |
| gr.Markdown("**Note:** Trained on aesthetic/artistic images — best results come from similar, stylised inputs.") | |
| # state | |
| canny_state = gr.State(None) | |
| width_state = gr.State(512) | |
| height_state = gr.State(512) | |
| gen_images_state = gr.State([]) # list[PIL] | |
| gen_index_state = gr.State(0) | |
| with gr.Row(): | |
| # ---- Left: Canny + Canny controls ---- | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| image_mode="RGB", | |
| height=IMG_H, | |
| ) | |
| canny_preview = gr.Image( | |
| label="Canny edges", | |
| type="pil", | |
| height=IMG_H, | |
| ) | |
| gr.Markdown("### Edge controls") | |
| use_clahe = gr.Checkbox( | |
| label="Stabilise contrast (CLAHE)", | |
| value=True, | |
| info="Helps edges stay consistent under different lighting/contrast.", | |
| ) | |
| edge_amount = gr.Slider( | |
| label="Edge Amount", | |
| minimum=0.0, maximum=1.0, value=0.6, step=0.01, | |
| info="More = detect more edges (more detail). Less = cleaner outline.", | |
| ) | |
| smoothing = gr.Slider( | |
| label="Smoothing", | |
| minimum=0.0, maximum=1.0, value=0.4, step=0.01, | |
| info="More = reduce tiny texture/noise edges, cleaner structure.", | |
| ) | |
| # ---- Right: Generated output + generation controls ---- | |
| with gr.Column(scale=1): | |
| generated = gr.Image( | |
| label="Generated image", | |
| type="pil", | |
| height=IMG_H, | |
| ) | |
| with gr.Row(): | |
| prev_btn = gr.Button("◀ Prev") | |
| page_label = gr.Markdown("0 / 0") | |
| next_btn = gr.Button("Next ▶") | |
| gr.Markdown("### Generation controls") | |
| positive_prompt = gr.Textbox( | |
| label="Positive Prompt", | |
| value="", | |
| lines=2, | |
| info="Describe what you want. The edges guide the structure.", | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="", | |
| lines=2, | |
| info="Things to avoid (e.g. blurry, deformed, low quality).", | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, maximum=15.0, value=7.5, step=0.1, | |
| info="Higher = follow text prompt more strongly (can drift from edges).", | |
| ) | |
| controlnet_conditioning_scale = gr.Slider( | |
| label="Control Strength", | |
| minimum=0.0, maximum=2.0, value=1.0, step=0.05, | |
| info="Higher = follow edges more strongly. Too high can reduce creativity.", | |
| ) | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider( | |
| label="Steps", | |
| minimum=10, maximum=80, value=50, step=1, | |
| info="More steps can improve quality but is slower.", | |
| ) | |
| num_images = gr.Slider( | |
| label="Samples", | |
| minimum=1, maximum=8, value=4, step=1, | |
| info="How many images to generate.", | |
| ) | |
| run_btn = gr.Button("Generate", variant="primary") | |
| # Auto-update Canny preview on changes (CPU) | |
| auto_inputs = [input_image, use_clahe, edge_amount, smoothing] | |
| for c in auto_inputs: | |
| c.change( | |
| fn=update_canny_preview, | |
| inputs=auto_inputs, | |
| outputs=[canny_preview, canny_state, width_state, height_state], | |
| ) | |
| # Generate (GPU) -> store list -> show first -> update paging label | |
| run_btn.click( | |
| fn=generate_from_canny, | |
| inputs=[ | |
| canny_state, | |
| width_state, | |
| height_state, | |
| positive_prompt, | |
| negative_prompt, | |
| guidance_scale, | |
| num_inference_steps, | |
| num_images, | |
| controlnet_conditioning_scale, | |
| ], | |
| outputs=[generated, gen_images_state], # visible output first => proper "Generating..." UX | |
| ).then( | |
| fn=lambda imgs: (0, f"1 / {len(imgs)}") if imgs else (0, "0 / 0"), | |
| inputs=[gen_images_state], | |
| outputs=[gen_index_state, page_label], | |
| ) | |
| # Paging buttons (CPU) | |
| next_btn.click( | |
| fn=next_image, | |
| inputs=[gen_images_state, gen_index_state], | |
| outputs=[generated, gen_index_state, page_label], | |
| ) | |
| prev_btn.click( | |
| fn=prev_image, | |
| inputs=[gen_images_state, gen_index_state], | |
| outputs=[generated, gen_index_state, page_label], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |