File size: 4,962 Bytes
dbf1d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import base64
import requests
from json import dumps, dump
from PIL import Image
from io import BytesIO
import time
from dotenv import load_dotenv
import os

load_dotenv()
endpoint = 'https://serving.hopter.staging.picc.co/api/v1/services/gen-ai-image-expansion/predictions'
token = os.getenv('API_TOKEN')

def pil_to_b64(image:Image.Image) -> str:
    buffered = BytesIO()
    image.save(buffered, format="PNG", quality=80)
    img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
    prefix = 'data:image/png;base64,'
    return prefix + img_str

def b64_to_pil(b64_string):
    # Remove the Base64 prefix if present
    if b64_string.startswith('data:image'):
        b64_string = b64_string.split(';base64,', 1)[1]
    # Decode the Base64 string to bytes
    image_bytes = base64.b64decode(b64_string)
    # Create a BytesIO object and load the image bytes
    image_buffer = BytesIO(image_bytes)
    image = Image.open(image_buffer)
    return image

def resize_image(image, max_height=768):
    scale = max_height/image.height
    return image.resize((int(image.width * scale), int(image.height * scale)))

def prepare_init_image_mask(images: [Image.Image], alpha_gradient_width=80, init_image_height=768):  # type: ignore
    total_width = sum([ im.width for im in images])
    init_image = Image.new('RGBA', (total_width,init_image_height))
    
    # Paste input images on init_image
    x_coord = 0
    for im in images: 
        init_image.paste(im, (x_coord, 0))
        x_coord += im.width
    
    # Add linear alpha gradient  
    x_coord = 0
    is_right_patch = True
    i = 0
    while i <= len(images) - 1:
        im = images[i]
        if i == len(images) - 1 and is_right_patch: 
            break 
        if is_right_patch:
            alpha = Image.linear_gradient('L').rotate(-90).resize((alpha_gradient_width, init_image_height))
            tmp_img = init_image.crop((x_coord+im.width - alpha_gradient_width, 0, x_coord+im.width, init_image_height))
            tmp_img.putalpha(alpha)
            init_image.paste(tmp_img, (x_coord+im.width - alpha_gradient_width, 0))
            x_coord += im.width
            i += 1
            is_right_patch = False
        else: 
            alpha = Image.linear_gradient('L').rotate(90).resize((alpha_gradient_width, init_image_height))
            tmp_img = init_image.crop((x_coord, 0, x_coord+alpha_gradient_width, init_image_height))
            tmp_img.putalpha(alpha)
            init_image.paste(tmp_img, (x_coord, 0))
            is_right_patch = True
    
    # Generate inpainting mask 
    mask = Image.new('RGBA', (total_width, init_image_height), (0, 0, 0))
    x_coord = 0
    for im in images[:-1]: 
        mask_patch = Image.new('RGBA', (alpha_gradient_width*2, init_image_height), (255, 255, 255))
        mask.paste(mask_patch, (x_coord + im.width - alpha_gradient_width, 0))
        x_coord += im.width 
        
    
    # Crop init_image and mask into batches
    x_coord = 0 
    init_image_mask_pair = []
    init_image_patch_x_coord = []
    
    for im in images[:-1]:
        crop_start_x = x_coord + im.width - init_image_height // 2
        crop_end_x = x_coord + im.width + init_image_height // 2
        tmp_img = init_image.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
        tmp_mask = mask.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
        init_image_mask_pair.append((tmp_img, tmp_mask))
        init_image_patch_x_coord.append(crop_start_x)
        x_coord += im.width
    return init_image, mask, init_image_mask_pair, init_image_patch_x_coord

def attach_images_with_loc(inpainted_results, init_image_patch_x_coord, full_init_img):
    full_init_img = full_init_img
    for im, loc in zip(inpainted_results, init_image_patch_x_coord):
        full_init_img.paste(im, (loc, 0))
    return full_init_img


def inpainting_api_call(input_image, input_mask, token, endpoint):
    body = {
        "input": {
            "initial_image_b64": pil_to_b64(input_image),
            "mask_image_b64": pil_to_b64(input_mask.convert('L'))
      }
    }

    json_data = dumps(body)
    start = time.time()
    resp_inpaint = requests.post(endpoint, data=json_data, headers={"Authorization": f"Bearer {token}"})
    print(f"Execution time: {time.time() - start}")
    return b64_to_pil(resp_inpaint.json()['output']['inpainted_image_b64'])

def process_images_and_inpaint(images, alpha_gradient_width=100, init_image_height=768):
    images = [ resize_image(b64_to_pil(im)).convert("RGBA") for im in images ]
    full_init_img, full_mask, init_image_mask_pair, init_image_patch_x_coord = prepare_init_image_mask(images, alpha_gradient_width, init_image_height)
    results = [ inpainting_api_call(im, mask, token, endpoint) for im, mask in init_image_mask_pair]
    attached_image = pil_to_b64(attach_images_with_loc(results, init_image_patch_x_coord, full_init_img))
    return attached_image