ZenCtrl / app.py
salso's picture
Update app.py
9a3cf0e verified
raw
history blame
10.9 kB
import os
import base64
import io
from typing import TypedDict
import requests
import gradio as gr
from PIL import Image
# Read Baseten configuration from environment variables.
BTEN_API_KEY = os.getenv("API_KEY")
URL = os.getenv("URL")
def image_to_base64(image: Image.Image) -> str:
with io.BytesIO() as buffer:
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def ensure_image(img) -> Image.Image:
if isinstance(img, Image.Image):
return img
elif isinstance(img, str):
return Image.open(img)
elif isinstance(img, dict) and "name" in img:
return Image.open(img["name"])
else:
raise ValueError("Cannot convert input to a PIL Image.")
def call_baseten_generate(image: Image.Image, prompt: str, steps: int, strength: float, height: int, width: int, lora_name: str, remove_bg: bool) -> Image.Image | None:
image = ensure_image(image)
b64_image = image_to_base64(image)
payload = {
"image": b64_image,
"prompt": prompt,
"steps": steps,
"strength": strength,
"height": height,
"width": width,
"lora_name": lora_name,
"bgrm": remove_bg,
}
headers = {"Authorization": f"Api-Key {BTEN_API_KEY or os.getenv('API_KEY')}"}
try:
if not URL:
raise ValueError("The URL environment variable is not set.")
response = requests.post(URL, headers=headers, json=payload)
if response.status_code == 200:
data = response.json()
gen_b64 = data.get("generated_image", None)
if gen_b64:
return Image.open(io.BytesIO(base64.b64decode(gen_b64)))
else:
return None
else:
print(f"Error: HTTP {response.status_code}\n{response.text}")
return None
except Exception as e:
print(f"Error: {e}")
return None
# ================== MODE CONFIG =====================
Mode = TypedDict("Mode", {
"model": str,
"prompt": str,
"default_strength": float,
"default_height": int,
"default_width": int,
"models": list[str],
"remove_bg": bool,
})
MODE_DEFAULTS: dict[str, Mode] = {
"Subject Generation": {
"model": "subject_99000_512",
"prompt": "A detailed portrait with soft lighting",
"default_strength": 1.2,
"default_height": 512,
"default_width": 512,
"models": ["zendsd_512_146000", "subject_99000_512", "zen_26000_512"],
"remove_bg": True,
},
"Background Generation": {
"model": "bg_canny_58000_1024",
"prompt": "A vibrant background with dynamic lighting and textures",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["bgwlight_15000_1024", "bg_canny_58000_1024", "gen_back_7000_1024"],
"remove_bg": True,
},
"Canny": {
"model": "canny_21000_1024",
"prompt": "A futuristic cityscape with neon lights",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["canny_21000_1024"],
"remove_bg": True,
},
"Depth": {
"model": "depth_9800_1024",
"prompt": "A scene with pronounced depth and perspective",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["depth_9800_1024"],
"remove_bg": True,
},
"Deblurring": {
"model": "deblurr_1024_10000",
"prompt": "A scene with pronounced depth and perspective",
"default_strength": 1.2,
"default_height": 1024,
"default_width": 1024,
"models": ["deblurr_1024_10000"],
"remove_bg": False,
},
}
# ================== PRESET EXAMPLES =====================
MODE_EXAMPLES = {
"Subject Generation": [
["assets/subj1.jpg", "Close-up portrait of a fruit bowl", "assets/subj1_out.jpg"],
["assets/subj2.jpg", "A penguin standing in snow", "assets/subj2_out.jpg"],
["assets/subj3.jpg", "A cat with glowing eyes", "assets/subj3_out.jpg"],
["assets/subj4.jpg", "A child playing with bubbles", "assets/subj4_out.jpg"],
["assets/subj5.jpg", "A stylish young man in neon lights", "assets/subj5_out.jpg"],
["assets/subj6.jpg", "Old man with a mysterious look", "assets/subj6_out.jpg"],
],
"Background Generation": [
["assets/bg1.jpg", "Modern living room with plants", "assets/bg1_out.jpg"],
["assets/bg2.jpg", "Fantasy forest background", "assets/bg2_out.jpg"],
["assets/bg3.jpg", "Futuristic cityscape", "assets/bg3_out.jpg"],
["assets/bg4.jpg", "Minimalist white studio", "assets/bg4_out.jpg"],
["assets/bg5.jpg", "Snowy mountain landscape", "assets/bg5_out.jpg"],
["assets/bg6.jpg", "Golden sunset over the sea", "assets/bg6_out.jpg"],
],
"Canny": [
["assets/canny1.jpg", "A neon cyberpunk city skyline", "assets/canny1_out.jpg"],
["assets/canny2.jpg", "A robot walking in the fog", "assets/canny2_out.jpg"],
["assets/canny3.jpg", "A futuristic vehicle parked under a bridge", "assets/canny3_out.jpg"],
["assets/canny4.jpg", "Sci-fi lab interior with glowing machinery", "assets/canny4_out.jpg"],
["assets/canny5.jpg", "A portrait of a woman outlined in neon", "assets/canny5_out.jpg"],
["assets/canny6.jpg", "Post-apocalyptic abandoned street", "assets/canny6_out.jpg"],
],
"Depth": [
["assets/depth1.jpg", "A narrow alleyway with deep perspective", "assets/depth1_out.jpg"],
["assets/depth2.jpg", "A mountain road vanishing into the distance", "assets/depth2_out.jpg"],
["assets/depth3.jpg", "A hallway with strong depth of field", "assets/depth3_out.jpg"],
["assets/depth4.jpg", "A misty forest path stretching far away", "assets/depth4_out.jpg"],
["assets/depth5.jpg", "A bridge over a deep canyon", "assets/depth5_out.jpg"],
["assets/depth6.jpg", "An underground tunnel with receding arches", "assets/depth6_out.jpg"],
],
"Deblurring": [
["assets/deblur1.jpg", "", "assets/deblur1_out.jpg"],
["assets/deblur2.jpg", "", "assets/deblur2_out.jpg"],
["assets/deblur3.jpg", "", "assets/deblur3_out.jpg"],
["assets/deblur4.jpg", "", "assets/deblur4_out.jpg"],
["assets/deblur5.jpg", "", "assets/deblur5_out.jpg"],
["assets/deblur6.jpg", "", "assets/deblur6_out.jpg"],
],
}
# ================== UI =====================
header = """
<h1>🌍 ZenCtrl / FLUX</h1>
<div align="center" style="line-height: 1;">
<a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg"></a>
<a href="https://huggingface.co/spaces/fotographerai/ZenCtrl" target="_blank"><img src="https://img.shields.io/badge/πŸ€—_HuggingFace-Space-ffbd45.svg"></a>
<a href="https://discord.com/invite/b9RuYQ3F8k" target="_blank"><img src="https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord"></a>
</div>
"""
with gr.Blocks(title="🌍 ZenCtrl") as demo:
gr.HTML(header)
gr.Markdown("# ZenCtrl Demo")
with gr.Tabs():
for mode in MODE_DEFAULTS:
with gr.Tab(mode):
defaults = MODE_DEFAULTS[mode]
gr.Markdown(f"### {mode} Mode")
with gr.Row():
with gr.Column(scale=2):
input_image = gr.Image(label="Input Image", type="pil")
generate_button = gr.Button("Generate")
with gr.Blocks():
model_dropdown = gr.Dropdown(
label="Model",
choices=defaults["models"],
value=defaults["model"],
interactive=True,
)
remove_bg_checkbox = gr.Checkbox(
label="Remove Background", value=defaults["remove_bg"]
)
with gr.Column(scale=2):
output_image = gr.Image(label="Generated Image", type="pil")
prompt_box = gr.Textbox(
label="Prompt", value=defaults["prompt"], lines=2
)
with gr.Accordion("Generation Parameters", open=False):
with gr.Row():
step_slider = gr.Slider(
minimum=2, maximum=28, value=2, step=2, label="Steps"
)
strength_slider = gr.Slider(
minimum=0.5,
maximum=2.0,
value=defaults["default_strength"],
step=0.1,
label="Strength",
)
with gr.Row():
height_slider = gr.Slider(
minimum=512,
maximum=1360,
value=defaults["default_height"],
step=1,
label="Height",
)
width_slider = gr.Slider(
minimum=512,
maximum=1360,
value=defaults["default_width"],
step=1,
label="Width",
)
def on_generate_click(model_name, prompt, steps, strength, height, width, remove_bg, image):
return call_baseten_generate(image, prompt, steps, strength, height, width, model_name, remove_bg)
generate_button.click(
fn=on_generate_click,
inputs=[
model_dropdown,
prompt_box,
step_slider,
strength_slider,
height_slider,
width_slider,
remove_bg_checkbox,
input_image,
],
outputs=[output_image],
)
# ---------------- Templates --------------------
gr.Dataset(
label="Presets (Input / Prompt / Output)",
headers=["Input", "Prompt", "Output"],
components=[input_image, prompt_box, output_image],
samples=MODE_EXAMPLES.get(mode, []),
samples_per_page=6,
)
if __name__ == "__main__":
demo.launch()