AnsenH's picture
Upload 14 files
dbf1d4a verified
raw
history blame
4.96 kB
import base64
import requests
from json import dumps, dump
from PIL import Image
from io import BytesIO
import time
from dotenv import load_dotenv
import os
load_dotenv()
endpoint = 'https://serving.hopter.staging.picc.co/api/v1/services/gen-ai-image-expansion/predictions'
token = os.getenv('API_TOKEN')
def pil_to_b64(image:Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG", quality=80)
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
prefix = 'data:image/png;base64,'
return prefix + img_str
def b64_to_pil(b64_string):
# Remove the Base64 prefix if present
if b64_string.startswith('data:image'):
b64_string = b64_string.split(';base64,', 1)[1]
# Decode the Base64 string to bytes
image_bytes = base64.b64decode(b64_string)
# Create a BytesIO object and load the image bytes
image_buffer = BytesIO(image_bytes)
image = Image.open(image_buffer)
return image
def resize_image(image, max_height=768):
scale = max_height/image.height
return image.resize((int(image.width * scale), int(image.height * scale)))
def prepare_init_image_mask(images: [Image.Image], alpha_gradient_width=80, init_image_height=768): # type: ignore
total_width = sum([ im.width for im in images])
init_image = Image.new('RGBA', (total_width,init_image_height))
# Paste input images on init_image
x_coord = 0
for im in images:
init_image.paste(im, (x_coord, 0))
x_coord += im.width
# Add linear alpha gradient
x_coord = 0
is_right_patch = True
i = 0
while i <= len(images) - 1:
im = images[i]
if i == len(images) - 1 and is_right_patch:
break
if is_right_patch:
alpha = Image.linear_gradient('L').rotate(-90).resize((alpha_gradient_width, init_image_height))
tmp_img = init_image.crop((x_coord+im.width - alpha_gradient_width, 0, x_coord+im.width, init_image_height))
tmp_img.putalpha(alpha)
init_image.paste(tmp_img, (x_coord+im.width - alpha_gradient_width, 0))
x_coord += im.width
i += 1
is_right_patch = False
else:
alpha = Image.linear_gradient('L').rotate(90).resize((alpha_gradient_width, init_image_height))
tmp_img = init_image.crop((x_coord, 0, x_coord+alpha_gradient_width, init_image_height))
tmp_img.putalpha(alpha)
init_image.paste(tmp_img, (x_coord, 0))
is_right_patch = True
# Generate inpainting mask
mask = Image.new('RGBA', (total_width, init_image_height), (0, 0, 0))
x_coord = 0
for im in images[:-1]:
mask_patch = Image.new('RGBA', (alpha_gradient_width*2, init_image_height), (255, 255, 255))
mask.paste(mask_patch, (x_coord + im.width - alpha_gradient_width, 0))
x_coord += im.width
# Crop init_image and mask into batches
x_coord = 0
init_image_mask_pair = []
init_image_patch_x_coord = []
for im in images[:-1]:
crop_start_x = x_coord + im.width - init_image_height // 2
crop_end_x = x_coord + im.width + init_image_height // 2
tmp_img = init_image.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
tmp_mask = mask.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
init_image_mask_pair.append((tmp_img, tmp_mask))
init_image_patch_x_coord.append(crop_start_x)
x_coord += im.width
return init_image, mask, init_image_mask_pair, init_image_patch_x_coord
def attach_images_with_loc(inpainted_results, init_image_patch_x_coord, full_init_img):
full_init_img = full_init_img
for im, loc in zip(inpainted_results, init_image_patch_x_coord):
full_init_img.paste(im, (loc, 0))
return full_init_img
def inpainting_api_call(input_image, input_mask, token, endpoint):
body = {
"input": {
"initial_image_b64": pil_to_b64(input_image),
"mask_image_b64": pil_to_b64(input_mask.convert('L'))
}
}
json_data = dumps(body)
start = time.time()
resp_inpaint = requests.post(endpoint, data=json_data, headers={"Authorization": f"Bearer {token}"})
print(f"Execution time: {time.time() - start}")
return b64_to_pil(resp_inpaint.json()['output']['inpainted_image_b64'])
def process_images_and_inpaint(images, alpha_gradient_width=100, init_image_height=768):
images = [ resize_image(b64_to_pil(im)).convert("RGBA") for im in images ]
full_init_img, full_mask, init_image_mask_pair, init_image_patch_x_coord = prepare_init_image_mask(images, alpha_gradient_width, init_image_height)
results = [ inpainting_api_call(im, mask, token, endpoint) for im, mask in init_image_mask_pair]
attached_image = pil_to_b64(attach_images_with_loc(results, init_image_patch_x_coord, full_init_img))
return attached_image