File size: 5,739 Bytes
05ee0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da07e2d
05ee0a6
 
0307a6d
05ee0a6
0307a6d
11efc35
 
74fe6a0
da07e2d
 
0307a6d
05ee0a6
74fe6a0
05ee0a6
 
da07e2d
 
 
 
74fe6a0
da07e2d
 
 
 
 
 
 
 
0307a6d
da07e2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05ee0a6
 
 
 
f02124a
05ee0a6
 
da07e2d
 
74fe6a0
 
da07e2d
 
 
 
74fe6a0
da07e2d
 
 
 
 
 
 
 
74fe6a0
da07e2d
 
74fe6a0
da07e2d
74fe6a0
05ee0a6
 
74fe6a0
da07e2d
 
 
 
05ee0a6
c4d1df7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from share import *
import config
import os
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.uniformer import UniformerDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler

from PIL import Image

# os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
device = "cpu"

model = create_model('./models/cldm_v15_cpu.yaml').cpu()
sd_model_path = "./models/sks_crack_ppl.ckpt"
controlnet_path = "./models/sks_crack_controlnet.pth"
model.load_state_dict(load_state_dict(sd_model_path, location='cpu'), strict = False)
model.load_state_dict(load_state_dict(controlnet_path, location='cpu'), strict = False)

# model = model.cuda()
ddim_sampler = DDIMSampler(model)
init_mask = Image.open("379.png").convert("L")


def model_sample(mask,
                prompt = "sks crack, pavement cracks, HDR, Asphalt road, mudded",
                a_prompt="", 
                n_prompt="", 
                num_samples=1, ddim_steps=50, guess_mode=False, strength=1.0, scale=7.0, seed=-1, eta=0.0):
    # mask --- numpy
    ddim_sampler = DDIMSampler(model)

    with torch.no_grad():
        mask = HWC3(mask)
        mask = resize_image(mask, 512)
        H, W, C= mask.shape

        control = torch.from_numpy(mask.copy()).float().to(device) / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)

        cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
        un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
        shape = (4, H // 8, W // 8)


        model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                     shape, cond, verbose=False, eta=eta,
                                                     unconditional_guidance_scale=scale,
                                                     unconditional_conditioning=un_cond)


        x_samples = model.decode_first_stage(samples)
        x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

        results = [x_samples[i] for i in range(num_samples)]
        
    return results

block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown("## Crack Diffusion")
    with gr.Row():
        with gr.Column():
            with gr.Row():
                with gr.Tabs(elem_id="mode_img2img"):
                    with gr.TabItem('txt2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="numpy", tool="editor", image_mode="L", value=init_mask).style(height=480)
                        init_run_button = gr.Button(label="Run Init")
                    with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
                        sketch_img = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="canvas", interactive=True, type="numpy", tool="color-sketch", image_mode="L").style(height=480)
                        sketch_run_button = gr.Button(label="Run Sketch")
                    prompt = gr.Textbox(label="Prompt", value="sks crack")
            with gr.Row():
                with gr.Accordion("Advanced options", open=False):
                    num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
                    image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
                    strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
                    guess_mode = gr.Checkbox(label='Guess Mode', value=False)
                    detect_resolution = gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1)
                    ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                    scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.0, step=0.1)
                    seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                    eta = gr.Number(label="eta (DDIM)", value=0.0)
                    a_prompt = gr.Textbox(label="Added Prompt", value='')
                    n_prompt = gr.Textbox(label="Negative Prompt",
                                            value='')
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
 
    init_ips = [init_img, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
    sketch_ips = [sketch_img, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
    init_run_button.click(fn=model_sample, inputs=init_ips, outputs=[result_gallery])
    sketch_run_button.click(fn=model_sample, inputs=sketch_ips, outputs=[result_gallery])

block.launch()