ICGenAIShare07 commited on
Commit
1e038da
·
verified ·
1 Parent(s): 2880446

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import spaces
5
+
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ from safetensors.torch import load_file
9
+
10
+ from diffusers import ControlNetModel
11
+ from transformers import CLIPTextModel, CLIPTokenizer
12
+ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
13
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
14
+
15
+ # Local files you must upload to the Space (same folder as this app.py)
16
+ from pipeline import build_controlnet_pipe
17
+ from prepare_laion import CannyCFG, canny_auto_median_bilateral
18
+
19
+ # -----------------------------
20
+ # Config
21
+ # -----------------------------
22
+ BASE_MODEL = "sd-legacy/stable-diffusion-v1-5"
23
+ WEIGHTS_REPO= "mvp-lab/ControlNet_Weight"
24
+ WEIGHTS_FILENAME = "diffusion_pytorch_model_1.safetensors"
25
+
26
+ # Download (cached) and get the local path
27
+ CONTROLNET_PATH = hf_hub_download(
28
+ repo_id=WEIGHTS_REPO,
29
+ filename=WEIGHTS_FILENAME,
30
+ repo_type="model"
31
+ )
32
+
33
+ # For ZeroGPU, keep dtype float32 for safety/compatibility.
34
+ DTYPE = torch.float32
35
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE)
38
+ unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet", torch_dtype=DTYPE)
39
+ tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
40
+ text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder", torch_dtype=DTYPE)
41
+
42
+ vae.requires_grad_(False)
43
+ unet.requires_grad_(False)
44
+ text_encoder.requires_grad_(False)
45
+
46
+ controlnet = ControlNetModel.from_unet(unet, conditioning_channels=3)
47
+ state = load_file(CONTROLNET_PATH)
48
+ missing, unexpected = controlnet.load_state_dict(state, strict=False)
49
+ print(f"[ControlNet] missing={len(missing)}, unexpected={len(unexpected)}")
50
+
51
+ pipe = build_controlnet_pipe(
52
+ base_model_name=BASE_MODEL,
53
+ controlnet=controlnet,
54
+ vae=vae,
55
+ unet=unet,
56
+ text_encoder=text_encoder,
57
+ tokenizer=tokenizer,
58
+ device=DEVICE,
59
+ weight_dtype=DTYPE,
60
+ use_unipc=True,
61
+ )
62
+
63
+ @torch.inference_mode()
64
+ def run_pipeline(
65
+ input_image: Image.Image,
66
+ prompt: str,
67
+ negative_prompt: str = "",
68
+ guidance_scale: float = 7.5,
69
+ num_inference_steps: int = 50,
70
+ num_images: int = 1,
71
+ controlnet_conditioning_scale: float = 1.0,
72
+ resolution: int = 512,
73
+ return_canny: bool = False,
74
+ ):
75
+ if input_image is None:
76
+ raise ValueError("input_image is None")
77
+ if num_images < 1:
78
+ raise ValueError("num_images must be >= 1")
79
+
80
+ # Resize input
81
+ img_rgb = input_image.convert("RGB").resize((resolution, resolution))
82
+
83
+ # Compute Canny conditioning image (RGB)
84
+ canny_cfg = CannyCFG(sigma=0.33, d=7, sigma_color=50, sigma_space=50)
85
+ canny = canny_auto_median_bilateral(img_rgb, canny_cfg).convert("RGB")
86
+
87
+ generators = [torch.Generator(device=DEVICE).manual_seed(i) for i in range(num_images)]
88
+
89
+ images = pipe(
90
+ prompt=[prompt] * num_images,
91
+ negative_prompt=[negative_prompt] * num_images,
92
+ image=[canny] * num_images,
93
+ num_inference_steps=int(num_inference_steps),
94
+ guidance_scale=float(guidance_scale),
95
+ height=int(resolution),
96
+ width=int(resolution),
97
+ generator=generators,
98
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
99
+ ).images
100
+
101
+ if return_canny:
102
+ return images, canny
103
+ return images
104
+
105
+ @spaces.GPU
106
+ def generate_image(
107
+ input_image,
108
+ positive_prompt,
109
+ negative_prompt,
110
+ guidance_scale,
111
+ num_inference_steps,
112
+ num_images,
113
+ controlnet_conditioning_scale,
114
+ resolution,
115
+ ):
116
+
117
+ if input_image is None:
118
+ raise gr.Error("Please upload an input image.")
119
+
120
+ # If Gradio passes numpy, convert defensively (even though type="pil" should give PIL)
121
+ if not isinstance(input_image, Image.Image):
122
+ input_image = Image.fromarray(input_image)
123
+
124
+ imgs, canny = run_pipeline(
125
+ input_image=input_image,
126
+ prompt=positive_prompt,
127
+ negative_prompt=negative_prompt,
128
+ guidance_scale=float(guidance_scale),
129
+ num_inference_steps=int(num_inference_steps),
130
+ num_images=int(num_images),
131
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
132
+ resolution=int(resolution),
133
+ return_canny=True,
134
+ )
135
+ return canny, imgs
136
+
137
+ # ----------- demo -----------
138
+ with gr.Blocks() as demo:
139
+ gr.Markdown("## ControlNet (Canny) Demo")
140
+ gr.Markdown("Upload an image and write prompt(s). The model generates images conditioned on Canny edges.")
141
+
142
+ with gr.Row():
143
+ with gr.Column():
144
+ input_image = gr.Image(
145
+ label="Input Image",
146
+ type="pil",
147
+ image_mode="RGB",
148
+ )
149
+
150
+ positive_prompt = gr.Textbox(
151
+ label="Positive Prompt",
152
+ value="",
153
+ lines=2,
154
+ placeholder="Brief description of image",
155
+ )
156
+
157
+ negative_prompt = gr.Textbox(
158
+ label="Negative Prompt",
159
+ value="",
160
+ lines=2,
161
+ placeholder="e.g., an blurry image with deformed structure",
162
+ )
163
+
164
+ with gr.Row():
165
+ guidance_scale = gr.Slider(
166
+ label="Guidance Scale",
167
+ minimum=1.0, maximum=15.0, value=7.5, step=0.1
168
+ )
169
+ num_inference_steps = gr.Slider(
170
+ label="Steps",
171
+ minimum=10, maximum=80, value=50, step=1
172
+ )
173
+
174
+ with gr.Row():
175
+ num_images = gr.Slider(
176
+ label="Number of Images",
177
+ minimum=1, maximum=6, value=1, step=1
178
+ )
179
+ controlnet_conditioning_scale = gr.Slider(
180
+ label="ControlNet Conditioning Scale",
181
+ minimum=0.0, maximum=2.0, value=1.0, step=0.05
182
+ )
183
+
184
+ resolution = gr.Dropdown(
185
+ label="Resolution",
186
+ choices=[256, 384, 512, 640, 768, 1024],
187
+ value=512,
188
+ )
189
+
190
+ run_btn = gr.Button("Generate", variant="primary")
191
+
192
+ with gr.Column(scale=1):
193
+ canny_preview = gr.Image(label="Canny edges image", type="pil")
194
+ gallery = gr.Gallery(label="Generated Images", columns=2, rows=2, height=420)
195
+
196
+ run_btn.click(
197
+ fn=generate_image,
198
+ inputs=[
199
+ input_image,
200
+ positive_prompt,
201
+ negative_prompt,
202
+ guidance_scale,
203
+ num_inference_steps,
204
+ num_images,
205
+ controlnet_conditioning_scale,
206
+ resolution,
207
+ ],
208
+ outputs=[canny_preview, gallery],
209
+ )
210
+
211
+
212
+ if __name__ == "__main__":
213
+ demo.launch()