incendies's picture
FLUX: use int params to match API Recorder snippet
c97aa5a
"""
Imageat Workflow Agent — All tools in one Gradio Space (ZeroGPU).
Single Space, single quota; no Daggr, no iframe. Pattern: spaces + @spaces.GPU + demo.launch().
"""
import base64
import os
import tempfile
import urllib.request
import gradio as gr
import spaces
from gradio_client import Client, handle_file
def _hf_token():
"""HF token for ZeroGPU quota (Space's HF_TOKEN or huggingface_hub)."""
token = os.environ.get("HF_TOKEN")
if token:
return token
try:
from huggingface_hub import get_token
return get_token()
except Exception:
return None
def _url_to_path(url):
"""Download image URL to temp file for Gradio display."""
if not url or not isinstance(url, str) or not url.startswith("http"):
return url
try:
ext = "png"
if ".jpg" in url or ".jpeg" in url:
ext = "jpg"
elif ".webp" in url:
ext = "webp"
f = tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False)
req = urllib.request.Request(url, headers={"User-Agent": "Gradio-Imageat/1.0"})
with urllib.request.urlopen(req, timeout=60) as r:
f.write(r.read())
f.close()
return f.name
except Exception:
return url
def _image_to_path(image):
"""Gradio Image → single filepath string."""
if image is None:
return None
if isinstance(image, str):
return image
if isinstance(image, dict) and image.get("path"):
return image["path"]
if isinstance(image, (list, tuple)) and len(image) > 0:
first = image[0]
if isinstance(first, str):
return first
if isinstance(first, dict) and first.get("path"):
return first["path"]
return None
def _path_for_api(path):
"""Data URL → temp file path; else return path."""
if not path or not isinstance(path, str):
return None
if path.startswith("data:"):
try:
header, b64 = path.split(",", 1)
ext = "png"
if "jpeg" in header or "jpg" in header:
ext = "jpg"
elif "webp" in header:
ext = "webp"
data = base64.b64decode(b64)
f = tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False)
f.write(data)
f.close()
return f.name
except Exception as e:
raise RuntimeError(f"Could not decode data URL: {e}") from e
return path
# ---------- 1. Background Removal ----------
def run_bg_removal(image):
if image is None:
return None
path = _image_to_path(image)
if not path:
return None
path = _path_for_api(path)
if not path:
return None
try:
client = Client("hf-applications/background-removal", token=_hf_token())
result = client.predict(handle_file(path), api_name="/image")
except Exception as e:
raise RuntimeError(f"Background removal error: {e}") from e
# API may return (original, result); we need the result.
out = result[-1] if isinstance(result, (list, tuple)) and result else result
if out and isinstance(out, str) and out.startswith("http"):
return _url_to_path(out)
return out
# ---------- 2. Upscaler ----------
@spaces.GPU
def run_upscaler(image, model_selection="4xBHI_dat2_real"):
path = _image_to_path(image)
if not path:
return None
path = _path_for_api(path)
if not path:
return None
try:
image_arg = handle_file(path)
except Exception as e:
raise RuntimeError(f"Upscaler: could not load image: {e}") from e
try:
client = Client("Phips/Upscaler", token=_hf_token())
result = client.predict(image_arg, model_selection, api_name="/upscale_image")
except Exception as e:
raise RuntimeError(f"Upscaler API error: {e}") from e
out = None
if result and len(result) >= 2 and result[1]:
out = result[1]
elif result and len(result) >= 1 and result[0]:
r0 = result[0]
if isinstance(r0, (list, tuple)) and len(r0) >= 2 and r0[1]:
out = r0[1]
elif isinstance(r0, str):
out = r0
if out and isinstance(out, str) and out.startswith("http"):
return _url_to_path(out)
return out
# ---------- 3. Z-Image Turbo ----------
@spaces.GPU
def run_z_image_turbo(prompt, height=1024, width=1024, seed=42):
if not prompt or not str(prompt).strip():
return None
try:
client = Client("hf-applications/Z-Image-Turbo", token=_hf_token())
result = client.predict(
prompt=str(prompt).strip(),
height=float(height),
width=float(width),
seed=int(seed),
api_name="/generate_image",
)
except Exception as e:
raise RuntimeError(f"Z-Image-Turbo API error: {e}") from e
if result and len(result) >= 1 and result[0]:
img = result[0]
out = img.get("url") or img.get("path") if isinstance(img, dict) else (img if isinstance(img, str) else None)
if out and isinstance(out, str) and out.startswith("http"):
return _url_to_path(out)
return out
return None
# ---------- 4. FLUX.2 Klein 9B ----------
@spaces.GPU
def run_flux_klein(prompt, mode_choice="Distilled (4 steps)", seed=0, randomize_seed=True, width=1024, height=1024):
if not prompt or not str(prompt).strip():
return None
try:
client = Client("black-forest-labs/FLUX.2-klein-9B", token=_hf_token())
# Match API Recorder snippet exactly: int for seed/width/height/steps/guidance
mode = "Distilled (4 steps)" if "Distilled" in str(mode_choice) else "Base (50 steps)"
num_steps = 4 if mode == "Distilled (4 steps)" else 50
guidance = 1 if mode == "Distilled (4 steps)" else 3.5
result = client.predict(
prompt=str(prompt).strip(),
input_images=[],
mode_choice=mode,
seed=int(seed) if seed is not None else 0,
randomize_seed=bool(randomize_seed),
width=int(width) if width is not None else 1024,
height=int(height) if height is not None else 1024,
num_inference_steps=num_steps,
guidance_scale=guidance,
prompt_upsampling=False,
api_name="/generate",
)
except Exception as e:
raise RuntimeError(f"FLUX.2-klein-9B API error: {e}") from e
if result and len(result) >= 1 and result[0]:
img = result[0]
out = img.get("url") or img.get("path") if isinstance(img, dict) else (img if isinstance(img, str) else None)
if out and isinstance(out, str) and out.startswith("http"):
return _url_to_path(out)
return out
return None
# ---------- Gradio UI: one app, all tools in tabs ----------
with gr.Blocks(title="Imageat Workflow Agent") as demo:
gr.Markdown("# Imageat Workflow — All tools on one Space (ZeroGPU)")
with gr.Tabs():
with gr.TabItem("Background Removal"):
gr.Interface(
fn=run_bg_removal,
inputs=gr.Image(label="Input Image", type="filepath"),
outputs=gr.Image(label="Background Removed"),
title="Background Removal",
)
with gr.TabItem("Upscaler"):
gr.Interface(
fn=run_upscaler,
inputs=[
gr.Image(label="Input Image", type="filepath"),
gr.Dropdown(
label="Model",
choices=["4xBHI_dat2_real", "4xNomos8kDAT", "4xHFA2k", "2xEvangelion_dat2"],
value="4xBHI_dat2_real",
),
],
outputs=gr.Image(label="Upscaled Image"),
title="Upscaler",
)
with gr.TabItem("Z-Image Turbo"):
gr.Interface(
fn=run_z_image_turbo,
inputs=[
gr.Textbox(label="Prompt", lines=3),
gr.Slider(512, 1024, value=1024, step=64, label="Height"),
gr.Slider(512, 1024, value=1024, step=64, label="Width"),
gr.Number(value=42, label="Seed", precision=0),
],
outputs=gr.Image(label="Generated Image"),
title="Z-Image Turbo",
)
with gr.TabItem("FLUX.2 Klein 9B"):
gr.Interface(
fn=run_flux_klein,
inputs=[
gr.Textbox(label="Prompt", lines=3),
gr.Radio(
choices=["Distilled (4 steps)", "Base (50 steps)"],
value="Distilled (4 steps)",
label="Mode",
),
gr.Number(value=0, label="Seed", precision=0),
gr.Checkbox(value=True, label="Randomize seed"),
gr.Slider(512, 1024, value=1024, step=64, label="Width"),
gr.Slider(512, 1024, value=1024, step=64, label="Height"),
],
outputs=gr.Image(label="Generated Image"),
title="FLUX.2 Klein 9B",
)
demo.launch(theme=gr.themes.Soft())