Yesianrohn commited on
Commit
288b231
·
verified ·
1 Parent(s): 865f879

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +377 -375
app.py CHANGED
@@ -1,376 +1,378 @@
1
- import subprocess
2
- import sys
3
-
4
- subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"])
5
- subprocess.run([
6
- sys.executable, "-m", "pip", "install",
7
- "Pillow==9.5.0", "--user"
8
- ])
9
-
10
- import os
11
- import random
12
- import numpy as np
13
- import cv2
14
- import torch
15
- import torchvision.transforms as transforms
16
- from PIL import Image, ImageDraw, ImageFont
17
- import gradio as gr
18
- from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
19
- from diffusers.utils.torch_utils import randn_tensor
20
- from tqdm import tqdm
21
-
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
-
24
- # Function definitions
25
- def calculate_square(full_image, mask):
26
- mask_array = np.array(mask)
27
- if len(mask_array.shape) == 2:
28
- gray = mask_array
29
- else:
30
- gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
31
- coords = cv2.findNonZero(gray)
32
- x, y, w, h = cv2.boundingRect(coords)
33
- L = max(w, h)
34
- L = min(full_image.shape[1], full_image.shape[0] ,L)
35
- if w < L:
36
- sx0 = random.randint(max(0, x+w - L), min(x, full_image.shape[1] - L)+1)
37
- sx1 = sx0 + L
38
- else:
39
- sx0, sx1 = x, x+w
40
-
41
- if h < L:
42
- sy0 = random.randint(max(0, y+h - L), min(y, full_image.shape[0] - L)+1)
43
- sy1 = sy0 + L
44
- else:
45
- sy0, sy1 = y, y+h
46
-
47
- return [sx0, sy0, sx1, sy1]
48
-
49
- def generate_mask(trans_image, resolution, mask, location):
50
- mask = np.array(mask.convert("L"))[location[1]:location[3], location[0]:location[2]]
51
- transform = transforms.Compose([
52
- transforms.ToTensor(),
53
- transforms.Resize((resolution, resolution))
54
- ])
55
- mask = transform(mask)
56
- mask = torch.where(mask > 0.5, torch.tensor(0.0), torch.tensor(1.0))
57
- masked_image = trans_image * mask.expand_as(trans_image)
58
-
59
- mask_np = mask.squeeze().byte().cpu().numpy()
60
- mask_np = np.transpose(mask_np)
61
- points = np.column_stack(np.where(mask_np == 0))
62
- rect = cv2.minAreaRect(points)
63
-
64
- return mask, masked_image, rect
65
-
66
- class AnytextDataset():
67
- def __init__(
68
- self,
69
- resolution=256,
70
- ttf_size=64,
71
- max_len=25,
72
- ):
73
- self.resolution = resolution
74
- self.ttf_size = ttf_size
75
- self.max_len = max_len
76
- self.transform = transforms.Compose([
77
- transforms.ToTensor(),
78
- transforms.Resize((resolution, resolution)),
79
- transforms.Normalize(mean=(0.5,), std=(0.5,)),
80
- ])
81
-
82
- def get_input(self, image, mask, text):
83
- full_image = np.array(image.convert('RGB'))
84
- location = calculate_square(full_image, mask)
85
- crop_image = full_image[location[1]:location[3], location[0]:location[2]]
86
- trans_image = self.transform(crop_image)
87
- mask, masked_image, mask_rect = generate_mask(trans_image, self.resolution, mask, location)
88
- text = text[:self.max_len]
89
- draw_ttf = self.draw_text(text)
90
- glyph = self.draw_glyph(text, mask_rect)
91
- info = {
92
- "image": trans_image,
93
- 'mask': mask,
94
- 'masked_image': masked_image,
95
- 'ttf_img': draw_ttf,
96
- 'glyph': glyph,
97
- "text": text,
98
- "full_image": full_image,
99
- "location": location,
100
- }
101
- return info
102
-
103
- def draw_text(self, text, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"):
104
- R = self.ttf_size
105
- fs = int(0.8*R)
106
- interval = 128 // self.max_len
107
- img_tensor = torch.ones((self.max_len, R, R), dtype=torch.float)
108
- for i, char in enumerate(text):
109
- img = Image.new('L', (R, R), 255)
110
- draw = ImageDraw.Draw(img)
111
- font = ImageFont.truetype(font_path, fs)
112
- text_size = font.getsize(char)
113
- text_position = ((R - text_size[0]) // 2, (R - text_size[1]) // 2)
114
- draw.text(text_position, char, font=font, fill=interval*i)
115
- img_tensor[i] = torch.from_numpy(np.array(img)).float() / 255.0
116
- return img_tensor
117
-
118
- def draw_glyph(self, text, rect, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"):
119
- resolution = self.resolution
120
- bg_img = np.ones((resolution, resolution, 3), dtype=np.uint8) * 255
121
- font = ImageFont.truetype(font_path, self.ttf_size)
122
- text_img = Image.new('RGB', font.getsize(text), (255, 255, 255))
123
- draw = ImageDraw.Draw(text_img)
124
- draw.text((0, 0), text, font=font, fill=(127, 127, 127))
125
- text_np = np.array(text_img)
126
- rec_h, rec_w = rect[1]
127
- box = cv2.boxPoints(rect)
128
- if rec_h > rec_w * 1.5:
129
- box = [box[1], box[2], box[3], box[0]]
130
- dst_points = np.array(box, dtype=np.float32)
131
- src_points = np.float32([[0, 0], [text_np.shape[1], 0], [text_np.shape[1], text_np.shape[0]], [0, text_np.shape[0]]])
132
- M = cv2.getPerspectiveTransform(src_points, dst_points)
133
- warped_text_img = cv2.warpPerspective(text_np, M, (resolution, resolution))
134
- mask = np.any(warped_text_img == [127, 127, 127], axis=-1)
135
- bg_img[mask] = warped_text_img[mask]
136
- bg_img = bg_img.astype(np.float32) / 255.0
137
- bg_img_tensor = torch.from_numpy(bg_img).permute(2, 0, 1)
138
- return bg_img_tensor
139
-
140
- class StableDiffusionPipeline:
141
- def __init__(self, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: DDPMScheduler, device):
142
- self.vae = vae
143
- self.unet = unet
144
- self.scheduler = scheduler
145
- self.device = device
146
- self.vae.to(self.device)
147
- self.unet.to(self.device)
148
- self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
149
-
150
- @torch.no_grad()
151
- def __call__(
152
- self,
153
- prompt: torch.FloatTensor,
154
- glyph: torch.FloatTensor,
155
- masked_image: torch.FloatTensor,
156
- mask: torch.FloatTensor,
157
- num_inference_steps: int = 20,
158
- ):
159
- if masked_image is None:
160
- raise ValueError("masked_image input cannot be undefined.")
161
-
162
- self.scheduler.set_timesteps(num_inference_steps, device=self.device)
163
- timesteps = self.scheduler.timesteps
164
-
165
- vae_scale_factor = self.vae_scale_factor
166
- _, mask_height, mask_width = mask.size()
167
- mask = mask.unsqueeze(0)
168
- glyph = glyph.unsqueeze(0)
169
- masked_image = masked_image.unsqueeze(0)
170
- prompt = prompt.unsqueeze(0)
171
-
172
- mask = torch.nn.functional.interpolate(mask, size=[mask_width // vae_scale_factor, mask_height // vae_scale_factor])
173
-
174
- glyph_latents = self.vae.encode(glyph).latent_dist.sample() * self.vae.config.scaling_factor
175
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample() * self.vae.config.scaling_factor
176
-
177
- shape = (1, self.vae.config.latent_channels, mask_height // vae_scale_factor, mask_width // vae_scale_factor)
178
- latents = randn_tensor(shape, generator=torch.manual_seed(20), device=self.device) * self.scheduler.init_noise_sigma
179
-
180
- for t in tqdm(timesteps):
181
- latent_model_input = latents
182
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
183
- sample = torch.cat([latent_model_input, masked_image_latents, glyph_latents, mask], dim=1)
184
- noise_pred = self.unet(sample=sample, timestep=t, encoder_hidden_states=prompt, ).sample
185
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
186
-
187
- pred_latents = latents / self.vae.config.scaling_factor
188
- image_vae = self.vae.decode(pred_latents).sample
189
- image = (image_vae / 2 + 0.5).clamp(0, 1)
190
- return image, image_vae
191
-
192
- # Load models (adjust the paths to your model directories)
193
- vae = AutoencoderKL.from_pretrained("./model/vae")
194
- unet = UNet2DConditionModel.from_pretrained("./model/unet")
195
- noise_scheduler = DDPMScheduler.from_pretrained("./model/scheduler")
196
-
197
- # Create pipeline
198
- pipe = StableDiffusionPipeline(vae=vae, unet=unet, scheduler=noise_scheduler, device=device)
199
-
200
- # Create dataset
201
- dataset = AnytextDataset(
202
- resolution=256,
203
- ttf_size=64,
204
- max_len=25,
205
- )
206
-
207
- def edit_mask(mask, num_points=14):
208
- mask_array = np.array(mask)
209
- if len(mask_array.shape) > 2:
210
- mask_array = mask_array[:, :, 0] if mask_array.shape[2] >= 1 else mask_array
211
- binary_mask = (mask_array > 0).astype(np.uint8) * 255
212
- contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
213
-
214
- if not contours:
215
- return Image.fromarray(binary_mask)
216
- filled_mask = np.zeros_like(binary_mask)
217
- cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
218
- contours, hierarchy = cv2.findContours(filled_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
219
- if contours:
220
- largest_contour = max(contours, key=cv2.contourArea)
221
- epsilon = 0.01 * cv2.arcLength(largest_contour, True)
222
- approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
223
- attempts = 0
224
- max_attempts = 20
225
- while len(approx_contour) > num_points and attempts < max_attempts:
226
- epsilon *= 1.1
227
- approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
228
- attempts += 1
229
- attempts = 0
230
- while len(approx_contour) < num_points and epsilon > 0.0001 and attempts < max_attempts:
231
- epsilon *= 0.9
232
- approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
233
- attempts += 1
234
- new_mask = np.zeros_like(binary_mask)
235
- points = [tuple(pt[0]) for pt in approx_contour]
236
- img = Image.fromarray(new_mask)
237
- draw = ImageDraw.Draw(img)
238
- if points:
239
- draw.polygon(points, fill=255)
240
- return img
241
- else:
242
- return Image.fromarray(filled_mask)
243
-
244
- def process_image(image, mask, text, num_points, num_inference_steps):
245
- print(text)
246
-
247
- edited_mask = edit_mask(mask["mask"], num_points=num_points)
248
- img_with_outline = image.copy()
249
- draw = ImageDraw.Draw(img_with_outline)
250
-
251
- mask_np = np.array(edited_mask)
252
- contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
253
-
254
- if contours:
255
- largest_contour = max(contours, key=cv2.contourArea)
256
- points = [tuple(pt[0]) for pt in largest_contour]
257
- if len(points) >= 2:
258
- draw.line(points + [points[0]], fill=(255, 0, 0), width=3)
259
-
260
- input = dataset.get_input(image=image, mask=edited_mask, text=text)
261
-
262
- masked_image = input["masked_image"].to(device)
263
- mask = input["mask"].to(device)
264
- ttf_img = input["ttf_img"].to(device)
265
- glyph = input["glyph"].to(device)
266
- full_image = input["full_image"]
267
- location = input["location"]
268
-
269
- image_output, _ = pipe(
270
- prompt=ttf_img,
271
- glyph=glyph,
272
- masked_image=masked_image,
273
- mask=mask,
274
- num_inference_steps=num_inference_steps,
275
- )
276
-
277
- mask_np = mask.cpu().detach().numpy().astype(np.uint8)
278
- coords = np.column_stack(np.where(mask_np == 0))
279
- img = image_output[0]
280
- if coords.size > 0:
281
- y_min, x_min = coords[:, 1].min(), coords[:, 2].min()
282
- y_max, x_max = coords[:, 1].max(), coords[:, 2].max()
283
- cropped_output_image = img[:, y_min:y_max+1, x_min:x_max+1]
284
- else:
285
- cropped_output_image = img
286
- cropped_output_image_np = (cropped_output_image * 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8)
287
- cropped_output_image_pil = Image.fromarray(cropped_output_image_np)
288
-
289
- x_min, y_min, x_max, y_max = location[0], location[1], location[2], location[3]
290
- full_image_patch = full_image[y_min:y_max, x_min:x_max, :]
291
- resize_trans = transforms.Resize((full_image_patch.shape[0], full_image_patch.shape[1]))
292
- resize_mask = resize_trans(mask).cpu()
293
- resize_img = resize_trans(img).cpu()
294
-
295
- img_mask = torch.where(resize_mask < 0.5, torch.tensor(0.0), torch.tensor(1.0))
296
- img_mask = img_mask.expand_as(resize_img)
297
- full_image_patch_tensor = transforms.ToTensor()(full_image_patch).cpu()
298
- full_image_patch_tensor = full_image_patch_tensor * img_mask + resize_img * (1 - img_mask)
299
-
300
- full_image_tensor = transforms.ToTensor()(full_image).cpu()
301
- full_image_tensor[:, y_min:y_max, x_min:x_max] = full_image_patch_tensor
302
-
303
- full_image_np = full_image_tensor.permute(1, 2, 0).numpy()
304
- full_image_pil = Image.fromarray((full_image_np * 255).astype(np.uint8))
305
-
306
- return cropped_output_image_pil, full_image_pil, img_with_outline
307
-
308
- demo_1 = Image.open("./imgs/demo_1.jpg")
309
- demo_2 = Image.open("./imgs/demo_2.jpg")
310
-
311
- def update_image(sample):
312
- if sample == "Sample 1":
313
- return demo_1
314
- elif sample == "Sample 2":
315
- return demo_2
316
- else:
317
- return None
318
-
319
- with gr.Blocks() as iface:
320
- gr.Markdown("# TextSSR Demo")
321
- gr.Markdown("Upload an image, draw a mask on the image, and enter text content for region synthesis and image editing.")
322
-
323
- with gr.Row():
324
- with gr.Column():
325
- sample_choice = gr.Radio(choices=["Sample 1", "Sample 2"], label="Choose a Sample Background")
326
- input_image = gr.Image(type="pil", label="Input Image")
327
- mask_input = gr.Image(type="pil", label="Draw Mask on Image", tool="sketch", interactive=True)
328
- text_input = gr.Textbox(label="Text to Synthesize / Edit")
329
- outlined_image = gr.Image(type="pil", label="Original Image with Mask Outline")
330
-
331
- with gr.Row():
332
- num_points_slider = gr.Slider(
333
- minimum=4,
334
- maximum=20,
335
- value=14,
336
- step=1,
337
- label="Control Points",
338
- info="Adjust mask complexity (4-20 points)"
339
- )
340
-
341
- num_steps_slider = gr.Slider(
342
- minimum=5,
343
- maximum=50,
344
- value=20,
345
- step=1,
346
- label="Inference Steps",
347
- info="More steps = better quality but slower"
348
- )
349
-
350
- submit_btn = gr.Button("Process Image")
351
-
352
- with gr.Column():
353
- output_region = gr.Image(type="pil", label="Modified Region")
354
- output_full = gr.Image(type="pil", label="Modified Full Image")
355
-
356
- # Update input image based on the selected sample background
357
- sample_choice.change(
358
- update_image,
359
- inputs=[sample_choice],
360
- outputs=[input_image]
361
- )
362
-
363
- # Update mask when input image changes
364
- input_image.change(
365
- lambda image: image, # Pass through image to mask_input
366
- inputs=[input_image],
367
- outputs=[mask_input]
368
- )
369
- # Process image when submit button is clicked (updated to include num_points and num_inference_steps parameters)
370
- submit_btn.click(
371
- process_image,
372
- inputs=[input_image, mask_input, text_input, num_points_slider, num_steps_slider],
373
- outputs=[output_region, output_full, outlined_image]
374
- )
375
-
 
 
376
  iface.launch(server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1", share=False)
 
1
+ import subprocess
2
+ import sys
3
+
4
+ subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"])
5
+ subprocess.run([
6
+ sys.executable, "-m", "pip", "install",
7
+ "Pillow==9.5.0", "opencv-python==4.8.1.78", "torch==2.1.0",
8
+ "torchvision==0.16.0", "gradio==4.13.0",
9
+ "tqdm==4.66.1", "numpy==1.24.4", "--user"
10
+ ])
11
+
12
+ import os
13
+ import random
14
+ import numpy as np
15
+ import cv2
16
+ import torch
17
+ import torchvision.transforms as transforms
18
+ from PIL import Image, ImageDraw, ImageFont
19
+ import gradio as gr
20
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
21
+ from diffusers.utils.torch_utils import randn_tensor
22
+ from tqdm import tqdm
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # Function definitions
27
+ def calculate_square(full_image, mask):
28
+ mask_array = np.array(mask)
29
+ if len(mask_array.shape) == 2:
30
+ gray = mask_array
31
+ else:
32
+ gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
33
+ coords = cv2.findNonZero(gray)
34
+ x, y, w, h = cv2.boundingRect(coords)
35
+ L = max(w, h)
36
+ L = min(full_image.shape[1], full_image.shape[0] ,L)
37
+ if w < L:
38
+ sx0 = random.randint(max(0, x+w - L), min(x, full_image.shape[1] - L)+1)
39
+ sx1 = sx0 + L
40
+ else:
41
+ sx0, sx1 = x, x+w
42
+
43
+ if h < L:
44
+ sy0 = random.randint(max(0, y+h - L), min(y, full_image.shape[0] - L)+1)
45
+ sy1 = sy0 + L
46
+ else:
47
+ sy0, sy1 = y, y+h
48
+
49
+ return [sx0, sy0, sx1, sy1]
50
+
51
+ def generate_mask(trans_image, resolution, mask, location):
52
+ mask = np.array(mask.convert("L"))[location[1]:location[3], location[0]:location[2]]
53
+ transform = transforms.Compose([
54
+ transforms.ToTensor(),
55
+ transforms.Resize((resolution, resolution))
56
+ ])
57
+ mask = transform(mask)
58
+ mask = torch.where(mask > 0.5, torch.tensor(0.0), torch.tensor(1.0))
59
+ masked_image = trans_image * mask.expand_as(trans_image)
60
+
61
+ mask_np = mask.squeeze().byte().cpu().numpy()
62
+ mask_np = np.transpose(mask_np)
63
+ points = np.column_stack(np.where(mask_np == 0))
64
+ rect = cv2.minAreaRect(points)
65
+
66
+ return mask, masked_image, rect
67
+
68
+ class AnytextDataset():
69
+ def __init__(
70
+ self,
71
+ resolution=256,
72
+ ttf_size=64,
73
+ max_len=25,
74
+ ):
75
+ self.resolution = resolution
76
+ self.ttf_size = ttf_size
77
+ self.max_len = max_len
78
+ self.transform = transforms.Compose([
79
+ transforms.ToTensor(),
80
+ transforms.Resize((resolution, resolution)),
81
+ transforms.Normalize(mean=(0.5,), std=(0.5,)),
82
+ ])
83
+
84
+ def get_input(self, image, mask, text):
85
+ full_image = np.array(image.convert('RGB'))
86
+ location = calculate_square(full_image, mask)
87
+ crop_image = full_image[location[1]:location[3], location[0]:location[2]]
88
+ trans_image = self.transform(crop_image)
89
+ mask, masked_image, mask_rect = generate_mask(trans_image, self.resolution, mask, location)
90
+ text = text[:self.max_len]
91
+ draw_ttf = self.draw_text(text)
92
+ glyph = self.draw_glyph(text, mask_rect)
93
+ info = {
94
+ "image": trans_image,
95
+ 'mask': mask,
96
+ 'masked_image': masked_image,
97
+ 'ttf_img': draw_ttf,
98
+ 'glyph': glyph,
99
+ "text": text,
100
+ "full_image": full_image,
101
+ "location": location,
102
+ }
103
+ return info
104
+
105
+ def draw_text(self, text, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"):
106
+ R = self.ttf_size
107
+ fs = int(0.8*R)
108
+ interval = 128 // self.max_len
109
+ img_tensor = torch.ones((self.max_len, R, R), dtype=torch.float)
110
+ for i, char in enumerate(text):
111
+ img = Image.new('L', (R, R), 255)
112
+ draw = ImageDraw.Draw(img)
113
+ font = ImageFont.truetype(font_path, fs)
114
+ text_size = font.getsize(char)
115
+ text_position = ((R - text_size[0]) // 2, (R - text_size[1]) // 2)
116
+ draw.text(text_position, char, font=font, fill=interval*i)
117
+ img_tensor[i] = torch.from_numpy(np.array(img)).float() / 255.0
118
+ return img_tensor
119
+
120
+ def draw_glyph(self, text, rect, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"):
121
+ resolution = self.resolution
122
+ bg_img = np.ones((resolution, resolution, 3), dtype=np.uint8) * 255
123
+ font = ImageFont.truetype(font_path, self.ttf_size)
124
+ text_img = Image.new('RGB', font.getsize(text), (255, 255, 255))
125
+ draw = ImageDraw.Draw(text_img)
126
+ draw.text((0, 0), text, font=font, fill=(127, 127, 127))
127
+ text_np = np.array(text_img)
128
+ rec_h, rec_w = rect[1]
129
+ box = cv2.boxPoints(rect)
130
+ if rec_h > rec_w * 1.5:
131
+ box = [box[1], box[2], box[3], box[0]]
132
+ dst_points = np.array(box, dtype=np.float32)
133
+ src_points = np.float32([[0, 0], [text_np.shape[1], 0], [text_np.shape[1], text_np.shape[0]], [0, text_np.shape[0]]])
134
+ M = cv2.getPerspectiveTransform(src_points, dst_points)
135
+ warped_text_img = cv2.warpPerspective(text_np, M, (resolution, resolution))
136
+ mask = np.any(warped_text_img == [127, 127, 127], axis=-1)
137
+ bg_img[mask] = warped_text_img[mask]
138
+ bg_img = bg_img.astype(np.float32) / 255.0
139
+ bg_img_tensor = torch.from_numpy(bg_img).permute(2, 0, 1)
140
+ return bg_img_tensor
141
+
142
+ class StableDiffusionPipeline:
143
+ def __init__(self, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: DDPMScheduler, device):
144
+ self.vae = vae
145
+ self.unet = unet
146
+ self.scheduler = scheduler
147
+ self.device = device
148
+ self.vae.to(self.device)
149
+ self.unet.to(self.device)
150
+ self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
151
+
152
+ @torch.no_grad()
153
+ def __call__(
154
+ self,
155
+ prompt: torch.FloatTensor,
156
+ glyph: torch.FloatTensor,
157
+ masked_image: torch.FloatTensor,
158
+ mask: torch.FloatTensor,
159
+ num_inference_steps: int = 20,
160
+ ):
161
+ if masked_image is None:
162
+ raise ValueError("masked_image input cannot be undefined.")
163
+
164
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
165
+ timesteps = self.scheduler.timesteps
166
+
167
+ vae_scale_factor = self.vae_scale_factor
168
+ _, mask_height, mask_width = mask.size()
169
+ mask = mask.unsqueeze(0)
170
+ glyph = glyph.unsqueeze(0)
171
+ masked_image = masked_image.unsqueeze(0)
172
+ prompt = prompt.unsqueeze(0)
173
+
174
+ mask = torch.nn.functional.interpolate(mask, size=[mask_width // vae_scale_factor, mask_height // vae_scale_factor])
175
+
176
+ glyph_latents = self.vae.encode(glyph).latent_dist.sample() * self.vae.config.scaling_factor
177
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample() * self.vae.config.scaling_factor
178
+
179
+ shape = (1, self.vae.config.latent_channels, mask_height // vae_scale_factor, mask_width // vae_scale_factor)
180
+ latents = randn_tensor(shape, generator=torch.manual_seed(20), device=self.device) * self.scheduler.init_noise_sigma
181
+
182
+ for t in tqdm(timesteps):
183
+ latent_model_input = latents
184
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
185
+ sample = torch.cat([latent_model_input, masked_image_latents, glyph_latents, mask], dim=1)
186
+ noise_pred = self.unet(sample=sample, timestep=t, encoder_hidden_states=prompt, ).sample
187
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
188
+
189
+ pred_latents = latents / self.vae.config.scaling_factor
190
+ image_vae = self.vae.decode(pred_latents).sample
191
+ image = (image_vae / 2 + 0.5).clamp(0, 1)
192
+ return image, image_vae
193
+
194
+ # Load models (adjust the paths to your model directories)
195
+ vae = AutoencoderKL.from_pretrained("Yesianrohn/TextSSR/vae")
196
+ unet = UNet2DConditionModel.from_pretrained("Yesianrohn/TextSSR/unet")
197
+ noise_scheduler = DDPMScheduler.from_pretrained("Yesianrohn/TextSSR/scheduler")
198
+
199
+ # Create pipeline
200
+ pipe = StableDiffusionPipeline(vae=vae, unet=unet, scheduler=noise_scheduler, device=device)
201
+
202
+ # Create dataset
203
+ dataset = AnytextDataset(
204
+ resolution=256,
205
+ ttf_size=64,
206
+ max_len=25,
207
+ )
208
+
209
+ def edit_mask(mask, num_points=14):
210
+ mask_array = np.array(mask)
211
+ if len(mask_array.shape) > 2:
212
+ mask_array = mask_array[:, :, 0] if mask_array.shape[2] >= 1 else mask_array
213
+ binary_mask = (mask_array > 0).astype(np.uint8) * 255
214
+ contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
215
+
216
+ if not contours:
217
+ return Image.fromarray(binary_mask)
218
+ filled_mask = np.zeros_like(binary_mask)
219
+ cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
220
+ contours, hierarchy = cv2.findContours(filled_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
221
+ if contours:
222
+ largest_contour = max(contours, key=cv2.contourArea)
223
+ epsilon = 0.01 * cv2.arcLength(largest_contour, True)
224
+ approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
225
+ attempts = 0
226
+ max_attempts = 20
227
+ while len(approx_contour) > num_points and attempts < max_attempts:
228
+ epsilon *= 1.1
229
+ approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
230
+ attempts += 1
231
+ attempts = 0
232
+ while len(approx_contour) < num_points and epsilon > 0.0001 and attempts < max_attempts:
233
+ epsilon *= 0.9
234
+ approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
235
+ attempts += 1
236
+ new_mask = np.zeros_like(binary_mask)
237
+ points = [tuple(pt[0]) for pt in approx_contour]
238
+ img = Image.fromarray(new_mask)
239
+ draw = ImageDraw.Draw(img)
240
+ if points:
241
+ draw.polygon(points, fill=255)
242
+ return img
243
+ else:
244
+ return Image.fromarray(filled_mask)
245
+
246
+ def process_image(image, mask, text, num_points, num_inference_steps):
247
+ print(text)
248
+
249
+ edited_mask = edit_mask(mask["mask"], num_points=num_points)
250
+ img_with_outline = image.copy()
251
+ draw = ImageDraw.Draw(img_with_outline)
252
+
253
+ mask_np = np.array(edited_mask)
254
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
255
+
256
+ if contours:
257
+ largest_contour = max(contours, key=cv2.contourArea)
258
+ points = [tuple(pt[0]) for pt in largest_contour]
259
+ if len(points) >= 2:
260
+ draw.line(points + [points[0]], fill=(255, 0, 0), width=3)
261
+
262
+ input = dataset.get_input(image=image, mask=edited_mask, text=text)
263
+
264
+ masked_image = input["masked_image"].to(device)
265
+ mask = input["mask"].to(device)
266
+ ttf_img = input["ttf_img"].to(device)
267
+ glyph = input["glyph"].to(device)
268
+ full_image = input["full_image"]
269
+ location = input["location"]
270
+
271
+ image_output, _ = pipe(
272
+ prompt=ttf_img,
273
+ glyph=glyph,
274
+ masked_image=masked_image,
275
+ mask=mask,
276
+ num_inference_steps=num_inference_steps,
277
+ )
278
+
279
+ mask_np = mask.cpu().detach().numpy().astype(np.uint8)
280
+ coords = np.column_stack(np.where(mask_np == 0))
281
+ img = image_output[0]
282
+ if coords.size > 0:
283
+ y_min, x_min = coords[:, 1].min(), coords[:, 2].min()
284
+ y_max, x_max = coords[:, 1].max(), coords[:, 2].max()
285
+ cropped_output_image = img[:, y_min:y_max+1, x_min:x_max+1]
286
+ else:
287
+ cropped_output_image = img
288
+ cropped_output_image_np = (cropped_output_image * 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8)
289
+ cropped_output_image_pil = Image.fromarray(cropped_output_image_np)
290
+
291
+ x_min, y_min, x_max, y_max = location[0], location[1], location[2], location[3]
292
+ full_image_patch = full_image[y_min:y_max, x_min:x_max, :]
293
+ resize_trans = transforms.Resize((full_image_patch.shape[0], full_image_patch.shape[1]))
294
+ resize_mask = resize_trans(mask).cpu()
295
+ resize_img = resize_trans(img).cpu()
296
+
297
+ img_mask = torch.where(resize_mask < 0.5, torch.tensor(0.0), torch.tensor(1.0))
298
+ img_mask = img_mask.expand_as(resize_img)
299
+ full_image_patch_tensor = transforms.ToTensor()(full_image_patch).cpu()
300
+ full_image_patch_tensor = full_image_patch_tensor * img_mask + resize_img * (1 - img_mask)
301
+
302
+ full_image_tensor = transforms.ToTensor()(full_image).cpu()
303
+ full_image_tensor[:, y_min:y_max, x_min:x_max] = full_image_patch_tensor
304
+
305
+ full_image_np = full_image_tensor.permute(1, 2, 0).numpy()
306
+ full_image_pil = Image.fromarray((full_image_np * 255).astype(np.uint8))
307
+
308
+ return cropped_output_image_pil, full_image_pil, img_with_outline
309
+
310
+ demo_1 = Image.open("./imgs/demo_1.jpg")
311
+ demo_2 = Image.open("./imgs/demo_2.jpg")
312
+
313
+ def update_image(sample):
314
+ if sample == "Sample 1":
315
+ return demo_1
316
+ elif sample == "Sample 2":
317
+ return demo_2
318
+ else:
319
+ return None
320
+
321
+ with gr.Blocks() as iface:
322
+ gr.Markdown("# TextSSR Demo")
323
+ gr.Markdown("Upload an image, draw a mask on the image, and enter text content for region synthesis and image editing.")
324
+
325
+ with gr.Row():
326
+ with gr.Column():
327
+ sample_choice = gr.Radio(choices=["Sample 1", "Sample 2"], label="Choose a Sample Background")
328
+ input_image = gr.Image(type="pil", label="Input Image")
329
+ mask_input = gr.Image(type="pil", label="Draw Mask on Image", tool="sketch", interactive=True)
330
+ text_input = gr.Textbox(label="Text to Synthesize / Edit")
331
+ outlined_image = gr.Image(type="pil", label="Original Image with Mask Outline")
332
+
333
+ with gr.Row():
334
+ num_points_slider = gr.Slider(
335
+ minimum=4,
336
+ maximum=20,
337
+ value=14,
338
+ step=1,
339
+ label="Control Points",
340
+ info="Adjust mask complexity (4-20 points)"
341
+ )
342
+
343
+ num_steps_slider = gr.Slider(
344
+ minimum=5,
345
+ maximum=50,
346
+ value=20,
347
+ step=1,
348
+ label="Inference Steps",
349
+ info="More steps = better quality but slower"
350
+ )
351
+
352
+ submit_btn = gr.Button("Process Image")
353
+
354
+ with gr.Column():
355
+ output_region = gr.Image(type="pil", label="Modified Region")
356
+ output_full = gr.Image(type="pil", label="Modified Full Image")
357
+
358
+ # Update input image based on the selected sample background
359
+ sample_choice.change(
360
+ update_image,
361
+ inputs=[sample_choice],
362
+ outputs=[input_image]
363
+ )
364
+
365
+ # Update mask when input image changes
366
+ input_image.change(
367
+ lambda image: image, # Pass through image to mask_input
368
+ inputs=[input_image],
369
+ outputs=[mask_input]
370
+ )
371
+ # Process image when submit button is clicked (updated to include num_points and num_inference_steps parameters)
372
+ submit_btn.click(
373
+ process_image,
374
+ inputs=[input_image, mask_input, text_input, num_points_slider, num_steps_slider],
375
+ outputs=[output_region, output_full, outlined_image]
376
+ )
377
+
378
  iface.launch(server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1", share=False)