Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler | |
| from diffusers.utils.torch_utils import randn_tensor | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Function definitions | |
| def calculate_square(full_image, mask): | |
| mask_array = np.array(mask) | |
| if len(mask_array.shape) == 2: | |
| gray = mask_array | |
| else: | |
| gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY) | |
| coords = cv2.findNonZero(gray) | |
| x, y, w, h = cv2.boundingRect(coords) | |
| L = max(w, h) | |
| L = min(full_image.shape[1], full_image.shape[0] ,L) | |
| if w < L: | |
| sx0 = random.randint(max(0, x+w - L), min(x, full_image.shape[1] - L)+1) | |
| sx1 = sx0 + L | |
| else: | |
| sx0, sx1 = x, x+w | |
| if h < L: | |
| sy0 = random.randint(max(0, y+h - L), min(y, full_image.shape[0] - L)+1) | |
| sy1 = sy0 + L | |
| else: | |
| sy0, sy1 = y, y+h | |
| return [sx0, sy0, sx1, sy1] | |
| def generate_mask(trans_image, resolution, mask, location): | |
| mask = np.array(mask.convert("L"))[location[1]:location[3], location[0]:location[2]] | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((resolution, resolution)) | |
| ]) | |
| mask = transform(mask) | |
| mask = torch.where(mask > 0.5, torch.tensor(0.0), torch.tensor(1.0)) | |
| masked_image = trans_image * mask.expand_as(trans_image) | |
| mask_np = mask.squeeze().byte().cpu().numpy() | |
| mask_np = np.transpose(mask_np) | |
| points = np.column_stack(np.where(mask_np == 0)) | |
| rect = cv2.minAreaRect(points) | |
| return mask, masked_image, rect | |
| class AnytextDataset(): | |
| def __init__( | |
| self, | |
| resolution=256, | |
| ttf_size=64, | |
| max_len=25, | |
| ): | |
| self.resolution = resolution | |
| self.ttf_size = ttf_size | |
| self.max_len = max_len | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((resolution, resolution)), | |
| transforms.Normalize(mean=(0.5,), std=(0.5,)), | |
| ]) | |
| def get_input(self, image, mask, text): | |
| full_image = np.array(image.convert('RGB')) | |
| location = calculate_square(full_image, mask) | |
| crop_image = full_image[location[1]:location[3], location[0]:location[2]] | |
| trans_image = self.transform(crop_image) | |
| mask, masked_image, mask_rect = generate_mask(trans_image, self.resolution, mask, location) | |
| text = text[:self.max_len] | |
| draw_ttf = self.draw_text(text) | |
| glyph = self.draw_glyph(text, mask_rect) | |
| info = { | |
| "image": trans_image, | |
| 'mask': mask, | |
| 'masked_image': masked_image, | |
| 'ttf_img': draw_ttf, | |
| 'glyph': glyph, | |
| "text": text, | |
| "full_image": full_image, | |
| "location": location, | |
| } | |
| return info | |
| def draw_text(self, text, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"): | |
| R = self.ttf_size | |
| fs = int(0.8*R) | |
| interval = 128 // self.max_len | |
| img_tensor = torch.ones((self.max_len, R, R), dtype=torch.float) | |
| for i, char in enumerate(text): | |
| img = Image.new('L', (R, R), 255) | |
| draw = ImageDraw.Draw(img) | |
| font = ImageFont.truetype(font_path, fs) | |
| text_size = font.getsize(char) | |
| text_position = ((R - text_size[0]) // 2, (R - text_size[1]) // 2) | |
| draw.text(text_position, char, font=font, fill=interval*i) | |
| img_tensor[i] = torch.from_numpy(np.array(img)).float() / 255.0 | |
| return img_tensor | |
| def draw_glyph(self, text, rect, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"): | |
| resolution = self.resolution | |
| bg_img = np.ones((resolution, resolution, 3), dtype=np.uint8) * 255 | |
| font = ImageFont.truetype(font_path, self.ttf_size) | |
| text_img = Image.new('RGB', font.getsize(text), (255, 255, 255)) | |
| draw = ImageDraw.Draw(text_img) | |
| draw.text((0, 0), text, font=font, fill=(127, 127, 127)) | |
| text_np = np.array(text_img) | |
| rec_h, rec_w = rect[1] | |
| box = cv2.boxPoints(rect) | |
| if rec_h > rec_w * 1.5: | |
| box = [box[1], box[2], box[3], box[0]] | |
| dst_points = np.array(box, dtype=np.float32) | |
| src_points = np.float32([[0, 0], [text_np.shape[1], 0], [text_np.shape[1], text_np.shape[0]], [0, text_np.shape[0]]]) | |
| M = cv2.getPerspectiveTransform(src_points, dst_points) | |
| warped_text_img = cv2.warpPerspective(text_np, M, (resolution, resolution)) | |
| mask = np.any(warped_text_img == [127, 127, 127], axis=-1) | |
| bg_img[mask] = warped_text_img[mask] | |
| bg_img = bg_img.astype(np.float32) / 255.0 | |
| bg_img_tensor = torch.from_numpy(bg_img).permute(2, 0, 1) | |
| return bg_img_tensor | |
| class StableDiffusionPipeline: | |
| def __init__(self, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: DDPMScheduler, device): | |
| self.vae = vae | |
| self.unet = unet | |
| self.scheduler = scheduler | |
| self.device = device | |
| self.vae.to(self.device) | |
| self.unet.to(self.device) | |
| self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) | |
| def __call__( | |
| self, | |
| prompt: torch.FloatTensor, | |
| glyph: torch.FloatTensor, | |
| masked_image: torch.FloatTensor, | |
| mask: torch.FloatTensor, | |
| num_inference_steps: int = 20, | |
| ): | |
| if masked_image is None: | |
| raise ValueError("masked_image input cannot be undefined.") | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| vae_scale_factor = self.vae_scale_factor | |
| _, mask_height, mask_width = mask.size() | |
| mask = mask.unsqueeze(0) | |
| glyph = glyph.unsqueeze(0) | |
| masked_image = masked_image.unsqueeze(0) | |
| prompt = prompt.unsqueeze(0) | |
| mask = torch.nn.functional.interpolate(mask, size=[mask_width // vae_scale_factor, mask_height // vae_scale_factor]) | |
| glyph_latents = self.vae.encode(glyph).latent_dist.sample() * self.vae.config.scaling_factor | |
| masked_image_latents = self.vae.encode(masked_image).latent_dist.sample() * self.vae.config.scaling_factor | |
| shape = (1, self.vae.config.latent_channels, mask_height // vae_scale_factor, mask_width // vae_scale_factor) | |
| latents = randn_tensor(shape, generator=torch.manual_seed(20), device=self.device) * self.scheduler.init_noise_sigma | |
| for t in timesteps: | |
| latent_model_input = latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| sample = torch.cat([latent_model_input, masked_image_latents, glyph_latents, mask], dim=1) | |
| noise_pred = self.unet(sample=sample, timestep=t, encoder_hidden_states=prompt, ).sample | |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
| pred_latents = latents / self.vae.config.scaling_factor | |
| image_vae = self.vae.decode(pred_latents).sample | |
| image = (image_vae / 2 + 0.5).clamp(0, 1) | |
| return image, image_vae | |
| # Load models (adjust the paths to your model directories) | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| vae = AutoencoderKL.from_pretrained("Yesianrohn/TextSSR", subfolder="vae", torch_dtype=dtype) | |
| unet = UNet2DConditionModel.from_pretrained("Yesianrohn/TextSSR", subfolder="unet", torch_dtype=dtype) | |
| noise_scheduler = DDPMScheduler.from_pretrained("Yesianrohn/TextSSR", subfolder="scheduler") | |
| # Create pipeline | |
| pipe = StableDiffusionPipeline(vae=vae, unet=unet, scheduler=noise_scheduler, device=device) | |
| # Create dataset | |
| dataset = AnytextDataset( | |
| resolution=256, | |
| ttf_size=64, | |
| max_len=25, | |
| ) | |
| def edit_mask(mask, num_points=14): | |
| mask_array = np.array(mask) | |
| if len(mask_array.shape) > 2: | |
| mask_array = mask_array[:, :, 0] if mask_array.shape[2] >= 1 else mask_array | |
| binary_mask = (mask_array > 0).astype(np.uint8) * 255 | |
| contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return Image.fromarray(binary_mask) | |
| filled_mask = np.zeros_like(binary_mask) | |
| cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED) | |
| contours, hierarchy = cv2.findContours(filled_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if contours: | |
| largest_contour = max(contours, key=cv2.contourArea) | |
| epsilon = 0.01 * cv2.arcLength(largest_contour, True) | |
| approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True) | |
| attempts = 0 | |
| max_attempts = 20 | |
| while len(approx_contour) > num_points and attempts < max_attempts: | |
| epsilon *= 1.1 | |
| approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True) | |
| attempts += 1 | |
| attempts = 0 | |
| while len(approx_contour) < num_points and epsilon > 0.0001 and attempts < max_attempts: | |
| epsilon *= 0.9 | |
| approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True) | |
| attempts += 1 | |
| new_mask = np.zeros_like(binary_mask) | |
| points = [tuple(pt[0]) for pt in approx_contour] | |
| img = Image.fromarray(new_mask) | |
| draw = ImageDraw.Draw(img) | |
| if points: | |
| draw.polygon(points, fill=255) | |
| return img | |
| else: | |
| return Image.fromarray(filled_mask) | |
| def process_image(image, mask, text, num_points, num_inference_steps): | |
| print(text) | |
| edited_mask = edit_mask(mask["mask"], num_points=num_points) | |
| img_with_outline = image.copy() | |
| draw = ImageDraw.Draw(img_with_outline) | |
| mask_np = np.array(edited_mask) | |
| contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if contours: | |
| largest_contour = max(contours, key=cv2.contourArea) | |
| points = [tuple(pt[0]) for pt in largest_contour] | |
| if len(points) >= 2: | |
| draw.line(points + [points[0]], fill=(255, 0, 0), width=3) | |
| input = dataset.get_input(image=image, mask=edited_mask, text=text) | |
| masked_image = input["masked_image"].to(device) | |
| mask = input["mask"].to(device) | |
| ttf_img = input["ttf_img"].to(device) | |
| glyph = input["glyph"].to(device) | |
| full_image = input["full_image"] | |
| location = input["location"] | |
| image_output, _ = pipe( | |
| prompt=ttf_img, | |
| glyph=glyph, | |
| masked_image=masked_image, | |
| mask=mask, | |
| num_inference_steps=num_inference_steps, | |
| ) | |
| mask_np = mask.cpu().detach().numpy().astype(np.uint8) | |
| coords = np.column_stack(np.where(mask_np == 0)) | |
| img = image_output[0] | |
| if coords.size > 0: | |
| y_min, x_min = coords[:, 1].min(), coords[:, 2].min() | |
| y_max, x_max = coords[:, 1].max(), coords[:, 2].max() | |
| cropped_output_image = img[:, y_min:y_max+1, x_min:x_max+1] | |
| else: | |
| cropped_output_image = img | |
| cropped_output_image_np = (cropped_output_image * 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8) | |
| cropped_output_image_pil = Image.fromarray(cropped_output_image_np) | |
| x_min, y_min, x_max, y_max = location[0], location[1], location[2], location[3] | |
| full_image_patch = full_image[y_min:y_max, x_min:x_max, :] | |
| resize_trans = transforms.Resize((full_image_patch.shape[0], full_image_patch.shape[1])) | |
| resize_mask = resize_trans(mask).cpu() | |
| resize_img = resize_trans(img).cpu() | |
| img_mask = torch.where(resize_mask < 0.5, torch.tensor(0.0), torch.tensor(1.0)) | |
| img_mask = img_mask.expand_as(resize_img) | |
| full_image_patch_tensor = transforms.ToTensor()(full_image_patch).cpu() | |
| full_image_patch_tensor = full_image_patch_tensor * img_mask + resize_img * (1 - img_mask) | |
| full_image_tensor = transforms.ToTensor()(full_image).cpu() | |
| full_image_tensor[:, y_min:y_max, x_min:x_max] = full_image_patch_tensor | |
| full_image_np = full_image_tensor.permute(1, 2, 0).numpy() | |
| full_image_pil = Image.fromarray((full_image_np * 255).astype(np.uint8)) | |
| return cropped_output_image_pil, full_image_pil, img_with_outline | |
| demo_1 = Image.open("./imgs/demo_1.jpg") | |
| demo_2 = Image.open("./imgs/demo_2.jpg") | |
| def update_image(sample): | |
| if sample == "Sample 1": | |
| return demo_1 | |
| elif sample == "Sample 2": | |
| return demo_2 | |
| else: | |
| return None | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# TextSSR Demo") | |
| gr.Markdown("Upload an image, draw a mask on the image, and enter text content for region synthesis and image editing.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sample_choice = gr.Radio(choices=["Sample 1", "Sample 2"], label="Choose a Sample Background") | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| mask_input = gr.Image(type="pil", label="Draw Mask on Image", tool="sketch", interactive=True) | |
| text_input = gr.Textbox(label="Text to Synthesize / Edit") | |
| outlined_image = gr.Image(type="pil", label="Original Image with Mask Outline") | |
| with gr.Row(): | |
| num_points_slider = gr.Slider( | |
| minimum=4, | |
| maximum=20, | |
| value=14, | |
| step=1, | |
| label="Control Points", | |
| info="Adjust mask complexity (4-20 points)" | |
| ) | |
| num_steps_slider = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=20, | |
| step=1, | |
| label="Inference Steps", | |
| info="More steps = better quality but slower" | |
| ) | |
| submit_btn = gr.Button("Process Image") | |
| with gr.Column(): | |
| output_region = gr.Image(type="pil", label="Modified Region") | |
| output_full = gr.Image(type="pil", label="Modified Full Image") | |
| # Update input image based on the selected sample background | |
| sample_choice.change( | |
| update_image, | |
| inputs=[sample_choice], | |
| outputs=[input_image] | |
| ) | |
| # Update mask when input image changes | |
| input_image.change( | |
| lambda image: image, # Pass through image to mask_input | |
| inputs=[input_image], | |
| outputs=[mask_input] | |
| ) | |
| # Process image when submit button is clicked (updated to include num_points and num_inference_steps parameters) | |
| submit_btn.click( | |
| process_image, | |
| inputs=[input_image, mask_input, text_input, num_points_slider, num_steps_slider], | |
| outputs=[output_region, output_full, outlined_image] | |
| ) | |
| iface.launch() |