| import spaces |
| import gradio as gr |
| import torch |
| import torchvision as tv |
| import random, os |
| from diffusers import StableVideoDiffusionPipeline |
| from PIL import Image |
| from glob import glob |
| from typing import Optional |
|
|
| from tdd_svd_scheduler import TDDSVDStochasticIterativeScheduler |
| from utils import load_lora_weights, save_video |
|
|
| |
| LOCAL = False |
|
|
| if LOCAL: |
| svd_path = '/share2/duanyuxuan/diff_playground/diffusers_models/stable-video-diffusion-img2vid-xt-1-1' |
| lora_file_path = '/share2/duanyuxuan/diff_playground/SVD-TDD/svd-xt-1-1_tdd_lora_weights.safetensors' |
| else: |
| svd_path = 'stabilityai/stable-video-diffusion-img2vid-xt-1-1' |
| lora_repo_path = 'RED-AIGC/TDD' |
| lora_weight_name = 'svd-xt-1-1_tdd_lora_weights.safetensors' |
|
|
| if torch.cuda.is_available(): |
| noise_scheduler = TDDSVDStochasticIterativeScheduler(num_train_timesteps = 250, sigma_min = 0.002, sigma_max = 700.0, sigma_data = 1.0, |
| s_noise = 1.0, rho = 7, clip_denoised = False) |
| |
| pipeline = StableVideoDiffusionPipeline.from_pretrained(svd_path, scheduler = noise_scheduler, torch_dtype = torch.float16, variant = "fp16").to('cuda') |
| if LOCAL: |
| load_lora_weights(pipeline.unet, lora_file_path) |
| else: |
| load_lora_weights(pipeline.unet, lora_repo_path, weight_name = lora_weight_name) |
|
|
| max_64_bit_int = 2**63 - 1 |
|
|
| @spaces.GPU |
| def sample( |
| image: Image, |
| seed: Optional[int] = 1, |
| randomize_seed: bool = False, |
| num_inference_steps: int = 4, |
| eta: float = 0.3, |
| min_guidance_scale: float = 1.0, |
| max_guidance_scale: float = 1.0, |
| |
| fps: int = 7, |
| width: int = 512, |
| height: int = 512, |
| num_frames: int = 25, |
| motion_bucket_id: int = 127, |
| output_folder: str = "outputs_gradio", |
| ): |
| pipeline.scheduler.set_eta(eta) |
|
|
| if randomize_seed: |
| seed = random.randint(0, max_64_bit_int) |
| generator = torch.manual_seed(seed) |
|
|
| os.makedirs(output_folder, exist_ok=True) |
| base_count = len(glob(os.path.join(output_folder, "*.mp4"))) |
| video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") |
|
|
| with torch.autocast("cuda"): |
| frames = pipeline( |
| image, height = height, width = width, |
| num_inference_steps = num_inference_steps, |
| min_guidance_scale = min_guidance_scale, |
| max_guidance_scale = max_guidance_scale, |
| num_frames = num_frames, fps = fps, motion_bucket_id = motion_bucket_id, |
| decode_chunk_size = 8, |
| noise_aug_strength = 0.02, |
| generator = generator, |
| ).frames[0] |
| save_video(frames, video_path, fps = fps, quality = 5.0) |
| torch.manual_seed(seed) |
|
|
| return video_path, seed |
|
|
|
|
| def preprocess_image(image, height = 512, width = 512): |
| image = image.convert('RGB') |
| if image.size[0] != image.size[1]: |
| image = tv.transforms.functional.pil_to_tensor(image) |
| image = tv.transforms.functional.center_crop(image, min(image.shape[-2:])) |
| image = tv.transforms.functional.to_pil_image(image) |
| image = image.resize((width, height)) |
| return image |
|
|
| css = """ |
| h1 { |
| text-align: center; |
| display:block; |
| } |
| .gradio-container { |
| max-width: 70.5rem !important; |
| } |
| """ |
|
|
| with gr.Blocks(css = css) as demo: |
| gr.Markdown( |
| """ |
| # Stable Video Diffusion distilled by ✨Target-Driven Distillation✨ |
| |
| Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of *target timestep selection* and *decoupled guidance*, models distilled by TDD can generated highly detailed images with only a few steps. |
| |
| Besides, TDD is also available for distilling video generation models. This space presents TDD-distilled [SVD-xt 1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). |
| |
| [**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD) |
| |
| The codes of this space are built on [AnimateLCM-SVD](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD) and we acknowledge their contribution. |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| image = gr.Image(label="Upload your image", type="pil") |
| generate_btn = gr.Button("Generate") |
| video = gr.Video() |
| with gr.Accordion("Options", open = True): |
| seed = gr.Slider( |
| label="Seed", |
| value=1, |
| randomize=False, |
| minimum=0, |
| maximum=max_64_bit_int, |
| step=1, |
| ) |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=False) |
| min_guidance_scale = gr.Slider( |
| label="Min guidance scale", |
| info="min strength of classifier-free guidance", |
| value=1.0, |
| minimum=1.0, |
| maximum=1.5, |
| ) |
| max_guidance_scale = gr.Slider( |
| label="Max guidance scale", |
| info="max strength of classifier-free guidance, it should not be less than Min guidance scale", |
| value=1.0, |
| minimum=1.0, |
| maximum=3.0, |
| ) |
| num_inference_steps = gr.Slider( |
| label="Num inference steps", |
| info="steps for inference", |
| value=4, |
| minimum=4, |
| maximum=8, |
| step=1, |
| ) |
| eta = gr.Slider( |
| label = "Eta", |
| info = "the value of gamma in gamma-sampling", |
| value = 0.3, |
| minimum = 0.0, |
| maximum = 1.0, |
| step = 0.1, |
| ) |
|
|
| image.upload(fn = preprocess_image, inputs = image, outputs = image, queue = False) |
| generate_btn.click( |
| fn = sample, |
| inputs = [ |
| image, |
| seed, |
| randomize_seed, |
| num_inference_steps, |
| eta, |
| min_guidance_scale, |
| max_guidance_scale, |
| ], |
| outputs = [video, seed], |
| api_name = "video", |
| ) |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| if LOCAL: |
| demo.queue().launch(share=True, server_name='0.0.0.0') |
| else: |
| demo.queue(api_open=False).launch(show_api=False) |