Noename commited on
Commit
d59d585
·
verified ·
1 Parent(s): f7dbc32

Delete app_diff.py

Browse files
Files changed (1) hide show
  1. app_diff.py +0 -204
app_diff.py DELETED
@@ -1,204 +0,0 @@
1
- import os
2
- from PIL import Image
3
- import json
4
- import random
5
-
6
- import cv2
7
- import einops
8
- import gradio as gr
9
- import numpy as np
10
- import torch
11
-
12
- from pytorch_lightning import seed_everything
13
- from annotator.util import resize_image, HWC3
14
- from torch.nn.functional import threshold, normalize, interpolate
15
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
16
- from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
17
- from einops import rearrange, repeat
18
-
19
- import argparse
20
-
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
-
23
- # parse= argparse.ArgumentParser()
24
- # parseadd_argument('--pretrained_model', type=str, default='runwayml/stable-diffusion-v1-5')
25
- # parseadd_argument('--controlnet', type=str, default='controlnet')
26
- # parseadd_argument('--precision', type=str, default='fp32')
27
- # = parseparse_)
28
- # pretrained_model = pretrained_model
29
- pretrained_model = 'runwayml/stable-diffusion-v1-5'
30
- controlnet = 'checkpoint-36000/controlnet'
31
- precision = 'bf16'
32
-
33
- # Check for different hardware architectures
34
- if torch.cuda.is_available():
35
- device = "cuda"
36
- # Check for xformers
37
- try:
38
- import xformers
39
-
40
- enable_xformers = True
41
- except ImportError:
42
- enable_xformers = False
43
- elif torch.backends.mps.is_available():
44
- device = "mps"
45
- else:
46
- device = "cpu"
47
-
48
- print(f"Using device: {device}")
49
-
50
- # Load models
51
- if precision == 'fp32':
52
- torch_dtype = torch.float32
53
- elif precision == 'fp16':
54
- torch_dtype = torch.float16
55
- elif precision == 'bf16':
56
- torch_dtype = torch.bfloat16
57
- else:
58
- raise ValueError(f"Invalid precision: {precision}")
59
-
60
- controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch_dtype)
61
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
62
- pretrained_model, controlnet=controlnet, torch_dtype=torch_dtype
63
- )
64
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
65
- pipe = pipe.to(device)
66
-
67
- # Apply optimizations based on hardware
68
- if device == "cuda":
69
- pipe = pipe.to(device)
70
- if enable_xformers:
71
- pipe.enable_xformers_memory_efficient_attention()
72
- print("xformers optimization enabled")
73
- elif device == "mps":
74
- pipe = pipe.to(device)
75
- pipe.enable_attention_slicing()
76
- print("Attention slicing enabled for Apple Silicon")
77
- else:
78
- # CPU-specific optimizations
79
- pipe = pipe.to(device)
80
- # pipe.enable_sequential_cpu_offload()
81
- # pipe.enable_attention_slicing()
82
-
83
- feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
84
- segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
85
-
86
-
87
- def LGB_TO_RGB(gray_image, rgb_image):
88
- # gray_image [H, W, 3]
89
- # rgb_image [H, W, 3]
90
-
91
- print("gray_image shape: ", gray_image.shape)
92
- print("rgb_image shape: ", rgb_image.shape)
93
-
94
- gray_image = cv2.cvtColor(gray_image, cv2.COLOR_RGB2GRAY)
95
- lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
96
- lab_image[:, :, 0] = gray_image[:, :]
97
-
98
- return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
99
-
100
-
101
- @torch.inference_mode()
102
- def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength,
103
- guidance_scale, seed, eta, threshold, save_memory=False):
104
- with torch.no_grad():
105
- img = resize_image(input_image, image_resolution)
106
- H, W, C = img.shape
107
- print("img shape: ", img.shape)
108
- if C == 3:
109
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
110
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
111
- control = torch.from_numpy(img).to(device).float()
112
- control = control / 255.0
113
- control = rearrange(control, 'h w c -> 1 c h w')
114
- # control = repeat(control, 'b c h w -> b c h w', b=num_samples)
115
- # control = rearrange(control, 'b h w c -> b c h w')
116
-
117
- if a_prompt:
118
- prompt = prompt + ', ' + a_prompt
119
-
120
- if seed == -1:
121
- seed = random.randint(0, 65535)
122
- seed_everything(seed)
123
-
124
- generator = torch.Generator(device=device).manual_seed(seed)
125
- # Generate images
126
- output = pipe(
127
- num_images_per_prompt=num_samples,
128
- prompt=prompt,
129
- image=control.to(device),
130
- negative_prompt=n_prompt,
131
- num_inference_steps=ddim_steps,
132
- guidance_scale=guidance_scale,
133
- generator=generator,
134
- eta=eta,
135
- strength=strength,
136
- output_type='np',
137
-
138
- ).images
139
-
140
- # output = einops.rearrange(output, 'b c h w -> b h w c')
141
- output = (output * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
142
-
143
- results = [output[i] for i in range(num_samples)]
144
- results = [LGB_TO_RGB(img, result) for result in results]
145
-
146
- # results의 각 이미지를 mask로 변환
147
- masks = []
148
- for result in results:
149
- inputs = feature_extractor(images=result, return_tensors="pt")
150
- outputs = segmodel(**inputs)
151
- logits = outputs.logits
152
- logits = logits.squeeze(0)
153
- thresholded = torch.zeros_like(logits)
154
- thresholded[logits > threshold] = 1
155
- mask = thresholded[1:, :, :].sum(dim=0)
156
- mask = mask.unsqueeze(0).unsqueeze(0)
157
- mask = interpolate(mask, size=(H, W), mode='bilinear')
158
- mask = mask.detach().numpy()
159
- mask = np.squeeze(mask)
160
- mask = np.where(mask > threshold, 1, 0)
161
- masks.append(mask)
162
-
163
- # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
164
- # img를 channel이 3인 rgb 이미지로 변환
165
- final = [img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
166
-
167
- # mask to 255 img
168
-
169
- mask_img = [mask * 255 for mask in masks]
170
- return [img] + results + mask_img + final
171
-
172
-
173
- block = gr.Blocks().queue()
174
- with block:
175
- with gr.Row():
176
- gr.Markdown("## Control Stable Diffusion with Gray Image")
177
- with gr.Row():
178
- with gr.Column():
179
- input_image = gr.Image(sources=['upload'], type="numpy")
180
- prompt = gr.Textbox(label="Prompt")
181
- run_button = gr.Button(value="Run")
182
- with gr.Accordion("Advanced options", open=False):
183
- num_samples = gr.Slider(label="Images", minimum=1, maximum=1, value=1, step=1, visible=False)
184
- # num_samples = 1
185
- image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
186
- strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
187
- # guess_mode = gr.Checkbox(label='Guess Mode', value=False)
188
- ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=20, value=20, step=1)
189
- scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
190
- threshold = gr.Slider(label="Segmentation Threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
191
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
192
- eta = gr.Number(label="eta (DDIM)", value=0.0)
193
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
194
- n_prompt = gr.Textbox(label="Negative Prompt",
195
- value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
196
- with gr.Column():
197
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
198
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
199
- ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed,
200
- eta, threshold]
201
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=4)
202
-
203
- block.queue(max_size=100)
204
- block.launch(share=True)