|
|
import torch |
|
|
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import ImageDraw, ImageOps, Image, ImageFont |
|
|
from scipy.spatial import KDTree |
|
|
from webcolors import CSS3_HEX_TO_NAMES, hex_to_rgb, hex_to_name |
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
model_id = "timbrooks/instruct-pix2pix" |
|
|
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None) |
|
|
self.pipe.to("cuda") |
|
|
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config) |
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
info=data['inputs'] |
|
|
|
|
|
|
|
|
image=info.pop("image",data) |
|
|
logo = info.pop("logo",data) |
|
|
logo = base64.b64decode(logo) |
|
|
logo = Image.open(BytesIO(logo)).convert('RGB') |
|
|
|
|
|
|
|
|
seed = info.pop("seed", data) |
|
|
|
|
|
|
|
|
punchline_text=info.pop("punchline_text", data.get("punchline_text", "Punchline Text")) |
|
|
punchline_text_max_width=info.pop("punchline_text_max_width", data.get("punchline_text_max_width", 550)) |
|
|
punchline_text_color=info.pop("punchline_text_color", data.get("punchline_text_color", "#008000")) |
|
|
spacing_image_text=info.pop("spacing_image_text", data.get("spacing_image_text", 0)) |
|
|
|
|
|
|
|
|
|
|
|
color_code=info.pop("color_code",data) |
|
|
|
|
|
num_inference_steps = info.pop("num_inference_steps", data.get("num_inference_steps", 40)) |
|
|
|
|
|
image_guidance_scale=info.pop("image_guidance_scale",data) |
|
|
|
|
|
|
|
|
|
|
|
guidance_scale = info.pop("guidance_scale", data.get("guidance_scale", 7.5)) |
|
|
|
|
|
|
|
|
|
|
|
button_color = info.pop("button_color", data.get("button_color", "#008000")) |
|
|
|
|
|
button_text = info.pop("button_text", data.get("button_text")) |
|
|
|
|
|
button_font = info.pop("button_font", data.get("button_font", cv2.FONT_HERSHEY_TRIPLEX)) |
|
|
|
|
|
button_font_scale = info.pop("button_font_scale", data.get("button_font_scale", 0.75)) |
|
|
button_font_thickness = info.pop("button_font_thickness", data.get("button_font_thickness", 1)) |
|
|
button_text_color = info.pop("button_text_color", data.get("button_text_color", "#FFFFFF")) |
|
|
spacing_between_punchline_and_button = info.pop("spacing_between_punchline_and_button", data.get("spacing_between_punchline_and_button", 10)) |
|
|
|
|
|
|
|
|
text_prompt=info.pop("prompt",data) |
|
|
|
|
|
|
|
|
|
|
|
image=base64.b64decode(image) |
|
|
raw_images = Image.open(BytesIO(image)).convert('RGB') |
|
|
raw_images = raw_images.convert("RGB") |
|
|
raw_images = raw_images.resize((512, 512)) |
|
|
|
|
|
result_prompt, negative_prompt = self.build_prompt(text_prompt, color_code) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
images = self.pipe(result_prompt, negative_prompt = negative_prompt, image=raw_images, num_inference_steps=num_inference_steps, guidance_scale = guidance_scale, image_guidance_scale = image_guidance_scale).images |
|
|
img=images[0] |
|
|
|
|
|
img.save("./1.png") |
|
|
logo.save("./logo.png") |
|
|
|
|
|
resulting_template = self.create_image_template( |
|
|
base_image_path="./1.png", |
|
|
logo_path="./logo.png", |
|
|
punchline_text=punchline_text, |
|
|
punchline_text_color=punchline_text_color, |
|
|
punchline_text_max_width=punchline_text_max_width, |
|
|
spacing_image_text=spacing_image_text, |
|
|
button_color=button_color, |
|
|
button_text=button_text, |
|
|
button_font=button_font, |
|
|
button_font_scale=button_font_scale, |
|
|
button_font_thickness=button_font_thickness, |
|
|
button_text_color=button_text_color, |
|
|
spacing_between_punchline_and_button=spacing_between_punchline_and_button, |
|
|
corner_radius=30) |
|
|
|
|
|
|
|
|
|
|
|
resulting_template.save("./result.png") |
|
|
|
|
|
with open('./result.png','rb') as img_file: |
|
|
encoded_string = base64.b64encode(img_file.read()).decode('utf-8') |
|
|
return {'image':encoded_string} |
|
|
|
|
|
def build_prompt(self, text_prompt, color_code): |
|
|
color_name = self.hex_to_name(color_code) |
|
|
q_prompt = "No blur, high quality, no distortion over objects, text remains clear and undistorted. Do not modify or alter the logo." |
|
|
base_prompts = text_prompt + q_prompt |
|
|
coloring_prompt = f" with a {color_name} color applied to only the designated area or key element in the picture by avoiding it becoming the dominant color of the image, leaving the text, logo, and shadows untouched." |
|
|
result_prompt = f"{base_prompts}{coloring_prompt}" |
|
|
negative_prompt = f'{color_name} shadows, worst quality, low quality, low res, blurry, watermark, cropped, jpeg artifacts, error, sketch ,duplicate, ugly, monochrome, horror, mutation, disgusting' |
|
|
|
|
|
return result_prompt, negative_prompt |
|
|
|
|
|
def hex_to_name(self, hex_color): |
|
|
rgb_tuple = tuple(int(hex_color[i:i+2], 16) for i in (1, 3, 5)) |
|
|
names = [] |
|
|
rgb_values = [] |
|
|
|
|
|
for color_hex, color_name in CSS3_HEX_TO_NAMES.items(): |
|
|
names.append(color_name) |
|
|
rgb_values.append(hex_to_rgb(color_hex)) |
|
|
|
|
|
kdt_db = KDTree(rgb_values) |
|
|
distance, index = kdt_db.query(rgb_tuple) |
|
|
|
|
|
color_mapping = { |
|
|
'tomato': 'red', |
|
|
'chocolate': 'brown-black', |
|
|
'darkgoldenrod': 'yellow', |
|
|
} |
|
|
|
|
|
color_name = names[index] |
|
|
mapped_color = color_mapping.get(color_name, color_name) |
|
|
|
|
|
return mapped_color |
|
|
|
|
|
|
|
|
|
|
|
def draw_text(self, img, text, font=cv2.FONT_HERSHEY_TRIPLEX, pos=(20, 45), font_scale=1, font_thickness=1, text_color=(0, 0, 255)): |
|
|
x, y = pos |
|
|
text_size, _ = cv2.getTextSize(text, font, font_scale, font_thickness) |
|
|
text_w, text_h = text_size |
|
|
cv2.putText(img, text, (x, y + text_h), font, font_scale, text_color, font_thickness) |
|
|
return text_size |
|
|
|
|
|
|
|
|
def smooth_corners(self, image, radius, alpha=255): |
|
|
factor = 5 |
|
|
radius = radius * factor |
|
|
size = (image.width, image.height) |
|
|
|
|
|
|
|
|
mask = Image.new('RGBA', (size[0] * factor, size[1] * factor), (255, 255, 255, 0)) |
|
|
|
|
|
|
|
|
corner = Image.new('RGBA', (radius, radius), (0, 0, 0, 0)) |
|
|
draw = ImageDraw.Draw(corner) |
|
|
draw.pieslice((0, 0, radius * 2, radius * 2), 180, 270, fill=(0, 0, 0, alpha + 55)) |
|
|
|
|
|
|
|
|
mx, my = (size[0] * factor, size[1] * factor) |
|
|
mask.paste(corner, (0, 0), corner) |
|
|
mask.paste(corner.rotate(90), (0, my - radius), corner.rotate(90)) |
|
|
mask.paste(corner.rotate(180), (mx - radius, my - radius), corner.rotate(180)) |
|
|
mask.paste(corner.rotate(270), (mx - radius, 0), corner.rotate(270)) |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(mask) |
|
|
draw.rectangle([(radius, 0), (mx - radius, my)], fill=(0, 0, 0, alpha)) |
|
|
draw.rectangle([(0, radius), (mx, my - radius)], fill=(0, 0, 0, alpha)) |
|
|
|
|
|
|
|
|
mask = mask.resize(size, Image.ANTIALIAS) |
|
|
|
|
|
|
|
|
result_image = Image.new('RGBA', size) |
|
|
result_image.paste(image, (0, 0), mask) |
|
|
|
|
|
return result_image |
|
|
|
|
|
def add_logo_to_image(self, base_image_path, logo_path, corner_radius=30): |
|
|
|
|
|
base_image = Image.open(base_image_path) |
|
|
smoothed_image = self.smooth_corners(base_image, radius=30) |
|
|
|
|
|
|
|
|
white_background = Image.new("RGB", base_image.size, "white") |
|
|
|
|
|
|
|
|
result_image = Image.new("RGBA", base_image.size) |
|
|
result_image.paste(white_background, (0, 0)) |
|
|
result_image.paste(smoothed_image, (0, 0), smoothed_image) |
|
|
|
|
|
|
|
|
logo_path = logo_path |
|
|
|
|
|
|
|
|
logo = Image.open(logo_path) |
|
|
|
|
|
|
|
|
desired_height = base_image.height / 3 |
|
|
|
|
|
|
|
|
scaling_factor = desired_height / logo.height |
|
|
|
|
|
|
|
|
new_width = int(logo.width * scaling_factor) |
|
|
|
|
|
|
|
|
resized_logo = logo.resize((new_width, int(desired_height))) |
|
|
|
|
|
|
|
|
image_width, image_height = base_image.size |
|
|
|
|
|
|
|
|
white_background = Image.new("RGB", (image_width, resized_logo.height), "white") |
|
|
|
|
|
|
|
|
x_logo = (image_width - resized_logo.width) // 2 |
|
|
y_logo = 0 |
|
|
|
|
|
|
|
|
white_background.paste(resized_logo, (x_logo, y_logo)) |
|
|
|
|
|
|
|
|
result_logo_image = Image.new("RGB", (image_width, image_height + white_background.height)) |
|
|
|
|
|
|
|
|
result_logo_image.paste(white_background, (0, 0)) |
|
|
result_logo_image.paste(result_image, (0, white_background.height)) |
|
|
|
|
|
|
|
|
return result_logo_image |
|
|
|
|
|
|
|
|
def create_template(self, logo_and_image, punchline_text, punchline_text_color, punchline_text_max_width, spacing_image_text): |
|
|
|
|
|
result_logo_image_resized = logo_and_image.resize((325, 425)) |
|
|
|
|
|
whitespace = Image.new("RGB", (645, 645), "white") |
|
|
|
|
|
|
|
|
x_logo = (whitespace.width - result_logo_image_resized.width) // 2 |
|
|
y_logo = 10 |
|
|
whitespace.paste(result_logo_image_resized, (x_logo, y_logo)) |
|
|
|
|
|
|
|
|
|
|
|
font = cv2.FONT_HERSHEY_TRIPLEX |
|
|
font_scale = 1 |
|
|
font_thickness = 2 |
|
|
|
|
|
text_size = cv2.getTextSize(punchline_text, font, font_scale, font_thickness)[0] |
|
|
|
|
|
|
|
|
y_logo += spacing_image_text |
|
|
|
|
|
x = (whitespace.size[1] - punchline_text_max_width) // 2 |
|
|
y = y_logo + 2 * text_size[1] + result_logo_image_resized.height |
|
|
|
|
|
lines = [] |
|
|
|
|
|
|
|
|
words = punchline_text.split() |
|
|
current_line = "" |
|
|
for word in words: |
|
|
|
|
|
current_line_size = cv2.getTextSize(current_line + " " + word, font, font_scale, font_thickness)[0] |
|
|
if current_line_size[0] > punchline_text_max_width: |
|
|
|
|
|
lines.append(current_line) |
|
|
current_line = word |
|
|
else: |
|
|
if current_line: |
|
|
current_line += " " |
|
|
current_line += word |
|
|
|
|
|
|
|
|
lines.append(current_line) |
|
|
w_1 = whitespace.size[1] |
|
|
|
|
|
|
|
|
for line in lines: |
|
|
text_size = cv2.getTextSize(line, font, font_scale, font_thickness)[0] |
|
|
x = (w_1 - text_size[0]) // 2 |
|
|
whitespace = np.asarray(whitespace) |
|
|
|
|
|
position = (int(x), int(y)) |
|
|
whitespace = cv2.putText(whitespace, line, position, font, font_scale, punchline_text_color, thickness=2) |
|
|
y += 1.5 * text_size[1] |
|
|
|
|
|
punchline_text_height = len(lines) * text_size[1] |
|
|
|
|
|
template = Image.fromarray(whitespace) |
|
|
|
|
|
return template, punchline_text_height, len(lines) |
|
|
|
|
|
def draw_button_with_text(self, button_color, button_text, button_font=cv2.FONT_HERSHEY_TRIPLEX, button_font_scale = 0.7, button_font_thickness=1, button_text_color=(0, 0, 255)): |
|
|
|
|
|
background_image = Image.new("RGB", (300, 80), "white") |
|
|
button_image_pil = self.smooth_corners(background_image, 20, alpha=255) |
|
|
|
|
|
button_image_pil = ImageOps.colorize(button_image_pil.convert("L"), "white", button_color) |
|
|
|
|
|
button_width, button_height = button_image_pil.size |
|
|
text_size, _ = cv2.getTextSize(button_text, button_font, button_font_scale, button_font_thickness) |
|
|
text_w, text_h = text_size |
|
|
text_x = (button_width - text_w) // 2 if text_w <= button_width else 0 |
|
|
text_y = (button_height - text_h) // 2 |
|
|
|
|
|
button_image_np = np.asarray(button_image_pil) |
|
|
text_size = self.draw_text(button_image_np, button_text, button_font, (text_x, text_y), button_font_scale, button_font_thickness, button_text_color) |
|
|
|
|
|
|
|
|
button_image_pil = Image.fromarray(button_image_np) |
|
|
|
|
|
return button_image_pil |
|
|
|
|
|
def concat_template_with_button(self, template, button_image, line_count, spacing_between_punchline_and_button, punchline_text_height): |
|
|
|
|
|
result_template = template.copy() |
|
|
|
|
|
|
|
|
button_width, button_height = button_image.size |
|
|
button_x = (template.width - button_width) // 2 |
|
|
|
|
|
line_count = line_count - 1 if line_count > 1 else 1 |
|
|
button_y = 425 + 10 + punchline_text_height + line_count* 33 + spacing_between_punchline_and_button |
|
|
|
|
|
|
|
|
result_template.paste(button_image, (button_x, int(button_y))) |
|
|
|
|
|
return result_template |
|
|
|
|
|
def create_image_template(self, |
|
|
base_image_path, |
|
|
logo_path, |
|
|
punchline_text, |
|
|
punchline_text_color="#008000", |
|
|
punchline_text_max_width=550, |
|
|
spacing_image_text=0, |
|
|
button_color="#008000", |
|
|
button_text="Call Action Text Here! >", |
|
|
button_font=cv2.FONT_HERSHEY_TRIPLEX, |
|
|
button_font_scale=0.7, |
|
|
button_font_thickness=1, |
|
|
button_text_color="#FFFFFF", |
|
|
spacing_between_punchline_and_button=10, |
|
|
corner_radius=30): |
|
|
|
|
|
|
|
|
|
|
|
punchline_text_color = tuple(int(punchline_text_color[i:i+2], 16) for i in (1, 3, 5)) |
|
|
button_color = tuple(int(button_color[i:i+2], 16) for i in (1, 3, 5)) |
|
|
button_text_color = tuple(int(button_text_color[i:i+2], 16) for i in (1, 3, 5)) |
|
|
|
|
|
|
|
|
result_image = self.add_logo_to_image(base_image_path, logo_path, corner_radius=corner_radius) |
|
|
|
|
|
|
|
|
|
|
|
template, punchline_text_height, line_count = self.create_template( |
|
|
result_image, |
|
|
punchline_text, |
|
|
punchline_text_color, |
|
|
punchline_text_max_width, |
|
|
spacing_image_text, |
|
|
) |
|
|
|
|
|
|
|
|
button_image = self.draw_button_with_text( |
|
|
button_color, |
|
|
button_text, |
|
|
button_font, |
|
|
button_font_scale, |
|
|
button_font_thickness, |
|
|
button_text_color, |
|
|
) |
|
|
|
|
|
|
|
|
result_template = self.concat_template_with_button( |
|
|
template, button_image, line_count, spacing_between_punchline_and_button, punchline_text_height |
|
|
) |
|
|
|
|
|
return result_template |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
my_handler=EndpointHandler(path='.') |
|
|
|