import spaces import subprocess import os import torch import uuid import gc import shutil import argparse from pathlib import Path from urllib.parse import urlparse from torch.hub import download_url_to_file, get_dir import shlex import gradio as gr from PIL import Image import numpy as np from omegaconf import OmegaConf from einops import rearrange from torchvision.transforms import Compose, Lambda, Normalize import torchvision.transforms as T # --- Project Specific Imports (Assumed to be present in repo) --- from data.image.transforms.divisible_crop import DivisibleCrop from data.image.transforms.na_resize import NaResize # Note: Keeping Rearrange in case it's a specific wrapper, though typically einops suffices from data.video.transforms.rearrange import Rearrange if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): from projects.video_diffusion_sr.color_fix import wavelet_reconstruction use_colorfix = True else: use_colorfix = False from common.distributed import init_torch from projects.video_diffusion_sr.infer import VideoDiffusionInfer from common.config import load_config from common.distributed.ops import sync_data from common.seed import set_seed from common.partition import partition_by_size # --- Environment Setup --- os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12355" os.environ["RANK"] = str(0) os.environ["WORLD_SIZE"] = str(1) # Install Flash Attention if missing subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) # --- Model & Resource Downloading --- def load_file_from_url(url, model_dir=None, progress=True, file_name=None): if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) filename = file_name if file_name else os.path.basename(parts.path) cached_file = os.path.abspath(os.path.join(model_dir, filename)) if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) return cached_file ckpt_dir = Path('./ckpts') ckpt_dir.mkdir(exist_ok=True) pretrain_model_url = { 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth', 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth', 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt', 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt', 'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl' } # Download Weights if not os.path.exists('./ckpts/seedvr2_ema_3b.pth'): load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/') if not os.path.exists('./ckpts/ema_vae.pth'): load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/') if not os.path.exists('./pos_emb.pt'): load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./') if not os.path.exists('./neg_emb.pt'): load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./') if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'): load_file_from_url(url=pretrain_model_url['apex'], model_dir='./') subprocess.run(shlex.split("pip install apex-0.1-cp310-cp310-linux_x86_64.whl")) # --- Core Inference Logic --- @spaces.GPU(duration=100) def configure_runner(): """Initializes the model runner singleton.""" config_path = os.path.join('./configs_3b', 'main.yaml') config = load_config(config_path) runner = VideoDiffusionInfer(config) OmegaConf.set_readonly(runner.config, False) # Standard init for single GPU init_torch(cudnn_benchmark=False) runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth') runner.configure_vae_model() if hasattr(runner.vae, "set_memory_limit"): runner.vae.set_memory_limit(**runner.config.vae.memory_limit) return runner @spaces.GPU(duration=100) def generation_step(runner, text_embeds_dict, cond_latents): """Executes the diffusion generation step.""" def _move_to_cuda(x): return [i.to(torch.device("cuda")) for i in x] # Generate noise noises = [torch.randn_like(latent) for latent in cond_latents] aug_noises = [torch.randn_like(latent) for latent in cond_latents] # Sync and move noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents))) cond_noise_scale = 0.1 def _add_noise(x, aug_noise): t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None] t = runner.timestep_transform(t, shape) x = runner.schedule.forward(x, aug_noise, t) return x conditions = [ runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise)) for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) ] with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): video_tensors = runner.inference( noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict, ) # Output formatting samples = [ (rearrange(video[:, None], "c t h w -> t c h w") if video.ndim == 3 else rearrange(video, "c t h w -> t c h w")) for video in video_tensors ] return samples def get_text_embeds(): """Loads static text embeddings.""" text_pos = torch.load('pos_emb.pt') text_neg = torch.load('neg_emb.pt') return {"texts_pos": [text_pos], "texts_neg": [text_neg]} @spaces.GPU(duration=100) def upscale_image(image_path, seed=666, cfg_scale=1.0): if not image_path: return None, None # Initialize runner runner = configure_runner() # Configure Diffusion runner.config.diffusion.cfg.scale = cfg_scale runner.config.diffusion.cfg.rescale = 0.0 runner.config.diffusion.timesteps.sampling.steps = 1 # One-step generation runner.configure_diffusion() # Seed seed = int(seed) % (2**32) set_seed(seed, same_across_ranks=True) os.makedirs('output/', exist_ok=True) output_filename = f'output/{uuid.uuid4()}.png' # Prepare Transforms # Note: Model is optimized for 2560x1440 area equivalent video_transform = Compose([ NaResize(resolution=(2560 * 1440) ** 0.5, mode="area", downsample_only=False), Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), DivisibleCrop((16, 16)), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w"), ]) # Load and Preprocess Image img = Image.open(image_path).convert("RGB") img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W) # Model expects (C, T, H, W), for image T=1 video_input = img_tensor.permute(0, 1, 2, 3) cond_latents = [video_transform(video_input.to(torch.device("cuda")))] input_tensor = cond_latents[0] # Keep for colorfix ref # Encode cond_latents = runner.vae_encode(cond_latents) # Get Embeddings text_embeds = get_text_embeds() for k in ["texts_pos", "texts_neg"]: text_embeds[k] = [emb.to(torch.device("cuda")) for emb in text_embeds[k]] # Inference samples = generation_step(runner, text_embeds, cond_latents=cond_latents) # Post-process sample = samples[0] # Handle tensor shaping for colorfix input_ref = ( rearrange(input_tensor[:, None], "c t h w -> t c h w") if input_tensor.ndim == 3 else rearrange(input_tensor, "c t h w -> t c h w") ) if use_colorfix: sample = wavelet_reconstruction(sample.to("cpu"), input_ref[:sample.size(0)].to("cpu")) else: sample = sample.to("cpu") # Final normalization sample = ( rearrange(sample[:, None], "t c h w -> t h w c") if sample.ndim == 3 else rearrange(sample, "t c h w -> t h w c") ) sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() sample = sample.to(torch.uint8).numpy() # Save result_image = Image.fromarray(sample[0]) result_image.save(output_filename) # Cleanup del runner, cond_latents, samples gc.collect() torch.cuda.empty_cache() return result_image, output_filename # --- Gradio UI --- # Custom CSS for the "Top Tier" look custom_css = """ /* Font Import handled by Theme, but custom tweaks here */ .gradio-container { font-family: 'IBM Plex Sans', sans-serif !important; } /* Header Styling */ h1 { text-align: center; color: #FF7043; font-weight: 700 !important; font-size: 2.5rem !important; margin-bottom: 0.5rem !important; } h3 { text-align: center; color: #525252; font-weight: 400 !important; margin-top: 0 !important; } /* Button Styling - Vibrant Orange */ button.primary { background: linear-gradient(135deg, #FF7043 0%, #FF5722 100%) !important; border: none !important; box-shadow: 0 4px 6px -1px rgba(255, 87, 34, 0.2), 0 2px 4px -1px rgba(255, 87, 34, 0.1) !important; transition: all 0.2s ease !important; } button.primary:hover { transform: translateY(-1px); box-shadow: 0 10px 15px -3px rgba(255, 87, 34, 0.3), 0 4px 6px -2px rgba(255, 87, 34, 0.15) !important; } /* UI Boxes (Groups/Columns) */ .ui-box { background: white; border: 1px solid #E5E7EB; border-radius: 12px; padding: 20px; box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06); height: 100%; display: flex; flex-direction: column; } /* Footer Styling */ .footer-link { color: #FF7043; text-decoration: none; font-weight: 600; } .footer-link:hover { text-decoration: underline; } """ # Refined Theme theme = gr.themes.Soft( primary_hue="orange", secondary_hue="zinc", neutral_hue="slate", font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"], radius_size="lg", ).set( body_background_fill="#F9FAFB", block_background_fill="white", block_border_width="0px", # Clean look block_shadow="none", # Remove orange background from labels block_label_background_fill="transparent", block_label_text_color="#4B5563", block_label_text_weight="600", block_title_text_color="#1F2937", # Input/Output styling input_background_fill="#F3F4F6", # Primary Button (Orange) button_primary_background_fill="#FF7043", button_primary_background_fill_hover="#F4511E", button_primary_text_color="white", ) with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as demo: with gr.Column(): gr.Markdown( """ # 🍊 SeedVR2 Image Upscaler ### Professional One-Step Restoration & Upscaling """ ) with gr.Row(equal_height=True): # Left Column: Input with gr.Column(scale=1, elem_classes="ui-box"): gr.Markdown("#### Source", elem_id="input-header") input_image = gr.Image( label="Input Image", type="filepath", height=400, sources=["upload", "clipboard"], show_label=False ) gr.Markdown("#### Settings", elem_id="settings-header") with gr.Group(): with gr.Row(): seed_input = gr.Number(label="Seed", value=666, precision=0, container=True) cfg_input = gr.Slider(label="CFG Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1, container=True) # Spacer gr.HTML("
") run_btn = gr.Button("Upscale Image", variant="primary", size="lg") # Right Column: Output with gr.Column(scale=1, elem_classes="ui-box"): gr.Markdown("#### Result", elem_id="output-header") output_image = gr.Image( label="Restored Result", interactive=False, height=400, show_label=False ) download_file = gr.File(label="Download High-Res", container=False) run_btn.click( fn=upscale_image, inputs=[input_image, seed_input, cfg_input], outputs=[output_image, download_file] ) # Footer gr.HTML( """Powered by SeedVR2 | One-Step Diffusion Model
UI/UX by @bbqhan