Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2026 Bytedance Ltd. and/or its affiliate | |
| # Licensed under the Apache License, Version 2.0 | |
| """Bernini Renderer Gradio demo — HuggingFace Spaces edition.""" | |
| import os | |
| import tempfile | |
| from datetime import datetime | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from bernini.pipeline import BerniniRendererPipeline | |
| from bernini.cli import DEFAULT_NEG_PROMPT, GUIDANCE_MODES | |
| from bernini.prompt_enhancer import PromptEnhancer, get_system_prompt_for_task | |
| HF_MODEL_ID = "ByteDance/Bernini-R-Diffusers" | |
| SAVE_BASE = tempfile.mkdtemp(prefix="bernini_gradio_") | |
| os.makedirs(SAVE_BASE, exist_ok=True) | |
| # Prompt Enhancement — configured via HF Secrets, not exposed to users | |
| _PE_API_KEY = os.environ.get("BERNINI_PE_API_KEY", "") | |
| _PE_BASE_URL = os.environ.get("BERNINI_PE_BASE_URL", "") | |
| _PE_MODEL = os.environ.get("BERNINI_PE_MODEL", "") | |
| TASK_TYPE_CHOICES = ["t2i", "t2v", "i2i", "v2v", "mv2v", "r2v", "rv2v", "ads2v"] | |
| GUIDANCE_MODE_BY_TASK = { | |
| "t2i": "t2v_apg", | |
| "t2v": "t2v_apg", | |
| "i2i": "v2v", | |
| "v2v": "v2v_apg", | |
| "mv2v": "v2v_apg", | |
| "r2v": "r2v_apg", | |
| "rv2v": "rv2v", | |
| "ads2v": "v2v_apg", | |
| } | |
| TASK_INPUTS = { | |
| "t2i": {"video": False, "image_role": "none", "images": False}, | |
| "t2v": {"video": False, "image_role": "none", "images": False}, | |
| "i2i": {"video": False, "image_role": "source", "images": False}, | |
| "v2v": {"video": True, "image_role": "none", "images": False}, | |
| "mv2v": {"video": True, "image_role": "none", "images": False}, | |
| "r2v": {"video": False, "image_role": "reference", "images": True}, | |
| "rv2v": {"video": True, "image_role": "reference", "images": True}, | |
| "ads2v": {"video": True, "image_role": "reference", "images": True}, | |
| } | |
| IMAGE_TASKS = {"t2i", "i2i"} | |
| PIPELINE = None | |
| def get_pipeline(): | |
| global PIPELINE | |
| if PIPELINE is None: | |
| print(f"Loading pipeline from {HF_MODEL_ID} ...") | |
| PIPELINE = BerniniRendererPipeline.from_pretrained( | |
| HF_MODEL_ID, | |
| device=torch.device("cuda"), | |
| load_ckpt_weights=False, | |
| use_unipc=True, | |
| use_src_id_rotary_emb=True, | |
| ) | |
| print("Pipeline loaded.") | |
| return PIPELINE | |
| def _coerce_video_paths(video_input): | |
| if not video_input: | |
| return None | |
| if isinstance(video_input, str): | |
| return [video_input] | |
| if isinstance(video_input, list): | |
| out = [] | |
| for v in video_input: | |
| if v is None: | |
| continue | |
| if isinstance(v, str): | |
| out.append(v) | |
| elif hasattr(v, "name"): | |
| out.append(v.name) | |
| elif isinstance(v, dict) and v.get("path"): | |
| out.append(v["path"]) | |
| return out or None | |
| return None | |
| def _coerce_gallery_paths(gallery_input): | |
| if not gallery_input: | |
| return None | |
| out = [] | |
| for item in gallery_input: | |
| if isinstance(item, (list, tuple)) and item: | |
| item = item[0] | |
| if isinstance(item, str): | |
| out.append(item) | |
| elif isinstance(item, dict) and item.get("path"): | |
| out.append(item["path"]) | |
| elif hasattr(item, "name"): | |
| out.append(item.name) | |
| return out or None | |
| def _output_path(task_type): | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| ext = "png" if task_type in IMAGE_TASKS else "mp4" | |
| return os.path.join(SAVE_BASE, f"{task_type}_{ts}.{ext}") | |
| def _build_kwargs( | |
| prompt, task_type, video_input, image_input, gallery_input, guidance_mode, | |
| max_image_size, num_inference_steps, num_frames, flow_shift, seed, fps, | |
| height, width, omega_V, omega_I, omega_TI, omega_scale, eta, momentum, | |
| ): | |
| needs = TASK_INPUTS[task_type] | |
| video = _coerce_video_paths(video_input) if needs["video"] else None | |
| images = _coerce_gallery_paths(gallery_input) if needs["images"] else None | |
| image = None | |
| if needs["image_role"] == "source": | |
| image = image_input or None | |
| elif needs["image_role"] == "reference" and image_input: | |
| images = [image_input] + (images or []) | |
| if task_type in IMAGE_TASKS: | |
| num_frames = 1 | |
| return dict( | |
| prompt=prompt or "", | |
| neg_prompt=DEFAULT_NEG_PROMPT, | |
| video=video, image=image, images=images, | |
| max_image_size=int(max_image_size), | |
| num_inference_steps=int(num_inference_steps), | |
| num_frames=int(num_frames), | |
| flow_shift=float(flow_shift), | |
| seed=int(seed), fps=int(fps), | |
| height=int(height), width=int(width), | |
| guidance_mode=guidance_mode or GUIDANCE_MODE_BY_TASK[task_type], | |
| omega_V=float(omega_V), omega_I=float(omega_I), | |
| omega_TI=float(omega_TI), omega_scale=float(omega_scale), | |
| eta=float(eta), momentum=float(momentum), | |
| system_prompt=get_system_prompt_for_task(task_type), | |
| ) | |
| def generate_handler( | |
| prompt, task_type, video_input, image_input, gallery_input, | |
| guidance_mode, max_image_size, num_inference_steps, num_frames, | |
| flow_shift, seed, fps, height, width, | |
| omega_V, omega_I, omega_TI, omega_scale, eta, momentum, | |
| progress=gr.Progress(), | |
| ): | |
| if not task_type: | |
| gr.Warning("Please select a task type first!") | |
| return None, None, "", "Please select a task type first!" | |
| if not (prompt or "").strip(): | |
| gr.Warning("Please enter a prompt!") | |
| return None, None, "", "Please enter a prompt!" | |
| kwargs = _build_kwargs( | |
| prompt, task_type, video_input, image_input, gallery_input, | |
| guidance_mode, max_image_size, num_inference_steps, num_frames, | |
| flow_shift, seed, fps, height, width, | |
| omega_V, omega_I, omega_TI, omega_scale, eta, momentum, | |
| ) | |
| # Prompt enhancement via server-side key (not exposed to users) | |
| if _PE_API_KEY: | |
| try: | |
| rewriter = PromptEnhancer( | |
| api_key=_PE_API_KEY, | |
| base_url=_PE_BASE_URL or None, | |
| model=_PE_MODEL or None, | |
| ) | |
| enhanced = rewriter( | |
| task_type, | |
| kwargs["prompt"], | |
| video=kwargs.get("video"), | |
| image=kwargs.get("image"), | |
| images=kwargs.get("images"), | |
| ) | |
| if enhanced: | |
| kwargs["prompt"] = enhanced | |
| except Exception as e: | |
| gr.Warning(f"Prompt enhancement failed: {e}. Using original prompt.") | |
| kwargs["output_path"] = _output_path(task_type) | |
| pipeline = get_pipeline() | |
| try: | |
| output_path = pipeline(write_output=True, **kwargs) | |
| except Exception as e: | |
| return None, None, kwargs["prompt"], f"Generation failed: {e}" | |
| out_video = out_image = None | |
| if output_path: | |
| if output_path.endswith(".png") or task_type in IMAGE_TASKS: | |
| out_image = output_path | |
| else: | |
| out_video = output_path | |
| return out_video, out_image, kwargs["prompt"], f"Done: {output_path}" | |
| def _on_task_change(task_type): | |
| auto = GUIDANCE_MODE_BY_TASK.get(task_type) if task_type else None | |
| needs = TASK_INPUTS.get(task_type, {}) | |
| bits = [] | |
| if needs.get("video"): | |
| bits.append("source video") | |
| if needs.get("image_role") == "source": | |
| bits.append("single source image") | |
| if needs.get("image_role") == "reference" or needs.get("images"): | |
| bits.append("reference image(s)") | |
| extra = "inputs: " + ", ".join(bits) if bits else "text-only" | |
| frames = " | forced num_frames=1" if task_type in IMAGE_TASKS else "" | |
| return gr.update(value=auto), f"{extra}{frames}" | |
| with gr.Blocks(title="Bernini Renderer Demo") as demo: | |
| gr.Markdown("# 🎬 Bernini Renderer Demo") | |
| gr.Markdown( | |
| "Unified video generation & editing — text-to-image, text-to-video, " | |
| "image editing, video editing, reference-to-video, and more.\n\n" | |
| "**Paper**: [arXiv 2605.22344](https://arxiv.org/abs/2605.22344) | " | |
| "**Model**: [ByteDance/Bernini-R](https://huggingface.co/ByteDance/Bernini-R)" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Input") | |
| prompt = gr.Textbox(label="Prompt", lines=3, | |
| placeholder="Describe the scene or the editing instruction...") | |
| with gr.Tabs(): | |
| with gr.TabItem("Video"): | |
| video_input = gr.File(label="Upload video(s)", | |
| file_count="multiple", file_types=["video"], type="filepath") | |
| with gr.TabItem("Single image"): | |
| image_input = gr.Image( | |
| label="Upload an image (source for i2i, or a single reference)", | |
| type="filepath") | |
| with gr.TabItem("Multiple images"): | |
| gallery_input = gr.Gallery(label="Upload reference images (r2v / rv2v)", | |
| columns=4, height="auto", interactive=True) | |
| with gr.Group(): | |
| gr.Markdown("### Task") | |
| task_type = gr.Dropdown(choices=TASK_TYPE_CHOICES, value=None, | |
| label="Task type (required)", info="Auto-fills guidance_mode below") | |
| guidance_mode = gr.Dropdown(choices=GUIDANCE_MODES, value=None, label="Guidance mode") | |
| input_hint = gr.Markdown("") | |
| with gr.Group(): | |
| gr.Markdown("### Basic parameters") | |
| with gr.Row(): | |
| max_image_size = gr.Slider(256, 1280, value=848, step=16, label="Max image size") | |
| num_frames = gr.Slider(1, 121, value=49, step=4, label="Num frames") | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider(10, 50, value=40, step=5, label="Inference steps") | |
| flow_shift = gr.Slider(0.0, 12.0, value=5.0, step=0.5, label="Flow shift") | |
| with gr.Row(): | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| fps = gr.Slider(1, 30, value=16, step=1, label="FPS") | |
| with gr.Row(): | |
| height = gr.Number(value=480, precision=0, label="Height") | |
| width = gr.Number(value=848, precision=0, label="Width") | |
| with gr.Accordion("Guidance (advanced)", open=False): | |
| with gr.Row(): | |
| omega_V = gr.Slider(0.0, 10.0, value=1.25, step=0.05, label="omega_V") | |
| omega_I = gr.Slider(0.0, 10.0, value=4.5, step=0.05, label="omega_I") | |
| omega_TI = gr.Slider(0.0, 10.0, value=4.0, step=0.05, label="omega_TI") | |
| with gr.Row(): | |
| omega_scale = gr.Slider(0.0, 2.0, value=0.8, step=0.05, label="omega_scale") | |
| eta = gr.Slider(0.0, 2.0, value=0.5, step=0.05, label="eta") | |
| momentum = gr.Slider(-2.0, 2.0, value=0.0, step=0.05, label="momentum") | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Output") | |
| output_video = gr.Video(label="Generated video") | |
| output_image = gr.Image(label="Generated image") | |
| final_prompt = gr.Textbox(label="Prompt used", interactive=False, lines=3) | |
| output_status = gr.Textbox(label="Status", interactive=False, lines=2) | |
| task_type.change(fn=_on_task_change, inputs=task_type, outputs=[guidance_mode, input_hint]) | |
| generate_btn.click( | |
| fn=generate_handler, | |
| inputs=[ | |
| prompt, task_type, video_input, image_input, gallery_input, | |
| guidance_mode, max_image_size, num_inference_steps, num_frames, | |
| flow_shift, seed, fps, height, width, | |
| omega_V, omega_I, omega_TI, omega_scale, eta, momentum, | |
| ], | |
| outputs=[output_video, output_image, final_prompt, output_status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=5, default_concurrency_limit=1).launch() | |