Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- .gitattributes +4 -0
- README.md +8 -8
- backend/main.py +23 -0
- backend/models.py +17 -0
- backend/utils.py +120 -0
- frontend/app.py +71 -0
- frontend/example/img_1.jpg +0 -0
- frontend/example/img_2.jpg +3 -0
- frontend/example/img_3.jpg +3 -0
- frontend/example/img_4.jpg +3 -0
- frontend/example/output.png +3 -0
- frontend/utils.py +119 -0
- requirements.txt +8 -0
- start_server.sh +2 -0
- start_web_app.sh +2 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
frontend/example/img_2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
frontend/example/img_3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
frontend/example/img_4.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
frontend/example/output.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
---
|
| 2 |
title: Memory Carousel
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
|
|
|
|
| 9 |
pinned: false
|
| 10 |
-
-
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: Memory Carousel
|
| 3 |
+
emoji: 🎞️
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.34.0
|
| 8 |
+
python_version: 3.11.3
|
| 9 |
+
app_file: frontend/app.py
|
| 10 |
pinned: false
|
| 11 |
+
license: apache-2.0
|
| 12 |
+
---
|
|
|
backend/main.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from models import ImageInpaintingRequest, ImageInpaintingResponse
|
| 3 |
+
from utils import process_images_and_inpaint
|
| 4 |
+
from typing import List
|
| 5 |
+
import uvicorn
|
| 6 |
+
|
| 7 |
+
app = FastAPI()
|
| 8 |
+
|
| 9 |
+
@app.get("/")
|
| 10 |
+
def hello():
|
| 11 |
+
return {"message":"Test FastAPI"}
|
| 12 |
+
|
| 13 |
+
@app.post("/inpaint", response_model=ImageInpaintingResponse)
|
| 14 |
+
async def inpaint_images(request: ImageInpaintingRequest):
|
| 15 |
+
try:
|
| 16 |
+
inpainted_image_b64 = process_images_and_inpaint(request.images, request.alpha_gradient_width, request.init_image_height)
|
| 17 |
+
return ImageInpaintingResponse(inpainted_image=inpainted_image_b64)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(e)
|
| 20 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
backend/models.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ImageInpaintingRequest(BaseModel):
|
| 6 |
+
images: List[str]
|
| 7 |
+
alpha_gradient_width: int
|
| 8 |
+
init_image_height: int
|
| 9 |
+
|
| 10 |
+
class Config:
|
| 11 |
+
arbitrary_types_allowed = True
|
| 12 |
+
|
| 13 |
+
class ImageInpaintingResponse(BaseModel):
|
| 14 |
+
inpainted_image: str
|
| 15 |
+
|
| 16 |
+
class Config:
|
| 17 |
+
arbitrary_types_allowed = True
|
backend/utils.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import requests
|
| 3 |
+
from json import dumps, dump
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import time
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
endpoint = 'https://serving.hopter.staging.picc.co/api/v1/services/gen-ai-image-expansion/predictions'
|
| 12 |
+
token = os.getenv('API_TOKEN')
|
| 13 |
+
|
| 14 |
+
def pil_to_b64(image:Image.Image) -> str:
|
| 15 |
+
buffered = BytesIO()
|
| 16 |
+
image.save(buffered, format="PNG", quality=80)
|
| 17 |
+
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 18 |
+
prefix = 'data:image/png;base64,'
|
| 19 |
+
return prefix + img_str
|
| 20 |
+
|
| 21 |
+
def b64_to_pil(b64_string):
|
| 22 |
+
# Remove the Base64 prefix if present
|
| 23 |
+
if b64_string.startswith('data:image'):
|
| 24 |
+
b64_string = b64_string.split(';base64,', 1)[1]
|
| 25 |
+
# Decode the Base64 string to bytes
|
| 26 |
+
image_bytes = base64.b64decode(b64_string)
|
| 27 |
+
# Create a BytesIO object and load the image bytes
|
| 28 |
+
image_buffer = BytesIO(image_bytes)
|
| 29 |
+
image = Image.open(image_buffer)
|
| 30 |
+
return image
|
| 31 |
+
|
| 32 |
+
def resize_image(image, max_height=768):
|
| 33 |
+
scale = max_height/image.height
|
| 34 |
+
return image.resize((int(image.width * scale), int(image.height * scale)))
|
| 35 |
+
|
| 36 |
+
def prepare_init_image_mask(images: [Image.Image], alpha_gradient_width=80, init_image_height=768): # type: ignore
|
| 37 |
+
total_width = sum([ im.width for im in images])
|
| 38 |
+
init_image = Image.new('RGBA', (total_width,init_image_height))
|
| 39 |
+
|
| 40 |
+
# Paste input images on init_image
|
| 41 |
+
x_coord = 0
|
| 42 |
+
for im in images:
|
| 43 |
+
init_image.paste(im, (x_coord, 0))
|
| 44 |
+
x_coord += im.width
|
| 45 |
+
|
| 46 |
+
# Add linear alpha gradient
|
| 47 |
+
x_coord = 0
|
| 48 |
+
is_right_patch = True
|
| 49 |
+
i = 0
|
| 50 |
+
while i <= len(images) - 1:
|
| 51 |
+
im = images[i]
|
| 52 |
+
if i == len(images) - 1 and is_right_patch:
|
| 53 |
+
break
|
| 54 |
+
if is_right_patch:
|
| 55 |
+
alpha = Image.linear_gradient('L').rotate(-90).resize((alpha_gradient_width, init_image_height))
|
| 56 |
+
tmp_img = init_image.crop((x_coord+im.width - alpha_gradient_width, 0, x_coord+im.width, init_image_height))
|
| 57 |
+
tmp_img.putalpha(alpha)
|
| 58 |
+
init_image.paste(tmp_img, (x_coord+im.width - alpha_gradient_width, 0))
|
| 59 |
+
x_coord += im.width
|
| 60 |
+
i += 1
|
| 61 |
+
is_right_patch = False
|
| 62 |
+
else:
|
| 63 |
+
alpha = Image.linear_gradient('L').rotate(90).resize((alpha_gradient_width, init_image_height))
|
| 64 |
+
tmp_img = init_image.crop((x_coord, 0, x_coord+alpha_gradient_width, init_image_height))
|
| 65 |
+
tmp_img.putalpha(alpha)
|
| 66 |
+
init_image.paste(tmp_img, (x_coord, 0))
|
| 67 |
+
is_right_patch = True
|
| 68 |
+
|
| 69 |
+
# Generate inpainting mask
|
| 70 |
+
mask = Image.new('RGBA', (total_width, init_image_height), (0, 0, 0))
|
| 71 |
+
x_coord = 0
|
| 72 |
+
for im in images[:-1]:
|
| 73 |
+
mask_patch = Image.new('RGBA', (alpha_gradient_width*2, init_image_height), (255, 255, 255))
|
| 74 |
+
mask.paste(mask_patch, (x_coord + im.width - alpha_gradient_width, 0))
|
| 75 |
+
x_coord += im.width
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Crop init_image and mask into batches
|
| 79 |
+
x_coord = 0
|
| 80 |
+
init_image_mask_pair = []
|
| 81 |
+
init_image_patch_x_coord = []
|
| 82 |
+
|
| 83 |
+
for im in images[:-1]:
|
| 84 |
+
crop_start_x = x_coord + im.width - init_image_height // 2
|
| 85 |
+
crop_end_x = x_coord + im.width + init_image_height // 2
|
| 86 |
+
tmp_img = init_image.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
|
| 87 |
+
tmp_mask = mask.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
|
| 88 |
+
init_image_mask_pair.append((tmp_img, tmp_mask))
|
| 89 |
+
init_image_patch_x_coord.append(crop_start_x)
|
| 90 |
+
x_coord += im.width
|
| 91 |
+
return init_image, mask, init_image_mask_pair, init_image_patch_x_coord
|
| 92 |
+
|
| 93 |
+
def attach_images_with_loc(inpainted_results, init_image_patch_x_coord, full_init_img):
|
| 94 |
+
full_init_img = full_init_img
|
| 95 |
+
for im, loc in zip(inpainted_results, init_image_patch_x_coord):
|
| 96 |
+
full_init_img.paste(im, (loc, 0))
|
| 97 |
+
return full_init_img
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def inpainting_api_call(input_image, input_mask, token, endpoint):
|
| 101 |
+
body = {
|
| 102 |
+
"input": {
|
| 103 |
+
"initial_image_b64": pil_to_b64(input_image),
|
| 104 |
+
"mask_image_b64": pil_to_b64(input_mask.convert('L'))
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
json_data = dumps(body)
|
| 109 |
+
start = time.time()
|
| 110 |
+
resp_inpaint = requests.post(endpoint, data=json_data, headers={"Authorization": f"Bearer {token}"})
|
| 111 |
+
print(f"Execution time: {time.time() - start}")
|
| 112 |
+
return b64_to_pil(resp_inpaint.json()['output']['inpainted_image_b64'])
|
| 113 |
+
|
| 114 |
+
def process_images_and_inpaint(images, alpha_gradient_width=100, init_image_height=768):
|
| 115 |
+
images = [ resize_image(b64_to_pil(im)).convert("RGBA") for im in images ]
|
| 116 |
+
full_init_img, full_mask, init_image_mask_pair, init_image_patch_x_coord = prepare_init_image_mask(images, alpha_gradient_width, init_image_height)
|
| 117 |
+
results = [ inpainting_api_call(im, mask, token, endpoint) for im, mask in init_image_mask_pair]
|
| 118 |
+
attached_image = pil_to_b64(attach_images_with_loc(results, init_image_patch_x_coord, full_init_img))
|
| 119 |
+
return attached_image
|
| 120 |
+
|
frontend/app.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import requests
|
| 3 |
+
from utils import resize_image, pil_to_b64, b64_to_pil, process_images_and_inpaint
|
| 4 |
+
|
| 5 |
+
USE_FASTAPI = False
|
| 6 |
+
FAST_API_ENDPOINT = 'http://127.0.0.1:5000/inpaint'
|
| 7 |
+
|
| 8 |
+
def run_inpainting(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height, USE_FASTAPI):
|
| 9 |
+
images = []
|
| 10 |
+
for img in [img_1, img_2, img_3, img_4]:
|
| 11 |
+
if img is not None:
|
| 12 |
+
images.append(pil_to_b64(resize_image(img, init_image_height)))
|
| 13 |
+
if USE_FASTAPI:
|
| 14 |
+
return call_inpainting_api(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height)
|
| 15 |
+
else:
|
| 16 |
+
return process_images_and_inpaint(images, alpha_gradient_width, init_image_height)
|
| 17 |
+
|
| 18 |
+
def call_inpainting_api(img_1, img_2, img_3, img_4, alpha_gradient_width, init_image_height):
|
| 19 |
+
images = []
|
| 20 |
+
for img in [img_1, img_2, img_3, img_4]:
|
| 21 |
+
if img is not None:
|
| 22 |
+
images.append(pil_to_b64(resize_image(img, init_image_height)))
|
| 23 |
+
response = requests.post(FAST_API_ENDPOINT, json={
|
| 24 |
+
"images": images,
|
| 25 |
+
"alpha_gradient_width": alpha_gradient_width,
|
| 26 |
+
"init_image_height": init_image_height
|
| 27 |
+
})
|
| 28 |
+
if response.status_code == 200:
|
| 29 |
+
return b64_to_pil(response.json()["inpainted_image"])
|
| 30 |
+
else:
|
| 31 |
+
return "Error calling inpainting API"
|
| 32 |
+
|
| 33 |
+
TITLE = """<h2 align="center"> 🎞️ Memory Carousel </h2>"""
|
| 34 |
+
|
| 35 |
+
# Define the Gradio interface
|
| 36 |
+
with gr.Blocks() as demo:
|
| 37 |
+
gr.HTML(TITLE)
|
| 38 |
+
with gr.Column():
|
| 39 |
+
with gr.Row():
|
| 40 |
+
input_image_1 = gr.Image(type='pil', label="First image")
|
| 41 |
+
input_image_2 = gr.Image(type='pil', label="Second image")
|
| 42 |
+
with gr.Row():
|
| 43 |
+
input_image_3 = gr.Image(type='pil', label="Third image(optional)")
|
| 44 |
+
input_image_4 = gr.Image(type='pil', label="Fourth image(optional)")
|
| 45 |
+
with gr.Row():
|
| 46 |
+
alpha_gradient_width = gr.Number(value=100, label="Alpha Gradient Width")
|
| 47 |
+
init_image_height = gr.Number(value=768, label="Init Image Height")
|
| 48 |
+
generate_button = gr.Button("Generate")
|
| 49 |
+
output = gr.Image(type='pil')
|
| 50 |
+
|
| 51 |
+
example_list = gr.Examples(
|
| 52 |
+
examples=[['./example/img_1.jpg', './example/img_2.jpg', './example/img_3.jpg', './example/img_4.jpg', 100, 768]],
|
| 53 |
+
inputs=[
|
| 54 |
+
input_image_1,
|
| 55 |
+
input_image_2,
|
| 56 |
+
input_image_3,
|
| 57 |
+
input_image_4,
|
| 58 |
+
alpha_gradient_width,
|
| 59 |
+
init_image_height
|
| 60 |
+
],
|
| 61 |
+
outputs=[output],
|
| 62 |
+
fn=call_inpainting_api,
|
| 63 |
+
cache_examples=True,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
generate_button.click(
|
| 67 |
+
fn=call_inpainting_api,
|
| 68 |
+
inputs=[input_image_1, input_image_2, input_image_3, input_image_4, alpha_gradient_width, init_image_height],
|
| 69 |
+
outputs=[output]
|
| 70 |
+
)
|
| 71 |
+
demo.launch()
|
frontend/example/img_1.jpg
ADDED
|
frontend/example/img_2.jpg
ADDED
|
Git LFS Details
|
frontend/example/img_3.jpg
ADDED
|
Git LFS Details
|
frontend/example/img_4.jpg
ADDED
|
Git LFS Details
|
frontend/example/output.png
ADDED
|
Git LFS Details
|
frontend/utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import requests
|
| 3 |
+
from json import dumps, dump
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import time
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
endpoint = 'https://serving.hopter.staging.picc.co/api/v1/services/gen-ai-image-expansion/predictions'
|
| 12 |
+
token = os.getenv('API_TOKEN')
|
| 13 |
+
|
| 14 |
+
def pil_to_b64(image:Image.Image) -> str:
|
| 15 |
+
buffered = BytesIO()
|
| 16 |
+
image.save(buffered, format="PNG", quality=80)
|
| 17 |
+
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 18 |
+
prefix = 'data:image/png;base64,'
|
| 19 |
+
return prefix + img_str
|
| 20 |
+
|
| 21 |
+
def b64_to_pil(b64_string):
|
| 22 |
+
# Remove the Base64 prefix if present
|
| 23 |
+
if b64_string.startswith('data:image'):
|
| 24 |
+
b64_string = b64_string.split(';base64,', 1)[1]
|
| 25 |
+
# Decode the Base64 string to bytes
|
| 26 |
+
image_bytes = base64.b64decode(b64_string)
|
| 27 |
+
# Create a BytesIO object and load the image bytes
|
| 28 |
+
image_buffer = BytesIO(image_bytes)
|
| 29 |
+
image = Image.open(image_buffer)
|
| 30 |
+
return image
|
| 31 |
+
|
| 32 |
+
def resize_image(image, max_height=768):
|
| 33 |
+
scale = max_height/image.height
|
| 34 |
+
return image.resize((int(image.width * scale), int(image.height * scale)))
|
| 35 |
+
|
| 36 |
+
def prepare_init_image_mask(images: [Image.Image], alpha_gradient_width=80, init_image_height=768): # type: ignore
|
| 37 |
+
total_width = sum([ im.width for im in images])
|
| 38 |
+
init_image = Image.new('RGBA', (total_width,init_image_height))
|
| 39 |
+
|
| 40 |
+
# Paste input images on init_image
|
| 41 |
+
x_coord = 0
|
| 42 |
+
for im in images:
|
| 43 |
+
init_image.paste(im, (x_coord, 0))
|
| 44 |
+
x_coord += im.width
|
| 45 |
+
|
| 46 |
+
# Add linear alpha gradient
|
| 47 |
+
x_coord = 0
|
| 48 |
+
is_right_patch = True
|
| 49 |
+
i = 0
|
| 50 |
+
while i <= len(images) - 1:
|
| 51 |
+
im = images[i]
|
| 52 |
+
if i == len(images) - 1 and is_right_patch:
|
| 53 |
+
break
|
| 54 |
+
if is_right_patch:
|
| 55 |
+
alpha = Image.linear_gradient('L').rotate(-90).resize((alpha_gradient_width, init_image_height))
|
| 56 |
+
tmp_img = init_image.crop((x_coord+im.width - alpha_gradient_width, 0, x_coord+im.width, init_image_height))
|
| 57 |
+
tmp_img.putalpha(alpha)
|
| 58 |
+
init_image.paste(tmp_img, (x_coord+im.width - alpha_gradient_width, 0))
|
| 59 |
+
x_coord += im.width
|
| 60 |
+
i += 1
|
| 61 |
+
is_right_patch = False
|
| 62 |
+
else:
|
| 63 |
+
alpha = Image.linear_gradient('L').rotate(90).resize((alpha_gradient_width, init_image_height))
|
| 64 |
+
tmp_img = init_image.crop((x_coord, 0, x_coord+alpha_gradient_width, init_image_height))
|
| 65 |
+
tmp_img.putalpha(alpha)
|
| 66 |
+
init_image.paste(tmp_img, (x_coord, 0))
|
| 67 |
+
is_right_patch = True
|
| 68 |
+
|
| 69 |
+
# Generate inpainting mask
|
| 70 |
+
mask = Image.new('RGBA', (total_width, init_image_height), (0, 0, 0))
|
| 71 |
+
x_coord = 0
|
| 72 |
+
for im in images[:-1]:
|
| 73 |
+
mask_patch = Image.new('RGBA', (alpha_gradient_width*2, init_image_height), (255, 255, 255))
|
| 74 |
+
mask.paste(mask_patch, (x_coord + im.width - alpha_gradient_width, 0))
|
| 75 |
+
x_coord += im.width
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Crop init_image and mask into batches
|
| 79 |
+
x_coord = 0
|
| 80 |
+
init_image_mask_pair = []
|
| 81 |
+
init_image_patch_x_coord = []
|
| 82 |
+
|
| 83 |
+
for im in images[:-1]:
|
| 84 |
+
crop_start_x = x_coord + im.width - init_image_height // 2
|
| 85 |
+
crop_end_x = x_coord + im.width + init_image_height // 2
|
| 86 |
+
tmp_img = init_image.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
|
| 87 |
+
tmp_mask = mask.crop((crop_start_x, 0, min(total_width, crop_end_x), init_image_height))
|
| 88 |
+
init_image_mask_pair.append((tmp_img, tmp_mask))
|
| 89 |
+
init_image_patch_x_coord.append(crop_start_x)
|
| 90 |
+
x_coord += im.width
|
| 91 |
+
return init_image, mask, init_image_mask_pair, init_image_patch_x_coord
|
| 92 |
+
|
| 93 |
+
def attach_images_with_loc(inpainted_results, init_image_patch_x_coord, full_init_img):
|
| 94 |
+
full_init_img = full_init_img
|
| 95 |
+
for im, loc in zip(inpainted_results, init_image_patch_x_coord):
|
| 96 |
+
full_init_img.paste(im, (loc, 0))
|
| 97 |
+
return full_init_img
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def inpainting_api_call(input_image, input_mask, token, endpoint):
|
| 101 |
+
body = {
|
| 102 |
+
"input": {
|
| 103 |
+
"initial_image_b64": pil_to_b64(input_image),
|
| 104 |
+
"mask_image_b64": pil_to_b64(input_mask.convert('L'))
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
json_data = dumps(body)
|
| 109 |
+
start = time.time()
|
| 110 |
+
resp_inpaint = requests.post(endpoint, data=json_data, headers={"Authorization": f"Bearer {token}"})
|
| 111 |
+
print(f"Execution time: {time.time() - start}")
|
| 112 |
+
return b64_to_pil(resp_inpaint.json()['output']['inpainted_image_b64'])
|
| 113 |
+
|
| 114 |
+
def process_images_and_inpaint(images, alpha_gradient_width=100, init_image_height=768):
|
| 115 |
+
images = [ resize_image(b64_to_pil(im)).convert("RGBA") for im in images ]
|
| 116 |
+
full_init_img, full_mask, init_image_mask_pair, init_image_patch_x_coord = prepare_init_image_mask(images, alpha_gradient_width, init_image_height)
|
| 117 |
+
results = [ inpainting_api_call(im, mask, token, endpoint) for im, mask in init_image_mask_pair]
|
| 118 |
+
attached_image = pil_to_b64(attach_images_with_loc(results, init_image_patch_x_coord, full_init_img))
|
| 119 |
+
return attached_image
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
python-multipart
|
| 4 |
+
Pillow
|
| 5 |
+
requests
|
| 6 |
+
python-dotenv
|
| 7 |
+
gradio
|
| 8 |
+
gunicorn
|
start_server.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cd backend
|
| 2 |
+
uvicorn --port 5000 --host 127.0.0.1 main:app --reload
|
start_web_app.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cd frontend
|
| 2 |
+
python app.py
|