Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| import subprocess | |
| from tqdm import tqdm | |
| import requests | |
| import einops | |
| import math | |
| import random | |
| import pytorch_lightning as pl | |
| import spaces | |
| def download_file(url, filename): | |
| response = requests.get(url, stream=True) | |
| total_size = int(response.headers.get('content-length', 0)) | |
| block_size = 1024 | |
| with open(filename, 'wb') as file, tqdm( | |
| desc=filename, | |
| total=total_size, | |
| unit='iB', | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as progress_bar: | |
| for data in response.iter_content(block_size): | |
| size = file.write(data) | |
| progress_bar.update(size) | |
| def setup_environment(): | |
| if not os.path.exists("CCSR"): | |
| print("Cloning CCSR repository...") | |
| subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git"]) | |
| os.chdir("CCSR") | |
| sys.path.append(os.getcwd()) | |
| os.makedirs("weights", exist_ok=True) | |
| if not os.path.exists("weights/real-world_ccsr.ckpt"): | |
| print("Downloading model checkpoint...") | |
| download_file( | |
| "https://huggingface.co/camenduru/CCSR/resolve/main/real-world_ccsr.ckpt", | |
| "weights/real-world_ccsr.ckpt" | |
| ) | |
| else: | |
| print("Model checkpoint already exists. Skipping download.") | |
| setup_environment() | |
| from ldm.xformers_state import disable_xformers | |
| from model.q_sampler import SpacedSampler | |
| from model.ccsr_stage1 import ControlLDM | |
| from utils.common import instantiate_from_config, load_state_dict | |
| from utils.image import auto_resize | |
| config = OmegaConf.load("configs/model/ccsr_stage2.yaml") | |
| model = instantiate_from_config(config) | |
| ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu") | |
| load_state_dict(model, ckpt, strict=True) | |
| model.freeze() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| def process( | |
| control_img: Image.Image, | |
| num_samples: int, | |
| sr_scale: float, | |
| strength: float, | |
| positive_prompt: str, | |
| negative_prompt: str, | |
| cfg_scale: float, | |
| steps: int, | |
| use_color_fix: bool, | |
| seed: int, | |
| tile_diffusion: bool, | |
| tile_diffusion_size: int, | |
| tile_diffusion_stride: int | |
| ): | |
| print(f"control image shape={control_img.size}\n" | |
| f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n" | |
| f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n" | |
| f"cfg scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n" | |
| f"seed={seed}\n" | |
| f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}") | |
| pl.seed_everything(seed) | |
| # Resize input image | |
| if sr_scale != 1: | |
| control_img = control_img.resize( | |
| tuple(math.ceil(x * sr_scale) for x in control_img.size), | |
| Image.BICUBIC | |
| ) | |
| input_size = control_img.size | |
| # Resize the image | |
| if not tile_diffusion: | |
| control_img = auto_resize(control_img, 512) | |
| else: | |
| control_img = auto_resize(control_img, tile_diffusion_size) | |
| # Resize image to be multiples of 64 | |
| control_img = control_img.resize( | |
| tuple((s // 64 + 1) * 64 for s in control_img.size), Image.LANCZOS | |
| ) | |
| control_img = np.array(control_img) | |
| # Convert to tensor (NCHW, [0,1]) | |
| control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=device).clamp_(0, 1) | |
| control = einops.rearrange(control, "n h w c -> n c h w").contiguous() | |
| height, width = control.size(-2), control.size(-1) | |
| model.control_scales = [strength] * 13 | |
| # Move model and tensors to GPU if available | |
| if torch.cuda.is_available(): | |
| model.to("cuda") | |
| control = control.to("cuda") | |
| sampler = SpacedSampler(model, var_type="fixed_small") | |
| preds = [] | |
| for _ in tqdm(range(num_samples)): | |
| shape = (1, 4, height // 8, width // 8) | |
| x_T = torch.randn(shape, device=device, dtype=torch.float32) | |
| if torch.cuda.is_available(): | |
| x_T = x_T.to("cuda") | |
| if not tile_diffusion: | |
| samples = sampler.sample_ccsr( | |
| steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control, | |
| positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, | |
| cfg_scale=cfg_scale, | |
| color_fix_type="adain" if use_color_fix else "none" | |
| ) | |
| else: | |
| samples = sampler.sample_with_tile_ccsr( | |
| tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride, | |
| steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control, | |
| positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, | |
| cfg_scale=cfg_scale, | |
| color_fix_type="adain" if use_color_fix else "none" | |
| ) | |
| x_samples = samples.clamp(0, 1) | |
| x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) | |
| img = Image.fromarray(x_samples[0, ...]).resize(input_size, Image.LANCZOS) | |
| preds.append(np.array(img)) | |
| return preds | |
| def update_output_resolution(image, scale_choice, custom_scale): | |
| if image is not None: | |
| width, height = image.size | |
| if scale_choice == "Custom": | |
| scale = custom_scale | |
| elif "%" in scale_choice: | |
| scale = float(scale_choice.split()[-1].strip("()%")) / 100 | |
| else: | |
| scale = float(scale_choice.split()[-1].strip("()x")) | |
| return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}" | |
| return "Upload an image to see the output resolution" | |
| def update_scale_choices(image): | |
| if image is not None: | |
| width, height = image.size | |
| aspect_ratio = width / height | |
| common_resolutions = [ | |
| (1280, 720), (1920, 1080), (2560, 1440), (3840, 2160), # 16:9 | |
| (1440, 1440), (2048, 2048), (2560, 2560), (3840, 3840) # 1:1 | |
| ] | |
| choices = [] | |
| for w, h in common_resolutions: | |
| if abs(w/h - aspect_ratio) < 0.1: # Allow some tolerance for aspect ratio | |
| scale = max(w/width, h/height) | |
| if scale > 1: | |
| choices.append(f"{w}x{h} ({scale:.2f}x)") | |
| if not choices: # If no common resolutions fit, use percentage-based options | |
| choices = [ | |
| f"{width*2}x{height*2} (200%)", | |
| f"{width*4}x{height*4} (400%)", | |
| f"{width*8}x{height*8} (800%)" | |
| ] | |
| choices.append("Custom") | |
| return gr.update(choices=choices, value=choices[0]) | |
| return gr.update(choices=["Custom"], value="Custom") | |
| # Improved UI design | |
| css = """ | |
| .container {max-width: 1200px; margin: auto; padding: 20px;} | |
| .input-image {width: 100%; max-height: 500px; object-fit: contain;} | |
| .output-gallery {display: flex; flex-wrap: wrap; justify-content: center;} | |
| .output-image {margin: 10px; max-width: 45%; height: auto;} | |
| .gr-form {border: 1px solid #e0e0e0; border-radius: 8px; padding: 16px; margin-bottom: 16px;} | |
| """ | |
| with gr.Blocks(css=css) as block: | |
| gr.HTML("<h1 style='text-align: center;'>CCSR Upscaler</h1>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Input Image", elem_classes="input-image") | |
| sr_scale = gr.Dropdown( | |
| label="Output Resolution", | |
| choices=["Custom"], | |
| value="Custom", | |
| interactive=True | |
| ) | |
| custom_scale = gr.Slider( | |
| label="Custom Scale", | |
| minimum=1, | |
| maximum=8, | |
| value=4, | |
| step=0.1, | |
| visible=True | |
| ) | |
| output_resolution = gr.Markdown("Upload an image to see the output resolution") | |
| run_button = gr.Button(value="Run", variant="primary") | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Advanced Options", open=False): | |
| num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1) | |
| strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
| positive_prompt = gr.Textbox(label="Positive Prompt", value="") | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" | |
| ) | |
| cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1) | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1) | |
| use_color_fix = gr.Checkbox(label="Use Color Correction", value=True) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231) | |
| tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False) | |
| tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256) | |
| tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128) | |
| with gr.Row(): | |
| result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery") | |
| def update_custom_scale(choice): | |
| return gr.update(visible=choice == "Custom") | |
| sr_scale.change(update_custom_scale, inputs=[sr_scale], outputs=[custom_scale]) | |
| def get_scale_value(choice, custom): | |
| if choice == "Custom": | |
| return custom | |
| if "%" in choice: | |
| return float(choice.split()[-1].strip("()%")) / 100 | |
| return float(choice.split()[-1].strip("()x")) | |
| inputs = [ | |
| input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt, | |
| cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size, | |
| tile_diffusion_stride | |
| ] | |
| run_button.click( | |
| fn=lambda *args: process(args[0], args[1], get_scale_value(args[2], args[-1]), *args[3:-1]), | |
| inputs=inputs + [custom_scale], | |
| outputs=[result_gallery] | |
| ) | |
| input_image.change( | |
| update_scale_choices, | |
| inputs=[input_image], | |
| outputs=[sr_scale] | |
| ) | |
| input_image.change( | |
| update_output_resolution, | |
| inputs=[input_image, sr_scale, custom_scale], | |
| outputs=[output_resolution] | |
| ) | |
| sr_scale.change( | |
| update_output_resolution, | |
| inputs=[input_image, sr_scale, custom_scale], | |
| outputs=[output_resolution] | |
| ) | |
| custom_scale.change( | |
| update_output_resolution, | |
| inputs=[input_image, sr_scale, custom_scale], | |
| outputs=[output_resolution] | |
| ) | |
| input_image.change( | |
| lambda x: gr.update(interactive=x is not None), | |
| inputs=[input_image], | |
| outputs=[sr_scale] | |
| ) | |
| block.launch(share=True) |