Colorization / app.py
Noename's picture
Update app.py
765fbda verified
raw
history blame
8 kB
import gc
import os
from PIL import Image
import json
import random
import cv2
import einops
import gradio as gr
import numpy as np
import torch
from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from torch.nn.functional import threshold, normalize, interpolate
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from einops import rearrange, repeat
import argparse
device = "cuda" if torch.cuda.is_available() else "cpu"
# parse= argparse.ArgumentParser()
# parseadd_argument('--pretrained_model', type=str, default='runwayml/stable-diffusion-v1-5')
# parseadd_argument('--controlnet', type=str, default='controlnet')
# parseadd_argument('--precision', type=str, default='fp32')
# = parseparse_)
# pretrained_model = pretrained_model
pretrained_model = 'runwayml/stable-diffusion-v1-5'
controlnet = 'models'
# controlnet = 'checkpoint-34000/controlnet'
precision = 'bf16'
# Check for different hardware architectures
if torch.cuda.is_available():
device = "cuda"
# Check for xformers
try:
import xformers
enable_xformers = True
except ImportError:
enable_xformers = False
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
# Load models
if precision == 'fp32':
torch_dtype = torch.float32
elif precision == 'fp16':
torch_dtype = torch.float16
elif precision == 'bf16':
torch_dtype = torch.bfloat16
else:
raise ValueError(f"Invalid precision: {precision}")
controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch_dtype, use_safetensors=True)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
pretrained_model, controlnet=controlnet, torch_dtype=torch_dtype
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
# Apply optimizations based on hardware
if device == "cuda":
pipe = pipe.to(device)
if enable_xformers:
pipe.enable_xformers_memory_efficient_attention()
print("xformers optimization enabled")
elif device == "mps":
pipe = pipe.to(device)
pipe.enable_attention_slicing()
print("Attention slicing enabled for Apple Silicon")
else:
# CPU-specific optimizations
pipe = pipe.to(device)
# pipe.enable_sequential_cpu_offload()
pipe.enable_attention_slicing()
pipe.safety_checker = None
pipe.requires_safety_checker = False
feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
def LGB_TO_RGB(gray_image, rgb_image):
# gray_image [H, W, 3]
# rgb_image [H, W, 3]
# print("gray_image shape: ", gray_image.shape)
# print("rgb_image shape: ", rgb_image.shape)
gray_image = cv2.cvtColor(gray_image, cv2.COLOR_RGB2GRAY)
lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
lab_image[:, :, 0] = gray_image[:, :]
return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
@torch.inference_mode()
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength,
guidance_scale, seed, eta, threshold, save_memory=False):
with torch.no_grad():
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
gray_img = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
control = Image.fromarray(img)
control = control.convert('L')
if a_prompt:
prompt = prompt + ', ' + a_prompt
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
generator = torch.Generator(device=device).manual_seed(seed)
# Generate images
output = pipe(
num_images_per_prompt=num_samples,
prompt=prompt,
image=control,
negative_prompt=n_prompt,
num_inference_steps=ddim_steps,
guidance_scale=guidance_scale,
generator=generator,
eta=eta,
strength=strength,
output_type='np',
).images
# output = einops.rearrange(output, 'b c h w -> b h w c')
output = (output * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
results = [output[i] for i in range(num_samples)]
results = [LGB_TO_RGB(gray_img, result) for result in results]
# results의 각 이미지를 mask로 변환
masks = []
for result in results:
inputs = feature_extractor(images=result, return_tensors="pt")
outputs = segmodel(**inputs)
logits = outputs.logits
logits = logits.squeeze(0)
thresholded = torch.zeros_like(logits)
thresholded[logits > threshold] = 1
mask = thresholded[1:, :, :].sum(dim=0)
mask = mask.unsqueeze(0).unsqueeze(0)
mask = interpolate(mask, size=(H, W), mode='bilinear')
mask = mask.detach().numpy()
mask = np.squeeze(mask)
mask = np.where(mask > threshold, 1, 0)
masks.append(mask)
# results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
# img를 channel이 3인 rgb 이미지로 변환
final = [gray_img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
# mask to 255 img
mask_img = [mask * 255 for mask in masks]
gc.collect()
return [gray_img] + results + mask_img + final
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Control Stable Diffusion with Gray Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(sources=['upload'], type="numpy")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=1, value=1, step=1, visible=False)
# num_samples = 1
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=768, step=64)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
# guess_mode = gr.Checkbox(label='Guess Mode', value=False)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=20, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=4.0, step=0.1)
threshold = gr.Slider(label="Segmentation Threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
eta = gr.Number(label="eta (DDIM)", value=0.0)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, vivid colors')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
# result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed,
eta, threshold]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=4)
block.queue(max_size=100)
block.launch(share=True)