SeedVR2-3B / app.py
Aduc-sdr's picture
Update app.py
5fad6fa verified
raw
history blame
13.2 kB
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import spaces
import subprocess
import os
import sys
# --- Setup: Clone repository and add it to Python Path ---
# This section ensures all necessary code and model files are available.
# 1. Clone the repository with all its files
subprocess.run("git lfs install", shell=True, check=True)
if not os.path.exists("SeedVR2-3B"):
print("Cloning SeedVR2-3B repository...")
subprocess.run("git clone https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B", shell=True, check=True)
# 2. Add the cloned repository's directory to Python's module search path
repo_dir = "SeedVR2-3B"
# This allows us to import modules like 'data', 'common', etc., from the cloned repo.
sys.path.insert(0, os.path.abspath(repo_dir))
print(f"Repository directory '{os.path.abspath(repo_dir)}' added to Python path.")
# --- Main Application Code ---
# All file paths will now be relative to the cloned repository directory.
import torch
import mediapy
from einops import rearrange
from omegaconf import OmegaConf
import datetime
from tqdm import tqdm
import gc
from PIL import Image
import gradio as gr
from pathlib import Path
import shlex
import uuid
import mimetypes
import torchvision.transforms as T
from torchvision.transforms import Compose, Lambda, Normalize
from torchvision.io.video import read_video
# Imports from the cloned repository
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
from common.config import load_config
from common.distributed import init_torch
from common.distributed.advanced import init_sequence_parallel
from common.seed import set_seed
from common.partition import partition_by_size
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.distributed.ops import sync_data
# Check for color_fix utility
color_fix_path = os.path.join(repo_dir, "projects/video_diffusion_sr/color_fix.py")
if os.path.exists(color_fix_path):
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
use_colorfix = True
else:
use_colorfix = False
print('Note!!!!!! Color fix is not available!')
# --- Environment and Dependencies Setup ---
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
# CORREÇÃO: Usar sys.executable para chamar o pip corretamente
python_executable = sys.executable
subprocess.run(
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
check=True
)
apex_wheel_path = os.path.join(repo_dir, "apex-0.1-cp310-cp310-linux_x86_64.whl")
if os.path.exists(apex_wheel_path):
# CORREÇÃO: Usar sys.executable aqui também
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
print("✅ Apex setup completed.")
# --- Core Functions ---
def configure_sequence_parallel(sp_size):
if sp_size > 1:
init_sequence_parallel(sp_size)
def configure_runner(sp_size):
config_path = os.path.join(repo_dir, 'configs_3b', 'main.yaml')
checkpoint_path = os.path.join(repo_dir, 'ckpts', 'seedvr2_ema_3b.pth')
config = load_config(config_path)
runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(runner.config, False)
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
configure_sequence_parallel(sp_size)
runner.configure_dit_model(device="cuda", checkpoint=checkpoint_path)
runner.configure_vae_model()
if hasattr(runner.vae, "set_memory_limit"):
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
return runner
def generation_step(runner, text_embeds_dict, cond_latents):
def _move_to_cuda(x):
return [i.to(torch.device("cuda")) for i in x]
noises = [torch.randn_like(latent) for latent in cond_latents]
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
print(f"Generating with noise shape: {noises[0].size()}.")
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
cond_noise_scale = 0.1
def _add_noise(x, aug_noise):
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
t = runner.timestep_transform(t, shape)
print(f"Timestep shifting from {1000.0 * cond_noise_scale} to {t}.")
x = runner.schedule.forward(x, aug_noise, t)
return x
conditions = [
runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = runner.inference(
noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict
)
samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
del video_tensors
return samples
@spaces.GPU
def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
if video_path is None:
return None, None, None
runner = configure_runner(1)
def _extract_text_embeds():
positive_prompts_embeds = []
for _ in original_videos_local:
text_pos_embeds = torch.load(os.path.join(repo_dir, 'pos_emb.pt'))
text_neg_embeds = torch.load(os.path.join(repo_dir, 'neg_emb.pt'))
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
gc.collect()
torch.cuda.empty_cache()
return positive_prompts_embeds
def cut_videos(videos, sp_size):
if videos.size(1) > 121:
videos = videos[:, :121]
t = videos.size(1)
if t <= 4 * sp_size:
padding_needed = 4 * sp_size - t + 1
if padding_needed > 0:
padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
videos = torch.cat([videos, padding], dim=1)
return videos
if (t - 1) % (4 * sp_size) == 0:
return videos
else:
padding_needed = 4 * sp_size - ((t - 1) % (4 * sp_size))
padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4 * sp_size) == 0
return videos
runner.config.diffusion.cfg.scale = cfg_scale
runner.config.diffusion.cfg.rescale = cfg_rescale
runner.config.diffusion.timesteps.sampling.steps = sample_steps
runner.configure_diffusion()
seed = int(seed) % (2**32)
set_seed(seed, same_across_ranks=True)
output_base_dir = "output"
os.makedirs(output_base_dir, exist_ok=True)
original_videos = [os.path.basename(video_path)]
original_videos_local = partition_by_size(original_videos, batch_size)
positive_prompts_embeds = _extract_text_embeds()
video_transform = Compose([
NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
cond_latents = []
for _ in videos:
media_type, _ = mimetypes.guess_type(video_path)
is_image = media_type and media_type.startswith("image")
is_video = media_type and media_type.startswith("video")
if is_video:
video, _, _ = read_video(video_path, output_format="TCHW")
video = video / 255.0
if video.size(0) > 121:
video = video[:121]
print(f"Read video size: {video.size()}")
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
elif is_image:
img = Image.open(video_path).convert("RGB")
img_tensor = T.ToTensor()(img).unsqueeze(0)
video = img_tensor
print(f"Read Image size: {video.size()}")
output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
else:
raise ValueError("Unsupported file type")
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
ori_lengths = [v.size(1) for v in cond_latents]
input_videos = cond_latents
if is_video:
cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
print(f"Encoding videos: {[v.size() for v in cond_latents]}")
cond_latents = runner.vae_encode(cond_latents)
for i, emb in enumerate(text_embeds["texts_pos"]):
text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
for i, emb in enumerate(text_embeds["texts_neg"]):
text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
del cond_latents
for _, input_tensor, sample, ori_length in zip(videos, input_videos, samples, ori_lengths):
if ori_length < sample.shape[0]:
sample = sample[:ori_length]
input_tensor = rearrange(input_tensor, "c t h w -> t c h w")
if use_colorfix:
sample = wavelet_reconstruction(sample.to("cpu"), input_tensor[:sample.size(0)].to("cpu"))
else:
sample = sample.to("cpu")
sample = rearrange(sample, "t c h w -> t h w c")
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
sample = sample.to(torch.uint8).numpy()
if is_image:
mediapy.write_image(output_dir, sample[0])
else:
mediapy.write_video(output_dir, sample, fps=fps_out)
gc.collect()
torch.cuda.empty_cache()
if is_image:
return output_dir, None, output_dir
else:
return None, output_dir, output_dir
# --- Gradio UI ---
with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
logo_path = os.path.join(repo_dir, "assets/seedvr_logo.png")
gr.HTML(f"""
<div style='text-align:center; margin-bottom: 10px;'>
<img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
</div>
<p><b>Official Gradio demo</b> for <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.</p>
""")
with gr.Row():
input_file = gr.File(label="Upload image or video", type="filepath")
with gr.Column():
seed = gr.Number(label="Seed", value=666)
fps = gr.Number(label="Output FPS (for video)", value=24)
run_button = gr.Button("Run")
with gr.Row():
output_image = gr.Image(label="Output Image")
output_video = gr.Video(label="Output Video")
download_link = gr.File(label="Download the output")
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
gr.HTML("""
<hr>
<p>If you find SeedVR helpful, please ⭐ the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repository</a>:
<a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
<h4>Notice</h4>
<p>This demo supports up to <b>720p and 121 frames for videos or 2k images</b>. For other use cases, check the <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>GitHub repo</a>.</p>
<h4>Limitations</h4>
<p>May fail on heavy degradations or small-motion AIGC clips, causing oversharpening or poor restoration.</p>
""")
demo.queue().launch(share=True)