import os import base64 import io from typing import TypedDict import requests import gradio as gr from PIL import Image # Read Baseten configuration from environment variables. BTEN_API_KEY = os.getenv("API_KEY") URL = os.getenv("URL") def image_to_base64(image: Image.Image) -> str: """Convert a PIL image to a base64-encoded PNG string.""" with io.BytesIO() as buffer: image.save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") def ensure_image(img) -> Image.Image: """ Ensure the input is a PIL Image. If it's already a PIL Image, return it. If it's a string (file path), open it. If it's a dict with a "name" key, open the file at that path. """ if isinstance(img, Image.Image): return img elif isinstance(img, str): return Image.open(img) elif isinstance(img, dict) and "name" in img: return Image.open(img["name"]) else: raise ValueError("Cannot convert input to a PIL Image.") def call_baseten_generate( image: Image.Image, prompt: str, steps: int, strength: float, height: int, width: int, lora_name: str, remove_bg: bool, ) -> Image.Image | None: """ Call the Baseten /predict endpoint with provided parameters and return the generated image. """ image = ensure_image(image) b64_image = image_to_base64(image) payload = { "image": b64_image, "prompt": prompt, "steps": steps, "strength": strength, "height": height, "width": width, "lora_name": lora_name, "bgrm": remove_bg, } if not BTEN_API_KEY: headers = {"Authorization": f"Api-Key {os.getenv('API_KEY')}"} else: headers = {"Authorization": f"Api-Key {BTEN_API_KEY}"} try: if not URL: raise ValueError("The URL environment variable is not set.") response = requests.post(URL, headers=headers, json=payload) if response.status_code == 200: data = response.json() gen_b64 = data.get("generated_image", None) if gen_b64: return Image.open(io.BytesIO(base64.b64decode(gen_b64))) else: return None else: print(f"Error: HTTP {response.status_code}\n{response.text}") return None except Exception as e: print(f"Error: {e}") return None # Mode defaults for each tab. Mode = TypedDict( "Mode", { "model": str, "prompt": str, "default_strength": float, "default_height": int, "default_width": int, "models": list[str], "remove_bg": bool, }, ) MODE_DEFAULTS: dict[str, Mode] = { "Subject Generation": { "model": "subject_99000_512", "prompt": "A detailed portrait with soft lighting", "default_strength": 1.2, "default_height": 512, "default_width": 512, "models": [ "zendsd_512_146000", "subject_99000_512", # "zen_pers_11000", "zen_26000_512", ], "remove_bg": True, }, "Background Generation": { "model": "gen_back_3000_1024", "prompt": "A vibrant background with dynamic lighting and textures", "default_strength": 1.2, "default_height": 1024, "default_width": 1024, "models": [ "bgwlight_15000_1024", # "rmgb_12000_1024", "bg_canny_58000_1024", # "gen_back_3000_1024", "gen_back_7000_1024", # "gen_bckgnd_18000_512", # "gen_bckgnd_18000_512", # "loose_25000_512", # "looser_23000_1024", # "looser_bg_gen_21000_1280", # "old_looser_46000_1024", # "relight_bg_gen_31000_1024", ], "remove_bg": True, }, "Canny": { "model": "canny_21000_1024", "prompt": "A futuristic cityscape with neon lights", "default_strength": 1.2, "default_height": 1024, "default_width": 1024, "models": ["canny_21000_1024"], "remove_bg": True, }, "Depth": { "model": "depth_9800_1024", "prompt": "A scene with pronounced depth and perspective", "default_strength": 1.2, "default_height": 1024, "default_width": 1024, "models": [ "depth_9800_1024", ], "remove_bg": True, }, "Deblurring": { "model": "slight_deblurr_18000", "prompt": "A scene with pronounced depth and perspective", "default_strength": 1.2, "default_height": 1024, "default_width": 1024, "models": ["deblurr_1024_10000"], # "slight_deblurr_18000", "remove_bg": False, }, } header = """ # 🌍 ZenCtrl / FLUX
GitHub Repo HuggingFace Space Discord LP X
""" defaults = MODE_DEFAULTS["Subject Generation"] with gr.Blocks(title="🌍 ZenCtrl") as demo: gr.Markdown(header) gr.Markdown( """ # ZenCtrl Demo [WIP] One Agent to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject image—without fine-tuning. We are first releasing some of the task specific weights and will release the codes soon. The goal is to unify all of the visual content generation tasks with a single LLM... **Modes:** - **Subject Generation:** Focuses on generating detailed subject portraits. - **Background Generation:** Creates dynamic, vibrant backgrounds: You can generate part of the image from sketch while keeping part of it as it is. - **Canny:** Emphasizes strong edge detection. - **Depth:** Produces images with realistic depth and perspective. For more details, shoot us a message on discord. """ ) with gr.Tabs(): for mode in MODE_DEFAULTS: with gr.Tab(mode): defaults = MODE_DEFAULTS[mode] gr.Markdown(f"### {mode} Mode") gr.Markdown(f"**Default Model:** {defaults['model']}") with gr.Row(): with gr.Column(scale=2, min_width=370): input_image = gr.Image( label="Upload Image", type="pil", scale=3, height=370, min_width=100, ) generate_button = gr.Button("Generate") with gr.Blocks(title="Options"): model_dropdown = gr.Dropdown( label="Model", choices=defaults["models"], value=defaults["model"], interactive=True, ) remove_bg_checkbox = gr.Checkbox( label="Remove Background", value=defaults["remove_bg"] ) with gr.Column(scale=2): output_image = gr.Image( label="Generated Image", type="pil", height=573, scale=4, min_width=100, ) gr.Markdown("#### Prompt") prompt_box = gr.Textbox( label="Prompt", value=defaults["prompt"], lines=2 ) # Wrap generation parameters in an Accordion for collapsible view. with gr.Accordion("Generation Parameters", open=False): with gr.Row(): step_slider = gr.Slider( minimum=2, maximum=28, value=2, step=2, label="Steps" ) strength_slider = gr.Slider( minimum=0.5, maximum=2.0, value=defaults["default_strength"], step=0.1, label="Strength", ) with gr.Row(): height_slider = gr.Slider( minimum=512, maximum=1360, value=defaults["default_height"], step=1, label="Height", ) width_slider = gr.Slider( minimum=512, maximum=1360, value=defaults["default_width"], step=1, label="Width", ) def on_generate_click( model_name, prompt, steps, strength, height, width, remove_bg, image, ): return call_baseten_generate( image, prompt, steps, strength, height, width, model_name, remove_bg, ) generate_button.click( fn=on_generate_click, inputs=[ model_dropdown, prompt_box, step_slider, strength_slider, height_slider, width_slider, remove_bg_checkbox, input_image, ], outputs=[output_image], ) if __name__ == "__main__": demo.launch()