Bernini / app.py
chilupku's picture
Upload app.py with huggingface_hub
279491f verified
# Copyright (c) 2026 Bytedance Ltd. and/or its affiliate
# Licensed under the Apache License, Version 2.0
"""Bernini Renderer Gradio demo — HuggingFace Spaces edition."""
import os
import tempfile
from datetime import datetime
import gradio as gr
import spaces
import torch
from bernini.pipeline import BerniniRendererPipeline
from bernini.cli import DEFAULT_NEG_PROMPT, GUIDANCE_MODES
from bernini.prompt_enhancer import PromptEnhancer, get_system_prompt_for_task
HF_MODEL_ID = "ByteDance/Bernini-R-Diffusers"
SAVE_BASE = tempfile.mkdtemp(prefix="bernini_gradio_")
os.makedirs(SAVE_BASE, exist_ok=True)
# Prompt Enhancement — configured via HF Secrets, not exposed to users
_PE_API_KEY = os.environ.get("BERNINI_PE_API_KEY", "")
_PE_BASE_URL = os.environ.get("BERNINI_PE_BASE_URL", "")
_PE_MODEL = os.environ.get("BERNINI_PE_MODEL", "")
TASK_TYPE_CHOICES = ["t2i", "t2v", "i2i", "v2v", "mv2v", "r2v", "rv2v", "ads2v"]
GUIDANCE_MODE_BY_TASK = {
"t2i": "t2v_apg",
"t2v": "t2v_apg",
"i2i": "v2v",
"v2v": "v2v_apg",
"mv2v": "v2v_apg",
"r2v": "r2v_apg",
"rv2v": "rv2v",
"ads2v": "v2v_apg",
}
TASK_INPUTS = {
"t2i": {"video": False, "image_role": "none", "images": False},
"t2v": {"video": False, "image_role": "none", "images": False},
"i2i": {"video": False, "image_role": "source", "images": False},
"v2v": {"video": True, "image_role": "none", "images": False},
"mv2v": {"video": True, "image_role": "none", "images": False},
"r2v": {"video": False, "image_role": "reference", "images": True},
"rv2v": {"video": True, "image_role": "reference", "images": True},
"ads2v": {"video": True, "image_role": "reference", "images": True},
}
IMAGE_TASKS = {"t2i", "i2i"}
PIPELINE = None
def get_pipeline():
global PIPELINE
if PIPELINE is None:
print(f"Loading pipeline from {HF_MODEL_ID} ...")
PIPELINE = BerniniRendererPipeline.from_pretrained(
HF_MODEL_ID,
device=torch.device("cuda"),
load_ckpt_weights=False,
use_unipc=True,
use_src_id_rotary_emb=True,
)
print("Pipeline loaded.")
return PIPELINE
def _coerce_video_paths(video_input):
if not video_input:
return None
if isinstance(video_input, str):
return [video_input]
if isinstance(video_input, list):
out = []
for v in video_input:
if v is None:
continue
if isinstance(v, str):
out.append(v)
elif hasattr(v, "name"):
out.append(v.name)
elif isinstance(v, dict) and v.get("path"):
out.append(v["path"])
return out or None
return None
def _coerce_gallery_paths(gallery_input):
if not gallery_input:
return None
out = []
for item in gallery_input:
if isinstance(item, (list, tuple)) and item:
item = item[0]
if isinstance(item, str):
out.append(item)
elif isinstance(item, dict) and item.get("path"):
out.append(item["path"])
elif hasattr(item, "name"):
out.append(item.name)
return out or None
def _output_path(task_type):
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
ext = "png" if task_type in IMAGE_TASKS else "mp4"
return os.path.join(SAVE_BASE, f"{task_type}_{ts}.{ext}")
def _build_kwargs(
prompt, task_type, video_input, image_input, gallery_input, guidance_mode,
max_image_size, num_inference_steps, num_frames, flow_shift, seed, fps,
height, width, omega_V, omega_I, omega_TI, omega_scale, eta, momentum,
):
needs = TASK_INPUTS[task_type]
video = _coerce_video_paths(video_input) if needs["video"] else None
images = _coerce_gallery_paths(gallery_input) if needs["images"] else None
image = None
if needs["image_role"] == "source":
image = image_input or None
elif needs["image_role"] == "reference" and image_input:
images = [image_input] + (images or [])
if task_type in IMAGE_TASKS:
num_frames = 1
return dict(
prompt=prompt or "",
neg_prompt=DEFAULT_NEG_PROMPT,
video=video, image=image, images=images,
max_image_size=int(max_image_size),
num_inference_steps=int(num_inference_steps),
num_frames=int(num_frames),
flow_shift=float(flow_shift),
seed=int(seed), fps=int(fps),
height=int(height), width=int(width),
guidance_mode=guidance_mode or GUIDANCE_MODE_BY_TASK[task_type],
omega_V=float(omega_V), omega_I=float(omega_I),
omega_TI=float(omega_TI), omega_scale=float(omega_scale),
eta=float(eta), momentum=float(momentum),
system_prompt=get_system_prompt_for_task(task_type),
)
@spaces.GPU(duration=1200)
def generate_handler(
prompt, task_type, video_input, image_input, gallery_input,
guidance_mode, max_image_size, num_inference_steps, num_frames,
flow_shift, seed, fps, height, width,
omega_V, omega_I, omega_TI, omega_scale, eta, momentum,
progress=gr.Progress(),
):
if not task_type:
gr.Warning("Please select a task type first!")
return None, None, "", "Please select a task type first!"
if not (prompt or "").strip():
gr.Warning("Please enter a prompt!")
return None, None, "", "Please enter a prompt!"
kwargs = _build_kwargs(
prompt, task_type, video_input, image_input, gallery_input,
guidance_mode, max_image_size, num_inference_steps, num_frames,
flow_shift, seed, fps, height, width,
omega_V, omega_I, omega_TI, omega_scale, eta, momentum,
)
# Prompt enhancement via server-side key (not exposed to users)
if _PE_API_KEY:
try:
rewriter = PromptEnhancer(
api_key=_PE_API_KEY,
base_url=_PE_BASE_URL or None,
model=_PE_MODEL or None,
)
enhanced = rewriter(
task_type,
kwargs["prompt"],
video=kwargs.get("video"),
image=kwargs.get("image"),
images=kwargs.get("images"),
)
if enhanced:
kwargs["prompt"] = enhanced
except Exception as e:
gr.Warning(f"Prompt enhancement failed: {e}. Using original prompt.")
kwargs["output_path"] = _output_path(task_type)
pipeline = get_pipeline()
try:
output_path = pipeline(write_output=True, **kwargs)
except Exception as e:
return None, None, kwargs["prompt"], f"Generation failed: {e}"
out_video = out_image = None
if output_path:
if output_path.endswith(".png") or task_type in IMAGE_TASKS:
out_image = output_path
else:
out_video = output_path
return out_video, out_image, kwargs["prompt"], f"Done: {output_path}"
def _on_task_change(task_type):
auto = GUIDANCE_MODE_BY_TASK.get(task_type) if task_type else None
needs = TASK_INPUTS.get(task_type, {})
bits = []
if needs.get("video"):
bits.append("source video")
if needs.get("image_role") == "source":
bits.append("single source image")
if needs.get("image_role") == "reference" or needs.get("images"):
bits.append("reference image(s)")
extra = "inputs: " + ", ".join(bits) if bits else "text-only"
frames = " | forced num_frames=1" if task_type in IMAGE_TASKS else ""
return gr.update(value=auto), f"{extra}{frames}"
with gr.Blocks(title="Bernini Renderer Demo") as demo:
gr.Markdown("# 🎬 Bernini Renderer Demo")
gr.Markdown(
"Unified video generation & editing — text-to-image, text-to-video, "
"image editing, video editing, reference-to-video, and more.\n\n"
"**Paper**: [arXiv 2605.22344](https://arxiv.org/abs/2605.22344) | "
"**Model**: [ByteDance/Bernini-R](https://huggingface.co/ByteDance/Bernini-R)"
)
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Input")
prompt = gr.Textbox(label="Prompt", lines=3,
placeholder="Describe the scene or the editing instruction...")
with gr.Tabs():
with gr.TabItem("Video"):
video_input = gr.File(label="Upload video(s)",
file_count="multiple", file_types=["video"], type="filepath")
with gr.TabItem("Single image"):
image_input = gr.Image(
label="Upload an image (source for i2i, or a single reference)",
type="filepath")
with gr.TabItem("Multiple images"):
gallery_input = gr.Gallery(label="Upload reference images (r2v / rv2v)",
columns=4, height="auto", interactive=True)
with gr.Group():
gr.Markdown("### Task")
task_type = gr.Dropdown(choices=TASK_TYPE_CHOICES, value=None,
label="Task type (required)", info="Auto-fills guidance_mode below")
guidance_mode = gr.Dropdown(choices=GUIDANCE_MODES, value=None, label="Guidance mode")
input_hint = gr.Markdown("")
with gr.Group():
gr.Markdown("### Basic parameters")
with gr.Row():
max_image_size = gr.Slider(256, 1280, value=848, step=16, label="Max image size")
num_frames = gr.Slider(1, 121, value=49, step=4, label="Num frames")
with gr.Row():
num_inference_steps = gr.Slider(10, 50, value=40, step=5, label="Inference steps")
flow_shift = gr.Slider(0.0, 12.0, value=5.0, step=0.5, label="Flow shift")
with gr.Row():
seed = gr.Number(value=42, precision=0, label="Seed")
fps = gr.Slider(1, 30, value=16, step=1, label="FPS")
with gr.Row():
height = gr.Number(value=480, precision=0, label="Height")
width = gr.Number(value=848, precision=0, label="Width")
with gr.Accordion("Guidance (advanced)", open=False):
with gr.Row():
omega_V = gr.Slider(0.0, 10.0, value=1.25, step=0.05, label="omega_V")
omega_I = gr.Slider(0.0, 10.0, value=4.5, step=0.05, label="omega_I")
omega_TI = gr.Slider(0.0, 10.0, value=4.0, step=0.05, label="omega_TI")
with gr.Row():
omega_scale = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="omega_scale")
eta = gr.Slider(0.0, 2.0, value=0.5, step=0.05, label="eta")
momentum = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="momentum")
generate_btn = gr.Button("Generate", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Output")
output_video = gr.Video(label="Generated video")
output_image = gr.Image(label="Generated image")
final_prompt = gr.Textbox(label="Prompt used", interactive=False, lines=3)
output_status = gr.Textbox(label="Status", interactive=False, lines=2)
task_type.change(fn=_on_task_change, inputs=task_type, outputs=[guidance_mode, input_hint])
generate_btn.click(
fn=generate_handler,
inputs=[
prompt, task_type, video_input, image_input, gallery_input,
guidance_mode, max_image_size, num_inference_steps, num_frames,
flow_shift, seed, fps, height, width,
omega_V, omega_I, omega_TI, omega_scale, eta, momentum,
],
outputs=[output_video, output_image, final_prompt, output_status],
)
if __name__ == "__main__":
demo.queue(max_size=5, default_concurrency_limit=1).launch()