TextSSR / app.py
Yesianrohn's picture
Update app.py
eb69848 verified
import os
import random
import numpy as np
import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.utils.torch_utils import randn_tensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Function definitions
def calculate_square(full_image, mask):
mask_array = np.array(mask)
if len(mask_array.shape) == 2:
gray = mask_array
else:
gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY)
coords = cv2.findNonZero(gray)
x, y, w, h = cv2.boundingRect(coords)
L = max(w, h)
L = min(full_image.shape[1], full_image.shape[0] ,L)
if w < L:
sx0 = random.randint(max(0, x+w - L), min(x, full_image.shape[1] - L)+1)
sx1 = sx0 + L
else:
sx0, sx1 = x, x+w
if h < L:
sy0 = random.randint(max(0, y+h - L), min(y, full_image.shape[0] - L)+1)
sy1 = sy0 + L
else:
sy0, sy1 = y, y+h
return [sx0, sy0, sx1, sy1]
def generate_mask(trans_image, resolution, mask, location):
mask = np.array(mask.convert("L"))[location[1]:location[3], location[0]:location[2]]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((resolution, resolution))
])
mask = transform(mask)
mask = torch.where(mask > 0.5, torch.tensor(0.0), torch.tensor(1.0))
masked_image = trans_image * mask.expand_as(trans_image)
mask_np = mask.squeeze().byte().cpu().numpy()
mask_np = np.transpose(mask_np)
points = np.column_stack(np.where(mask_np == 0))
rect = cv2.minAreaRect(points)
return mask, masked_image, rect
class AnytextDataset():
def __init__(
self,
resolution=256,
ttf_size=64,
max_len=25,
):
self.resolution = resolution
self.ttf_size = ttf_size
self.max_len = max_len
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((resolution, resolution)),
transforms.Normalize(mean=(0.5,), std=(0.5,)),
])
def get_input(self, image, mask, text):
full_image = np.array(image.convert('RGB'))
location = calculate_square(full_image, mask)
crop_image = full_image[location[1]:location[3], location[0]:location[2]]
trans_image = self.transform(crop_image)
mask, masked_image, mask_rect = generate_mask(trans_image, self.resolution, mask, location)
text = text[:self.max_len]
draw_ttf = self.draw_text(text)
glyph = self.draw_glyph(text, mask_rect)
info = {
"image": trans_image,
'mask': mask,
'masked_image': masked_image,
'ttf_img': draw_ttf,
'glyph': glyph,
"text": text,
"full_image": full_image,
"location": location,
}
return info
def draw_text(self, text, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"):
R = self.ttf_size
fs = int(0.8*R)
interval = 128 // self.max_len
img_tensor = torch.ones((self.max_len, R, R), dtype=torch.float)
for i, char in enumerate(text):
img = Image.new('L', (R, R), 255)
draw = ImageDraw.Draw(img)
font = ImageFont.truetype(font_path, fs)
text_size = font.getsize(char)
text_position = ((R - text_size[0]) // 2, (R - text_size[1]) // 2)
draw.text(text_position, char, font=font, fill=interval*i)
img_tensor[i] = torch.from_numpy(np.array(img)).float() / 255.0
return img_tensor
def draw_glyph(self, text, rect, font_path="AlibabaPuHuiTi-3-85-Bold.ttf"):
resolution = self.resolution
bg_img = np.ones((resolution, resolution, 3), dtype=np.uint8) * 255
font = ImageFont.truetype(font_path, self.ttf_size)
text_img = Image.new('RGB', font.getsize(text), (255, 255, 255))
draw = ImageDraw.Draw(text_img)
draw.text((0, 0), text, font=font, fill=(127, 127, 127))
text_np = np.array(text_img)
rec_h, rec_w = rect[1]
box = cv2.boxPoints(rect)
if rec_h > rec_w * 1.5:
box = [box[1], box[2], box[3], box[0]]
dst_points = np.array(box, dtype=np.float32)
src_points = np.float32([[0, 0], [text_np.shape[1], 0], [text_np.shape[1], text_np.shape[0]], [0, text_np.shape[0]]])
M = cv2.getPerspectiveTransform(src_points, dst_points)
warped_text_img = cv2.warpPerspective(text_np, M, (resolution, resolution))
mask = np.any(warped_text_img == [127, 127, 127], axis=-1)
bg_img[mask] = warped_text_img[mask]
bg_img = bg_img.astype(np.float32) / 255.0
bg_img_tensor = torch.from_numpy(bg_img).permute(2, 0, 1)
return bg_img_tensor
class StableDiffusionPipeline:
def __init__(self, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: DDPMScheduler, device):
self.vae = vae
self.unet = unet
self.scheduler = scheduler
self.device = device
self.vae.to(self.device)
self.unet.to(self.device)
self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
@torch.no_grad()
def __call__(
self,
prompt: torch.FloatTensor,
glyph: torch.FloatTensor,
masked_image: torch.FloatTensor,
mask: torch.FloatTensor,
num_inference_steps: int = 20,
):
if masked_image is None:
raise ValueError("masked_image input cannot be undefined.")
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
vae_scale_factor = self.vae_scale_factor
_, mask_height, mask_width = mask.size()
mask = mask.unsqueeze(0)
glyph = glyph.unsqueeze(0)
masked_image = masked_image.unsqueeze(0)
prompt = prompt.unsqueeze(0)
mask = torch.nn.functional.interpolate(mask, size=[mask_width // vae_scale_factor, mask_height // vae_scale_factor])
glyph_latents = self.vae.encode(glyph).latent_dist.sample() * self.vae.config.scaling_factor
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample() * self.vae.config.scaling_factor
shape = (1, self.vae.config.latent_channels, mask_height // vae_scale_factor, mask_width // vae_scale_factor)
latents = randn_tensor(shape, generator=torch.manual_seed(20), device=self.device) * self.scheduler.init_noise_sigma
for t in timesteps:
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
sample = torch.cat([latent_model_input, masked_image_latents, glyph_latents, mask], dim=1)
noise_pred = self.unet(sample=sample, timestep=t, encoder_hidden_states=prompt, ).sample
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
pred_latents = latents / self.vae.config.scaling_factor
image_vae = self.vae.decode(pred_latents).sample
image = (image_vae / 2 + 0.5).clamp(0, 1)
return image, image_vae
# Load models (adjust the paths to your model directories)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
vae = AutoencoderKL.from_pretrained("Yesianrohn/TextSSR", subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained("Yesianrohn/TextSSR", subfolder="unet", torch_dtype=dtype)
noise_scheduler = DDPMScheduler.from_pretrained("Yesianrohn/TextSSR", subfolder="scheduler")
# Create pipeline
pipe = StableDiffusionPipeline(vae=vae, unet=unet, scheduler=noise_scheduler, device=device)
# Create dataset
dataset = AnytextDataset(
resolution=256,
ttf_size=64,
max_len=25,
)
def edit_mask(mask, num_points=14):
mask_array = np.array(mask)
if len(mask_array.shape) > 2:
mask_array = mask_array[:, :, 0] if mask_array.shape[2] >= 1 else mask_array
binary_mask = (mask_array > 0).astype(np.uint8) * 255
contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return Image.fromarray(binary_mask)
filled_mask = np.zeros_like(binary_mask)
cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
contours, hierarchy = cv2.findContours(filled_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
largest_contour = max(contours, key=cv2.contourArea)
epsilon = 0.01 * cv2.arcLength(largest_contour, True)
approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
attempts = 0
max_attempts = 20
while len(approx_contour) > num_points and attempts < max_attempts:
epsilon *= 1.1
approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
attempts += 1
attempts = 0
while len(approx_contour) < num_points and epsilon > 0.0001 and attempts < max_attempts:
epsilon *= 0.9
approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)
attempts += 1
new_mask = np.zeros_like(binary_mask)
points = [tuple(pt[0]) for pt in approx_contour]
img = Image.fromarray(new_mask)
draw = ImageDraw.Draw(img)
if points:
draw.polygon(points, fill=255)
return img
else:
return Image.fromarray(filled_mask)
def process_image(image, mask, text, num_points, num_inference_steps):
print(text)
edited_mask = edit_mask(mask["mask"], num_points=num_points)
img_with_outline = image.copy()
draw = ImageDraw.Draw(img_with_outline)
mask_np = np.array(edited_mask)
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
largest_contour = max(contours, key=cv2.contourArea)
points = [tuple(pt[0]) for pt in largest_contour]
if len(points) >= 2:
draw.line(points + [points[0]], fill=(255, 0, 0), width=3)
input = dataset.get_input(image=image, mask=edited_mask, text=text)
masked_image = input["masked_image"].to(device)
mask = input["mask"].to(device)
ttf_img = input["ttf_img"].to(device)
glyph = input["glyph"].to(device)
full_image = input["full_image"]
location = input["location"]
image_output, _ = pipe(
prompt=ttf_img,
glyph=glyph,
masked_image=masked_image,
mask=mask,
num_inference_steps=num_inference_steps,
)
mask_np = mask.cpu().detach().numpy().astype(np.uint8)
coords = np.column_stack(np.where(mask_np == 0))
img = image_output[0]
if coords.size > 0:
y_min, x_min = coords[:, 1].min(), coords[:, 2].min()
y_max, x_max = coords[:, 1].max(), coords[:, 2].max()
cropped_output_image = img[:, y_min:y_max+1, x_min:x_max+1]
else:
cropped_output_image = img
cropped_output_image_np = (cropped_output_image * 255).cpu().permute(1, 2, 0).numpy().astype(np.uint8)
cropped_output_image_pil = Image.fromarray(cropped_output_image_np)
x_min, y_min, x_max, y_max = location[0], location[1], location[2], location[3]
full_image_patch = full_image[y_min:y_max, x_min:x_max, :]
resize_trans = transforms.Resize((full_image_patch.shape[0], full_image_patch.shape[1]))
resize_mask = resize_trans(mask).cpu()
resize_img = resize_trans(img).cpu()
img_mask = torch.where(resize_mask < 0.5, torch.tensor(0.0), torch.tensor(1.0))
img_mask = img_mask.expand_as(resize_img)
full_image_patch_tensor = transforms.ToTensor()(full_image_patch).cpu()
full_image_patch_tensor = full_image_patch_tensor * img_mask + resize_img * (1 - img_mask)
full_image_tensor = transforms.ToTensor()(full_image).cpu()
full_image_tensor[:, y_min:y_max, x_min:x_max] = full_image_patch_tensor
full_image_np = full_image_tensor.permute(1, 2, 0).numpy()
full_image_pil = Image.fromarray((full_image_np * 255).astype(np.uint8))
return cropped_output_image_pil, full_image_pil, img_with_outline
demo_1 = Image.open("./imgs/demo_1.jpg")
demo_2 = Image.open("./imgs/demo_2.jpg")
def update_image(sample):
if sample == "Sample 1":
return demo_1
elif sample == "Sample 2":
return demo_2
else:
return None
with gr.Blocks() as iface:
gr.Markdown("# TextSSR Demo")
gr.Markdown("Upload an image, draw a mask on the image, and enter text content for region synthesis and image editing.")
with gr.Row():
with gr.Column():
sample_choice = gr.Radio(choices=["Sample 1", "Sample 2"], label="Choose a Sample Background")
input_image = gr.Image(type="pil", label="Input Image")
mask_input = gr.Image(type="pil", label="Draw Mask on Image", tool="sketch", interactive=True)
text_input = gr.Textbox(label="Text to Synthesize / Edit")
outlined_image = gr.Image(type="pil", label="Original Image with Mask Outline")
with gr.Row():
num_points_slider = gr.Slider(
minimum=4,
maximum=20,
value=14,
step=1,
label="Control Points",
info="Adjust mask complexity (4-20 points)"
)
num_steps_slider = gr.Slider(
minimum=1,
maximum=50,
value=20,
step=1,
label="Inference Steps",
info="More steps = better quality but slower"
)
submit_btn = gr.Button("Process Image")
with gr.Column():
output_region = gr.Image(type="pil", label="Modified Region")
output_full = gr.Image(type="pil", label="Modified Full Image")
# Update input image based on the selected sample background
sample_choice.change(
update_image,
inputs=[sample_choice],
outputs=[input_image]
)
# Update mask when input image changes
input_image.change(
lambda image: image, # Pass through image to mask_input
inputs=[input_image],
outputs=[mask_input]
)
# Process image when submit button is clicked (updated to include num_points and num_inference_steps parameters)
submit_btn.click(
process_image,
inputs=[input_image, mask_input, text_input, num_points_slider, num_steps_slider],
outputs=[output_region, output_full, outlined_image]
)
iface.launch()