Spaces:
Running
Running
| import base64 | |
| import io | |
| import os | |
| import zipfile | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Literal, cast | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| from gradio.components.image_editor import EditorValue | |
| from PIL import Image | |
| PASSWORD = os.environ.get("PASSWORD", None) | |
| if not PASSWORD: | |
| raise ValueError("PASSWORD is not set") | |
| ENDPOINT = os.environ.get("ENDPOINT", None) | |
| if not ENDPOINT: | |
| raise ValueError("ENDPOINT is not set") | |
| def encode_image_as_base64(image: Image.Image) -> str: | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def predict( | |
| model_type: Literal["schnell", "dev"], | |
| image_and_mask: EditorValue, | |
| furniture_reference: Image.Image | None, | |
| prompt: str = "", | |
| subfolder: str = "", | |
| seed: int = 0, | |
| num_inference_steps: int = 28, | |
| max_dimension: int = 512, | |
| margin: int = 64, | |
| crop: bool = True, | |
| num_images_per_prompt: int = 1, | |
| ) -> list[Image.Image] | None: | |
| if not image_and_mask: | |
| gr.Info("Please upload an image and draw a mask") | |
| return None | |
| if not furniture_reference: | |
| gr.Info("Please upload a furniture reference image") | |
| return None | |
| image_np = image_and_mask["background"] | |
| image_np = cast(np.ndarray, image_np) | |
| # If the image is empty, return None | |
| if np.sum(image_np) == 0: | |
| gr.Info("Please upload an image") | |
| return None | |
| alpha_channel = image_and_mask["layers"][0] | |
| alpha_channel = cast(np.ndarray, alpha_channel) | |
| mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
| # if mask_np is empty, return None | |
| if np.sum(mask_np) == 0: | |
| gr.Info("Please mark the areas you want to remove") | |
| return None | |
| mask_image = Image.fromarray(mask_np).convert("L") | |
| target_image = Image.fromarray(image_np).convert("RGB") | |
| # Avoid too big image to be sent to the API | |
| mask_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS) | |
| target_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS) | |
| furniture_reference.thumbnail((1024, 1024), Image.Resampling.LANCZOS) | |
| room_image_input_base64 = encode_image_as_base64(target_image) | |
| room_image_mask_base64 = encode_image_as_base64(mask_image) | |
| furniture_reference_base64 = encode_image_as_base64(furniture_reference) | |
| room_image_input_base64 = "data:image/png;base64," + room_image_input_base64 | |
| room_image_mask_base64 = "data:image/png;base64," + room_image_mask_base64 | |
| furniture_reference_base64 = "data:image/png;base64," + furniture_reference_base64 | |
| response = requests.post( | |
| ENDPOINT, | |
| headers={"accept": "application/json", "Content-Type": "application/json"}, | |
| json={ | |
| "model_type": model_type, | |
| "room_image_input": room_image_input_base64, | |
| "room_image_mask": room_image_mask_base64, | |
| "furniture_reference_image": furniture_reference_base64, | |
| "prompt": prompt, | |
| "subfolder": subfolder, | |
| "seed": seed, | |
| "num_inference_steps": num_inference_steps, | |
| "max_dimension": max_dimension, | |
| "condition_scale": 1.0, | |
| "margin": margin, | |
| "crop": crop, | |
| "num_images_per_prompt": num_images_per_prompt, | |
| "password": PASSWORD, | |
| }, | |
| ) | |
| if response.status_code != 200: | |
| gr.Info("An error occurred during the generation") | |
| return None | |
| # Read the returned ZIP file from the response. | |
| zip_bytes = io.BytesIO(response.content) | |
| final_image_list: list[Image.Image] = [] | |
| # Open the ZIP archive. | |
| with zipfile.ZipFile(zip_bytes, "r") as zip_file: | |
| image_filenames = zip_file.namelist() | |
| for filename in image_filenames: | |
| with zip_file.open(filename) as file: | |
| image = Image.open(file).convert("RGB") | |
| final_image_list.append(image) | |
| return final_image_list | |
| intro_markdown = r""" | |
| # Furniture Blending Demo | |
| """ | |
| css = r""" | |
| #col-left { | |
| margin: 0 auto; | |
| max-width: 430px; | |
| } | |
| #col-mid { | |
| margin: 0 auto; | |
| max-width: 430px; | |
| } | |
| #col-right { | |
| margin: 0 auto; | |
| max-width: 430px; | |
| } | |
| #col-showcase { | |
| margin: 0 auto; | |
| max-width: 1100px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(intro_markdown) | |
| with gr.Row() as content: | |
| with gr.Column(elem_id="col-left"): | |
| gr.HTML( | |
| r""" | |
| <div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px; height: 50px;"> | |
| <div> | |
| 🪟 Room image with inpainting mask ⬇️ | |
| </div> | |
| </div> | |
| """, | |
| max_height=50, | |
| ) | |
| image_and_mask = gr.ImageMask( | |
| label="Image and Mask", | |
| layers=False, | |
| height="full", | |
| width="full", | |
| show_fullscreen_button=False, | |
| sources=["upload"], | |
| show_download_button=False, | |
| interactive=True, | |
| brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), | |
| transforms=[], | |
| ) | |
| image_and_mask_examples = gr.Examples( | |
| examples=list(Path("./examples/scenes").glob("*.png")), | |
| label="Room examples", | |
| examples_per_page=6, | |
| inputs=[image_and_mask], | |
| ) | |
| with gr.Column(elem_id="col-mid"): | |
| gr.HTML( | |
| r""" | |
| <div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px; height: 50px;"> | |
| <div> | |
| 🪑 Furniture reference image ⬇️ | |
| </div> | |
| </div> | |
| """, | |
| max_height=50, | |
| ) | |
| condition_image = gr.Image( | |
| label="Furniture Reference", | |
| type="pil", | |
| sources=["upload"], | |
| image_mode="RGB", | |
| ) | |
| furniture_examples = gr.Examples( | |
| examples=list(Path("./examples/objects").glob("*.png")), | |
| label="Furniture examples", | |
| examples_per_page=6, | |
| inputs=[condition_image], | |
| ) | |
| with gr.Column(elem_id="col-right"): | |
| gr.HTML( | |
| r""" | |
| <div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px; height: 50px;"> | |
| <div> | |
| 🔥 Press Run ⬇️ | |
| </div> | |
| </div> | |
| """, | |
| max_height=50, | |
| ) | |
| results = gr.Gallery( | |
| label="Result", | |
| format="png", | |
| file_types="image", | |
| show_label=False, | |
| columns=2, | |
| allow_preview=True, | |
| preview=True, | |
| ) | |
| model_type = gr.Radio( | |
| choices=["schnell", "dev"], | |
| value="schnell", | |
| label="Model Type", | |
| ) | |
| run_button = gr.Button("Run") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="", | |
| ) | |
| subfolder = gr.Textbox( | |
| label="Subfolder", | |
| value="", | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=np.iinfo(np.int32).max, | |
| step=1, | |
| value=0, | |
| ) | |
| num_images_per_prompt = gr.Slider( | |
| label="Number of images per prompt", | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| value=4, | |
| ) | |
| crop = gr.Checkbox( | |
| label="Crop", | |
| value=False, | |
| ) | |
| margin = gr.Slider( | |
| label="Margin", | |
| minimum=0, | |
| maximum=256, | |
| step=16, | |
| value=128, | |
| ) | |
| with gr.Column(): | |
| max_dimension = gr.Slider( | |
| label="Max Dimension", | |
| minimum=256, | |
| maximum=1024, | |
| step=128, | |
| value=512, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=4, | |
| maximum=30, | |
| step=2, | |
| value=4, | |
| ) | |
| # Change the number of inference steps based on the model type | |
| model_type.change( | |
| fn=lambda x: gr.update(value=4 if x == "schnell" else 28), | |
| inputs=model_type, | |
| outputs=num_inference_steps, | |
| ) | |
| run_button.click( | |
| fn=predict, | |
| inputs=[ | |
| model_type, | |
| image_and_mask, | |
| condition_image, | |
| prompt, | |
| subfolder, | |
| seed, | |
| num_inference_steps, | |
| max_dimension, | |
| margin, | |
| crop, | |
| num_images_per_prompt, | |
| ], | |
| outputs=[results], | |
| ) | |
| demo.launch() | |