BiliSakura's picture
Add FlowMatchEuler/Heun schedulers to all swappable model dropdowns
9dfc93c verified
Raw
History Blame Contribute Delete
18.5 kB
"""Gradio demo for BiliSakura visual generative foundation models on Hugging Face ZeroGPU."""
from __future__ import annotations
try:
import spaces
except ImportError: # Local development without the Spaces runtime.
class _SpacesStub:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn):
return fn
if args and callable(args[0]):
return args[0]
return decorator
spaces = _SpacesStub() # type: ignore[assignment]
import traceback
from typing import Any
import gradio as gr
import torch
from model_catalog import (
COLLECTIONS,
MODEL_LABELS,
get_profile_by_label,
parse_model_label,
scheduler_choices_for_profile,
uses_native_scheduler,
)
from model_loader import (
PIPELINE_MANAGER,
_to_float,
_to_int,
current_scheduler_name,
default_class_label_for_pipe,
run_inference,
scheduler_options_for_profile,
)
DEFAULT_MODEL = MODEL_LABELS[0]
DEFAULT_PROFILE = get_profile_by_label(DEFAULT_MODEL)
INTERVAL_COLLECTIONS = {"iMF-diffusers", "NiT-diffusers", "PixelFlow-diffusers"}
PMF_COLLECTION = "pMF-diffusers"
def _model_info_markdown(profile, pipe=None) -> str:
extras = profile.extra_call_kwargs
extra_lines = ""
if extras:
extra_lines = "\n".join(f"- `{key}`: `{value}`" for key, value in extras.items())
extra_lines = f"\n\n**Default extra args**\n{extra_lines}"
scheduler_line = (
f"`{current_scheduler_name(pipe)}`"
if pipe is not None
else "`checkpoint` (load model to see)"
)
return (
f"**Hub repo:** [`{profile.hub_model_id}`]({profile.hub_model_url})\n\n"
f"- dtype: `{profile.dtype}`\n"
f"- default resolution: `{profile.default_height}x{profile.default_width}`\n"
f"- checkpoint scheduler: {scheduler_line}\n"
f"- GPU size: `{profile.gpu_size}`"
f"{extra_lines}"
)
def _interval_defaults(profile) -> tuple[float, float]:
extras = profile.extra_call_kwargs
if "guidance_interval_start" in extras:
return float(extras["guidance_interval_start"]), float(extras["guidance_interval_end"])
interval = extras.get("guidance_interval", (0.0, 0.7))
return float(interval[0]), float(interval[1])
def _build_extra_kwargs(
profile,
guidance_interval_start: float,
guidance_interval_end: float,
guidance_interval_min: float,
guidance_interval_max: float,
noise_scale: float,
) -> dict[str, Any]:
if profile.collection == "iMF-diffusers":
return {
"guidance_interval_start": guidance_interval_start,
"guidance_interval_end": guidance_interval_end,
}
if profile.collection in {"NiT-diffusers", "PixelFlow-diffusers"}:
return {"guidance_interval": (guidance_interval_start, guidance_interval_end)}
if profile.collection == PMF_COLLECTION:
return {
"guidance_interval_min": guidance_interval_min,
"guidance_interval_max": guidance_interval_max,
"noise_scale": noise_scale,
}
return dict(profile.extra_call_kwargs)
def _scheduler_config(profile, pipe=None):
choices, default = scheduler_options_for_profile(profile, pipe)
if uses_native_scheduler(profile) and pipe is None:
return gr.update(
choices=["checkpoint"],
value="checkpoint",
interactive=False,
info="Uses the checkpoint scheduler (custom, not swappable until loaded)",
)
return gr.update(
choices=choices,
value=default,
interactive=not uses_native_scheduler(profile),
info=(
"Checkpoint scheduler by default; pick another built-in diffusers scheduler to swap"
if pipe is not None
else "Defaults to checkpoint scheduler after Load model; optional built-in swap"
),
)
def _config_from_profile(profile, pipe=None):
g_start, g_end = _interval_defaults(profile)
extras = profile.extra_call_kwargs
show_interval = profile.collection in INTERVAL_COLLECTIONS
show_pmf = profile.collection == PMF_COLLECTION
return (
_model_info_markdown(profile, pipe),
gr.update(value=profile.default_class_label),
gr.update(value=profile.default_seed),
gr.update(value=profile.default_steps, maximum=profile.max_steps),
gr.update(value=profile.default_guidance),
_scheduler_config(profile, pipe),
gr.update(value=profile.default_height or profile.infer_resolution()),
gr.update(value=profile.default_width or profile.infer_resolution()),
gr.update(value=g_start),
gr.update(value=g_end),
gr.update(value=float(extras.get("guidance_interval_min", 0.2))),
gr.update(value=float(extras.get("guidance_interval_max", 0.6))),
gr.update(value=float(extras.get("noise_scale", 4.0))),
gr.update(visible=show_interval, open=show_interval),
gr.update(visible=show_pmf, open=show_pmf),
)
def on_model_change(model_label: str):
return _config_from_profile(get_profile_by_label(model_label))
def _coerce_inference_inputs(
class_label: Any,
seed: Any,
num_steps: Any,
guidance_scale: Any,
height: Any,
width: Any,
guidance_interval_start: Any,
guidance_interval_end: Any,
guidance_interval_min: Any,
guidance_interval_max: Any,
noise_scale: Any,
) -> tuple[str, int, int, float, int, int, float, float, float, float, float]:
return (
str(class_label or "").strip(),
_to_int(seed, default=42),
_to_int(num_steps, default=50),
_to_float(guidance_scale, default=4.0),
_to_int(height, default=256),
_to_int(width, default=256),
_to_float(guidance_interval_start, default=0.0),
_to_float(guidance_interval_end, default=0.7),
_to_float(guidance_interval_min, default=0.2),
_to_float(guidance_interval_max, default=0.6),
_to_float(noise_scale, default=4.0),
)
def _gpu_duration(
model_label: str,
class_label: str,
seed: int,
num_steps: int,
guidance_scale: float,
scheduler: str,
height: int,
width: int,
guidance_interval_start: float,
guidance_interval_end: float,
guidance_interval_min: float,
guidance_interval_max: float,
noise_scale: float,
) -> int:
profile = get_profile_by_label(model_label)
num_steps = _to_int(num_steps, default=profile.default_steps)
step_budget = num_steps if not profile.steps_are_list else max(num_steps, 40)
base = 45 if profile.gpu_size == "large" else 90
return int(min(300, max(base, step_budget * 0.6 + 30)))
def _load_model_core(model_label: str) -> tuple[str, str | None]:
collection, variant = parse_model_label(model_label)
message, _ = PIPELINE_MANAGER.load(collection, variant)
PIPELINE_MANAGER.move_to_cuda()
pipe = PIPELINE_MANAGER.pipe
profile = get_profile_by_label(model_label)
suggested_label = default_class_label_for_pipe(pipe, profile) if pipe is not None else None
return message, suggested_label
@spaces.GPU(size="xlarge", duration=120)
def _load_on_gpu(model_label: str) -> tuple[str, str | None]:
return _load_model_core(model_label)
def load_model(model_label: str):
try:
message, suggested_label = _load_on_gpu(model_label)
except Exception as exc:
raise gr.Error(f"Failed to load `{model_label}`: {exc}") from exc
profile = get_profile_by_label(model_label)
config = _config_from_profile(profile, PIPELINE_MANAGER.pipe)
if suggested_label:
config = list(config)
config[1] = gr.update(value=suggested_label)
config = tuple(config)
return (message, *config)
@spaces.GPU(size="xlarge", duration=_gpu_duration)
def _generate_on_gpu(
model_label: str,
class_label: str,
seed: int,
num_steps: int,
guidance_scale: float,
scheduler: str,
height: int,
width: int,
guidance_interval_start: float,
guidance_interval_end: float,
guidance_interval_min: float,
guidance_interval_max: float,
noise_scale: float,
):
(
class_label,
seed,
num_steps,
guidance_scale,
height,
width,
guidance_interval_start,
guidance_interval_end,
guidance_interval_min,
guidance_interval_max,
noise_scale,
) = _coerce_inference_inputs(
class_label,
seed,
num_steps,
guidance_scale,
height,
width,
guidance_interval_start,
guidance_interval_end,
guidance_interval_min,
guidance_interval_max,
noise_scale,
)
profile = get_profile_by_label(model_label)
collection, variant = parse_model_label(model_label)
if PIPELINE_MANAGER.loaded_label != model_label or PIPELINE_MANAGER.pipe is None:
PIPELINE_MANAGER.load(collection, variant)
PIPELINE_MANAGER.move_to_cuda()
pipe = PIPELINE_MANAGER.pipe
if pipe is None:
raise gr.Error(f"Model `{model_label}` is not loaded.")
extra_kwargs = _build_extra_kwargs(
profile,
guidance_interval_start,
guidance_interval_end,
guidance_interval_min,
guidance_interval_max,
noise_scale,
)
return run_inference(
profile,
pipe,
class_label=class_label,
seed=seed,
num_steps=num_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
scheduler_name=scheduler,
extra_kwargs=extra_kwargs,
)
def generate(
model_label: str,
class_label: str,
seed: int,
num_steps: int,
guidance_scale: float,
scheduler: str,
height: int,
width: int,
guidance_interval_start: float,
guidance_interval_end: float,
guidance_interval_min: float,
guidance_interval_max: float,
noise_scale: float,
):
try:
image = _generate_on_gpu(
model_label,
class_label,
seed,
num_steps,
guidance_scale,
scheduler,
height,
width,
guidance_interval_start,
guidance_interval_end,
guidance_interval_min,
guidance_interval_max,
noise_scale,
)
except Exception as exc:
detail = traceback.format_exc(limit=6).strip()
raise gr.Error(f"Generation failed for `{model_label}`: {exc}\n\n{detail}") from exc
label = PIPELINE_MANAGER.loaded_label or model_label
return f"Generated with `{label}`.", image
def build_demo() -> gr.Blocks:
g_start, g_end = _interval_defaults(DEFAULT_PROFILE)
extras = DEFAULT_PROFILE.extra_call_kwargs
default_scheduler_choices = ["checkpoint", *scheduler_choices_for_profile(DEFAULT_PROFILE)]
default_scheduler_value = "checkpoint"
with gr.Blocks(title="BiliSakura Visual Generation Models") as demo:
gr.Markdown(
"""
# BiliSakura Visual Generative Foundation Models
Class-conditional image generation for [`BiliSakura/*-diffusers`](https://huggingface.co/BiliSakura)
on Hugging Face **ZeroGPU**.
"""
)
with gr.Row(equal_height=False):
with gr.Column(scale=5):
model = gr.Dropdown(
MODEL_LABELS,
value=DEFAULT_MODEL,
label="Model",
info="Select a checkpoint, then configure inference args below",
)
model_info = gr.Markdown(_model_info_markdown(DEFAULT_PROFILE))
with gr.Accordion("Inference config", open=True):
class_label = gr.Textbox(
label="class_labels",
value=DEFAULT_PROFILE.default_class_label,
info="ImageNet class id (e.g. 207) or any synonym from id2label (e.g. golden retriever, tabby)",
)
with gr.Row():
seed = gr.Number(label="seed", value=DEFAULT_PROFILE.default_seed, precision=0)
num_steps = gr.Slider(
label="num_inference_steps",
minimum=1,
maximum=DEFAULT_PROFILE.max_steps,
step=1,
value=DEFAULT_PROFILE.default_steps,
)
guidance_scale = gr.Slider(
label="guidance_scale",
minimum=0.0,
maximum=20.0,
step=0.1,
value=DEFAULT_PROFILE.default_guidance,
)
scheduler = gr.Dropdown(
label="scheduler",
choices=default_scheduler_choices or ["checkpoint"],
value=default_scheduler_value,
interactive=not uses_native_scheduler(DEFAULT_PROFILE),
info="Defaults to checkpoint scheduler after Load model",
)
with gr.Row():
height = gr.Slider(
label="height",
minimum=128,
maximum=1024,
step=16,
value=DEFAULT_PROFILE.default_height or 256,
)
width = gr.Slider(
label="width",
minimum=128,
maximum=1024,
step=16,
value=DEFAULT_PROFILE.default_width or 256,
)
with gr.Accordion(
"Advanced: guidance interval",
open=False,
visible=DEFAULT_PROFILE.collection in INTERVAL_COLLECTIONS,
) as interval_accordion:
with gr.Row():
guidance_interval_start = gr.Slider(
label="guidance_interval_start / [0]",
minimum=0.0,
maximum=1.0,
step=0.05,
value=g_start,
)
guidance_interval_end = gr.Slider(
label="guidance_interval_end / [1]",
minimum=0.0,
maximum=1.0,
step=0.05,
value=g_end,
)
with gr.Accordion(
"Advanced: pMF args",
open=False,
visible=DEFAULT_PROFILE.collection == PMF_COLLECTION,
) as pmf_accordion:
with gr.Row():
guidance_interval_min = gr.Slider(
label="guidance_interval_min",
minimum=0.0,
maximum=1.0,
step=0.05,
value=float(extras.get("guidance_interval_min", 0.2)),
)
guidance_interval_max = gr.Slider(
label="guidance_interval_max",
minimum=0.0,
maximum=1.0,
step=0.05,
value=float(extras.get("guidance_interval_max", 0.6)),
)
noise_scale = gr.Slider(
label="noise_scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=float(extras.get("noise_scale", 4.0)),
)
with gr.Row():
load_btn = gr.Button("Load model", variant="secondary")
generate_btn = gr.Button("Generate", variant="primary")
status = gr.Textbox(label="Status", interactive=False, lines=2)
gr.Markdown(
f"**Catalog:** {len(MODEL_LABELS)} variants · {len(COLLECTIONS)} families"
)
with gr.Column(scale=6):
output = gr.Image(
label="Generated image",
type="pil",
height=720,
elem_classes=["output-image"],
)
gr.Examples(
examples=[
[DEFAULT_MODEL, "golden retriever", 42],
[DEFAULT_MODEL, "207", 0],
[DEFAULT_MODEL, "tabby", 123],
[DEFAULT_MODEL, "tabby, tabby cat", 456],
],
inputs=[model, class_label, seed],
label="Quick examples",
)
inference_inputs = [
model,
class_label,
seed,
num_steps,
guidance_scale,
scheduler,
height,
width,
guidance_interval_start,
guidance_interval_end,
guidance_interval_min,
guidance_interval_max,
noise_scale,
]
config_outputs = [
model_info,
class_label,
seed,
num_steps,
guidance_scale,
scheduler,
height,
width,
guidance_interval_start,
guidance_interval_end,
guidance_interval_min,
guidance_interval_max,
noise_scale,
interval_accordion,
pmf_accordion,
]
model.change(on_model_change, inputs=model, outputs=config_outputs)
load_btn.click(load_model, inputs=model, outputs=[status, *config_outputs])
generate_btn.click(generate, inputs=inference_inputs, outputs=[status, output])
demo.load(on_model_change, inputs=model, outputs=config_outputs)
return demo
demo = build_demo()
if __name__ == "__main__":
if not torch.cuda.is_available():
print("CUDA is not available locally; ZeroGPU Spaces will provide GPU at inference time.")
demo.queue(max_size=8).launch()