Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| from typing import List | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from diffusers import StableVideoDiffusionPipeline | |
| from diffusers.utils import export_to_video | |
| MODEL_ID_DEFAULT = os.getenv("MODEL_ID", "stabilityai/stable-video-diffusion-img2vid") | |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| pipe = None | |
| def load_pipeline(model_id: str = MODEL_ID_DEFAULT): | |
| global pipe | |
| if pipe is not None: | |
| return pipe | |
| kwargs = { | |
| "torch_dtype": DTYPE, | |
| } | |
| # fp16 variant helps on GPU spaces | |
| if DTYPE == torch.float16: | |
| kwargs["variant"] = "fp16" | |
| pipe_local = StableVideoDiffusionPipeline.from_pretrained( | |
| model_id, | |
| **kwargs, | |
| ) | |
| # memory & speed tweaks | |
| if torch.cuda.is_available(): | |
| pipe_local.enable_model_cpu_offload() # good default for Spaces GPUs | |
| else: | |
| pipe_local.enable_sequential_cpu_offload() | |
| pipe_local.enable_vae_slicing() | |
| pipe_local.enable_attention_slicing() | |
| pipe = pipe_local | |
| return pipe | |
| def _ensure_rgb(img: Image.Image) -> Image.Image: | |
| if img.mode != "RGB": | |
| return img.convert("RGB") | |
| return img | |
| def generate( | |
| image: Image.Image, | |
| num_frames: int = 14, | |
| fps: int = 8, | |
| motion_bucket_id: int = 127, | |
| noise_aug_strength: float = 0.02, | |
| seed: int = 0, | |
| decode_chunk_size: int = 8, | |
| model_id: str = MODEL_ID_DEFAULT, | |
| ): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| pipe = load_pipeline(model_id) | |
| # Determinism | |
| generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu") | |
| if seed is None or seed < 0: | |
| seed = torch.seed() % (2**31) | |
| generator = generator.manual_seed(int(seed)) | |
| image = _ensure_rgb(image) | |
| with torch.inference_mode(): | |
| result = pipe( | |
| image=image, | |
| num_frames=int(num_frames), | |
| fps=fps, | |
| motion_bucket_id=int(motion_bucket_id), | |
| noise_aug_strength=float(noise_aug_strength), | |
| decode_chunk_size=int(decode_chunk_size), | |
| generator=generator, | |
| ) | |
| frames: List[Image.Image] = result.frames[0] | |
| # Save to a temp .mp4 | |
| tmpdir = tempfile.mkdtemp() | |
| out_path = os.path.join(tmpdir, "output.mp4") | |
| export_to_video(frames, out_path, fps=fps) | |
| return out_path | |
| def build_demo(): | |
| with gr.Blocks(theme=gr.themes.Soft(), fill_width=True) as demo: | |
| gr.Markdown( | |
| """ | |
| # Image → Video (Stable Video Diffusion) | |
| Pretrained **Stable Video Diffusion (Img2Vid)** from the Hugging Face Hub. | |
| - Default model: `stabilityai/stable-video-diffusion-img2vid` | |
| - Try alternative ids like `stabilityai/stable-video-diffusion-img2vid-xt` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp_img = gr.Image(type="pil", label="Input image", width=512) | |
| model_id = gr.Textbox( | |
| value=MODEL_ID_DEFAULT, | |
| label="Model repo id", | |
| info="Any compatible Img2Vid pipeline on the Hub", | |
| ) | |
| with gr.Accordion("Advanced", open=False): | |
| num_frames = gr.Slider(8, 25, value=14, step=1, label="Frames") | |
| fps = gr.Slider(4, 30, value=8, step=1, label="FPS") | |
| motion_bucket_id = gr.Slider(1, 255, value=127, step=1, label="Motion bucket id") | |
| noise_aug_strength = gr.Slider(0.0, 0.5, value=0.02, step=0.01, label="Noise aug strength") | |
| decode_chunk_size = gr.Slider(1, 32, value=8, step=1, label="Decode chunk size") | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 for random)") | |
| run = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| out_vid = gr.Video(label="Output video (.mp4)") | |
| run.click( | |
| fn=generate, | |
| inputs=[ | |
| inp_img, | |
| num_frames, | |
| fps, | |
| motion_bucket_id, | |
| noise_aug_strength, | |
| seed, | |
| decode_chunk_size, | |
| model_id, | |
| ], | |
| outputs=[out_vid], | |
| queue=True, | |
| api_name="predict", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/sketch-mountains-input.jpg", 14, 8, 127, 0.02, 0, 8, MODEL_ID_DEFAULT], | |
| ], | |
| inputs=[inp_img, num_frames, fps, motion_bucket_id, noise_aug_strength, seed, decode_chunk_size, model_id], | |
| label="Try an example (downloads on-click)", | |
| ) | |
| return demo | |
| demo = build_demo() | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |