import os import random import numpy as np import cv2 import torch import torchvision.transforms as transforms from PIL import Image, ImageDraw, ImageFont import gradio as gr from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler from diffusers.utils.torch_utils import randn_tensor device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Function definitions def calculate_square(full_image, mask): mask_array = np.array(mask) if len(mask_array.shape) == 2: gray = mask_array else: gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY) coords = cv2.findNonZero(gray) x, y, w, h = cv2.boundingRect(coords) L = max(w, h) L = min(full_image.shape[1], full_image.shape[0] ,L) if w < L: sx0 = random.randint(max(0, x+w - L), min(x, full_image.shape[1] - L)+1) sx1 = sx0 + L else: sx0, sx1 = x, x+w if h < L: sy0 = random.randint(max(0, y+h - L), min(y, full_image.shape[0] - L)+1) sy1 = sy0 + L else: sy0, sy1 = y, y+h return [sx0, sy0, sx1, sy1] def generate_mask(trans_image, resolution, mask, location): mask = np.array(mask.convert("L"))[location[1]:location[3], location[0]:location[2]] transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((resolution, resolution)) ]) mask = transform(mask) mask = torch.where(mask > 0.5, torch.tensor(0.0), torch.tensor(1.0)) masked_image = trans_image * mask.expand_as(trans_image) mask_np = mask.squeeze().byte().cpu().numpy() mask_np = np.transpose(mask_np) points = np.column_stack(np.where(mask_np == 0)) rect = cv2.minAreaRect(points) return mask, masked_image, rect class AnytextDataset(): def __init__( self, resolution=256, ttf_size=64, max_len=25, ): self.resolution = resolution self.ttf_size = ttf_size self.max_len = max_len self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((resolution, resolution)), transforms.Normalize(mean=(0.5,), std=(0.5,)), ]) def get_input(self, image, mask, text): full_image = np.array(image.convert('RGB')) location = calculate_square(full_image, mask) crop_image = full_image[location[1]:location[3], location[0]:location[2]] trans_image = self.transform(crop_image) mask, masked_image, mask_rect = generate_mask(trans_image, self.resolution, mask, location) text = text[:self.max_len] draw_ttf = self.draw_text(text) glyph = self.draw_glyph(text, mask_rect) info = { "image": trans_image, 'mask': mask, 'masked_image': masked_image, 'ttf_img': draw_ttf, 'glyph': glyph, "text": text, "full_image": full_image, "location": location, } return info def draw_text(self, text, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"): R = self.ttf_size fs = int(0.8*R) interval = 128 // self.max_len img_tensor = torch.ones((self.max_len, R, R), dtype=torch.float) for i, char in enumerate(text): img = Image.new('L', (R, R), 255) draw = ImageDraw.Draw(img) font = ImageFont.truetype(font_path, fs) text_size = font.getsize(char) text_position = ((R - text_size[0]) // 2, (R - text_size[1]) // 2) draw.text(text_position, char, font=font, fill=interval*i) img_tensor[i] = torch.from_numpy(np.array(img)).float() / 255.0 return img_tensor def draw_glyph(self, text, rect, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"): resolution = self.resolution bg_img = np.ones((resolution, resolution, 3), dtype=np.uint8) * 255 font = ImageFont.truetype(font_path, self.ttf_size) text_img = Image.new('RGB', font.getsize(text), (255, 255, 255)) draw = ImageDraw.Draw(text_img) draw.text((0, 0), text, font=font, fill=(127, 127, 127)) text_np = np.array(text_img) rec_h, rec_w = rect[1] box = cv2.boxPoints(rect) if rec_h > rec_w * 1.5: box = [box[1], box[2], box[3], box[0]] dst_points = np.array(box, dtype=np.float32) src_points = np.float32([[0, 0], [text_np.shape[1], 0], [text_np.shape[1], text_np.shape[0]], [0, text_np.shape[0]]]) M = cv2.getPerspectiveTransform(src_points, dst_points) warped_text_img = cv2.warpPerspective(text_np, M, (resolution, resolution)) mask = np.any(warped_text_img == [127, 127, 127], axis=-1) bg_img[mask] = warped_text_img[mask] bg_img = bg_img.astype(np.float32) / 255.0 bg_img_tensor = torch.from_numpy(bg_img).permute(2, 0, 1) return bg_img_tensor class StableDiffusionPipeline: def __init__(self, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: DDPMScheduler, device): self.vae = vae self.unet = unet self.scheduler = scheduler self.device = device self.vae.to(self.device) self.unet.to(self.device) self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) @torch.no_grad() def __call__( self, prompt: torch.FloatTensor, glyph: torch.FloatTensor, masked_image: torch.FloatTensor, mask: torch.FloatTensor, num_inference_steps: int = 20, ): if masked_image is None: raise ValueError("masked_image input cannot be undefined.") self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.scheduler.timesteps vae_scale_factor = self.vae_scale_factor _, mask_height, mask_width = mask.size() mask = mask.unsqueeze(0) glyph = glyph.unsqueeze(0) masked_image = masked_image.unsqueeze(0) prompt = prompt.unsqueeze(0) mask = torch.nn.functional.interpolate(mask, size=[mask_width // vae_scale_factor, mask_height // vae_scale_factor]) glyph_latents = self.vae.encode(glyph).latent_dist.sample() * self.vae.config.scaling_factor masked_image_latents = self.vae.encode(masked_image).latent_dist.sample() * self.vae.config.scaling_factor shape = (1, self.vae.config.latent_channels, mask_height // vae_scale_factor, mask_width // vae_scale_factor) latents = randn_tensor(shape, generator=torch.manual_seed(20), device=self.device) * self.scheduler.init_noise_sigma for t in timesteps: latent_model_input = latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) sample = torch.cat([latent_model_input, masked_image_latents, glyph_latents, mask], dim=1) noise_pred = self.unet(sample=sample, timestep=t, encoder_hidden_states=prompt, ).sample latents = self.scheduler.step(noise_pred, t, latents).prev_sample pred_latents = latents / self.vae.config.scaling_factor image_vae = self.vae.decode(pred_latents).sample image = (image_vae / 2 + 0.5).clamp(0, 1) return image, image_vae # Load models (adjust the paths to your model directories) dtype = torch.float16 if torch.cuda.is_available() else torch.float32 vae = AutoencoderKL.from_pretrained("Yesianrohn/TextSSR", subfolder="vae", torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained("Yesianrohn/TextSSR", subfolder="unet", torch_dtype=dtype) noise_scheduler = DDPMScheduler.from_pretrained("Yesianrohn/TextSSR", subfolder="scheduler") # Create pipeline pipe = StableDiffusionPipeline(vae=vae, unet=unet, scheduler=noise_scheduler, device=device) # Create dataset dataset = AnytextDataset( resolution=256, ttf_size=64, max_len=25, ) def edit_mask(mask, num_points=14): mask_array = np.array(mask) if len(mask_array.shape) > 2: mask_array = mask_array[:, :, 0] if mask_array.shape[2] >= 1 else mask_array binary_mask = (mask_array > 0).astype(np.uint8) * 255 contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return Image.fromarray(binary_mask) filled_mask = np.zeros_like(binary_mask) cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED) contours, hierarchy = cv2.findContours(filled_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: largest_contour = max(contours, key=cv2.contourArea) epsilon = 0.01 * cv2.arcLength(largest_contour, True) approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True) attempts = 0 max_attempts = 20 while len(approx_contour) > num_points and attempts < max_attempts: epsilon *= 1.1 approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True) attempts += 1 attempts = 0 while len(approx_contour) < num_points and epsilon > 0.0001 and attempts < max_attempts: epsilon *= 0.9 approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True) attempts += 1 new_mask = np.zeros_like(binary_mask) points = [tuple(pt[0]) for pt in approx_contour] img = Image.fromarray(new_mask) draw = ImageDraw.Draw(img) if points: draw.polygon(points, fill=255) return img else: return Image.fromarray(filled_mask) def process_image(image, mask, text, num_points, num_inference_steps): print(text) edited_mask = edit_mask(mask["mask"], num_points=num_points) img_with_outline = image.copy() draw = ImageDraw.Draw(img_with_outline) mask_np = np.array(edited_mask) contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: largest_contour = max(contours, key=cv2.contourArea) points = [tuple(pt[0]) for pt in largest_contour] if len(points) >= 2: draw.line(points + [points[0]], fill=(255, 0, 0), width=3) input = dataset.get_input(image=image, mask=edited_mask, text=text) masked_image = input["masked_image"].to(device) mask = input["mask"].to(device) ttf_img = input["ttf_img"].to(device) glyph = input["glyph"].to(device) full_image = input["full_image"] location = input["location"] image_output, _ = pipe( prompt=ttf_img, glyph=glyph, masked_image=masked_image, mask=mask, num_inference_steps=num_inference_steps, ) mask_np = mask.cpu().detach().numpy().astype(np.uint8) coords = np.column_stack(np.where(mask_np == 0)) img = image_output[0] if coords.size > 0: y_min, x_min = coords[:, 1].min(), coords[:, 2].min() y_max, x_max = coords[:, 1].max(), coords[:, 2].max() cropped_output_image = img[:, y_min:y_max+1, x_min:x_max+1] else: cropped_output_image = img cropped_output_image_np = (cropped_output_image * 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8) cropped_output_image_pil = Image.fromarray(cropped_output_image_np) x_min, y_min, x_max, y_max = location[0], location[1], location[2], location[3] full_image_patch = full_image[y_min:y_max, x_min:x_max, :] resize_trans = transforms.Resize((full_image_patch.shape[0], full_image_patch.shape[1])) resize_mask = resize_trans(mask).cpu() resize_img = resize_trans(img).cpu() img_mask = torch.where(resize_mask < 0.5, torch.tensor(0.0), torch.tensor(1.0)) img_mask = img_mask.expand_as(resize_img) full_image_patch_tensor = transforms.ToTensor()(full_image_patch).cpu() full_image_patch_tensor = full_image_patch_tensor * img_mask + resize_img * (1 - img_mask) full_image_tensor = transforms.ToTensor()(full_image).cpu() full_image_tensor[:, y_min:y_max, x_min:x_max] = full_image_patch_tensor full_image_np = full_image_tensor.permute(1, 2, 0).numpy() full_image_pil = Image.fromarray((full_image_np * 255).astype(np.uint8)) return cropped_output_image_pil, full_image_pil, img_with_outline demo_1 = Image.open("./imgs/demo_1.jpg") demo_2 = Image.open("./imgs/demo_2.jpg") def update_image(sample): if sample == "Sample 1": return demo_1 elif sample == "Sample 2": return demo_2 else: return None with gr.Blocks() as iface: gr.Markdown("# TextSSR Demo") gr.Markdown("Upload an image, draw a mask on the image, and enter text content for region synthesis and image editing.") with gr.Row(): with gr.Column(): sample_choice = gr.Radio(choices=["Sample 1", "Sample 2"], label="Choose a Sample Background") input_image = gr.Image(type="pil", label="Input Image") mask_input = gr.Image(type="pil", label="Draw Mask on Image", tool="sketch", interactive=True) text_input = gr.Textbox(label="Text to Synthesize / Edit") outlined_image = gr.Image(type="pil", label="Original Image with Mask Outline") with gr.Row(): num_points_slider = gr.Slider( minimum=4, maximum=20, value=14, step=1, label="Control Points", info="Adjust mask complexity (4-20 points)" ) num_steps_slider = gr.Slider( minimum=1, maximum=50, value=20, step=1, label="Inference Steps", info="More steps = better quality but slower" ) submit_btn = gr.Button("Process Image") with gr.Column(): output_region = gr.Image(type="pil", label="Modified Region") output_full = gr.Image(type="pil", label="Modified Full Image") # Update input image based on the selected sample background sample_choice.change( update_image, inputs=[sample_choice], outputs=[input_image] ) # Update mask when input image changes input_image.change( lambda image: image, # Pass through image to mask_input inputs=[input_image], outputs=[mask_input] ) # Process image when submit button is clicked (updated to include num_points and num_inference_steps parameters) submit_btn.click( process_image, inputs=[input_image, mask_input, text_input, num_points_slider, num_steps_slider], outputs=[output_region, output_full, outlined_image] ) iface.launch()