Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import subprocess | |
| import os | |
| import torch | |
| import uuid | |
| import gc | |
| import shutil | |
| import argparse | |
| from pathlib import Path | |
| from urllib.parse import urlparse | |
| from torch.hub import download_url_to_file, get_dir | |
| import shlex | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from einops import rearrange | |
| from torchvision.transforms import Compose, Lambda, Normalize | |
| import torchvision.transforms as T | |
| # --- Project Specific Imports (Assumed to be present in repo) --- | |
| from data.image.transforms.divisible_crop import DivisibleCrop | |
| from data.image.transforms.na_resize import NaResize | |
| # Note: Keeping Rearrange in case it's a specific wrapper, though typically einops suffices | |
| from data.video.transforms.rearrange import Rearrange | |
| if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): | |
| from projects.video_diffusion_sr.color_fix import wavelet_reconstruction | |
| use_colorfix = True | |
| else: | |
| use_colorfix = False | |
| from common.distributed import init_torch | |
| from projects.video_diffusion_sr.infer import VideoDiffusionInfer | |
| from common.config import load_config | |
| from common.distributed.ops import sync_data | |
| from common.seed import set_seed | |
| from common.partition import partition_by_size | |
| # --- Environment Setup --- | |
| os.environ["MASTER_ADDR"] = "127.0.0.1" | |
| os.environ["MASTER_PORT"] = "12355" | |
| os.environ["RANK"] = str(0) | |
| os.environ["WORLD_SIZE"] = str(1) | |
| # Install Flash Attention if missing | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| # --- Model & Resource Downloading --- | |
| def load_file_from_url(url, model_dir=None, progress=True, file_name=None): | |
| if model_dir is None: | |
| hub_dir = get_dir() | |
| model_dir = os.path.join(hub_dir, 'checkpoints') | |
| os.makedirs(model_dir, exist_ok=True) | |
| parts = urlparse(url) | |
| filename = file_name if file_name else os.path.basename(parts.path) | |
| cached_file = os.path.abspath(os.path.join(model_dir, filename)) | |
| if not os.path.exists(cached_file): | |
| print(f'Downloading: "{url}" to {cached_file}\n') | |
| download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) | |
| return cached_file | |
| ckpt_dir = Path('./ckpts') | |
| ckpt_dir.mkdir(exist_ok=True) | |
| pretrain_model_url = { | |
| 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth', | |
| 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth', | |
| 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt', | |
| 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt', | |
| 'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl' | |
| } | |
| # Download Weights | |
| if not os.path.exists('./ckpts/seedvr2_ema_3b.pth'): | |
| load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/') | |
| if not os.path.exists('./ckpts/ema_vae.pth'): | |
| load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/') | |
| if not os.path.exists('./pos_emb.pt'): | |
| load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./') | |
| if not os.path.exists('./neg_emb.pt'): | |
| load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./') | |
| if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'): | |
| load_file_from_url(url=pretrain_model_url['apex'], model_dir='./') | |
| subprocess.run(shlex.split("pip install apex-0.1-cp310-cp310-linux_x86_64.whl")) | |
| # --- Core Inference Logic --- | |
| def configure_runner(): | |
| """Initializes the model runner singleton.""" | |
| config_path = os.path.join('./configs_3b', 'main.yaml') | |
| config = load_config(config_path) | |
| runner = VideoDiffusionInfer(config) | |
| OmegaConf.set_readonly(runner.config, False) | |
| # Standard init for single GPU | |
| init_torch(cudnn_benchmark=False) | |
| runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth') | |
| runner.configure_vae_model() | |
| if hasattr(runner.vae, "set_memory_limit"): | |
| runner.vae.set_memory_limit(**runner.config.vae.memory_limit) | |
| return runner | |
| def generation_step(runner, text_embeds_dict, cond_latents): | |
| """Executes the diffusion generation step.""" | |
| def _move_to_cuda(x): | |
| return [i.to(torch.device("cuda")) for i in x] | |
| # Generate noise | |
| noises = [torch.randn_like(latent) for latent in cond_latents] | |
| aug_noises = [torch.randn_like(latent) for latent in cond_latents] | |
| # Sync and move | |
| noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) | |
| noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents))) | |
| cond_noise_scale = 0.1 | |
| def _add_noise(x, aug_noise): | |
| t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale | |
| shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None] | |
| t = runner.timestep_transform(t, shape) | |
| x = runner.schedule.forward(x, aug_noise, t) | |
| return x | |
| conditions = [ | |
| runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise)) | |
| for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) | |
| ] | |
| with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): | |
| video_tensors = runner.inference( | |
| noises=noises, | |
| conditions=conditions, | |
| dit_offload=False, | |
| **text_embeds_dict, | |
| ) | |
| # Output formatting | |
| samples = [ | |
| (rearrange(video[:, None], "c t h w -> t c h w") if video.ndim == 3 | |
| else rearrange(video, "c t h w -> t c h w")) | |
| for video in video_tensors | |
| ] | |
| return samples | |
| def get_text_embeds(): | |
| """Loads static text embeddings.""" | |
| text_pos = torch.load('pos_emb.pt') | |
| text_neg = torch.load('neg_emb.pt') | |
| return {"texts_pos": [text_pos], "texts_neg": [text_neg]} | |
| def upscale_image(image_path, seed=666, cfg_scale=1.0): | |
| if not image_path: | |
| return None, None | |
| # Initialize runner | |
| runner = configure_runner() | |
| # Configure Diffusion | |
| runner.config.diffusion.cfg.scale = cfg_scale | |
| runner.config.diffusion.cfg.rescale = 0.0 | |
| runner.config.diffusion.timesteps.sampling.steps = 1 # One-step generation | |
| runner.configure_diffusion() | |
| # Seed | |
| seed = int(seed) % (2**32) | |
| set_seed(seed, same_across_ranks=True) | |
| os.makedirs('output/', exist_ok=True) | |
| output_filename = f'output/{uuid.uuid4()}.png' | |
| # Prepare Transforms | |
| # Note: Model is optimized for 2560x1440 area equivalent | |
| video_transform = Compose([ | |
| NaResize(resolution=(2560 * 1440) ** 0.5, mode="area", downsample_only=False), | |
| Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), | |
| DivisibleCrop((16, 16)), | |
| Normalize(0.5, 0.5), | |
| Rearrange("t c h w -> c t h w"), | |
| ]) | |
| # Load and Preprocess Image | |
| img = Image.open(image_path).convert("RGB") | |
| img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W) | |
| # Model expects (C, T, H, W), for image T=1 | |
| video_input = img_tensor.permute(0, 1, 2, 3) | |
| cond_latents = [video_transform(video_input.to(torch.device("cuda")))] | |
| input_tensor = cond_latents[0] # Keep for colorfix ref | |
| # Encode | |
| cond_latents = runner.vae_encode(cond_latents) | |
| # Get Embeddings | |
| text_embeds = get_text_embeds() | |
| for k in ["texts_pos", "texts_neg"]: | |
| text_embeds[k] = [emb.to(torch.device("cuda")) for emb in text_embeds[k]] | |
| # Inference | |
| samples = generation_step(runner, text_embeds, cond_latents=cond_latents) | |
| # Post-process | |
| sample = samples[0] | |
| # Handle tensor shaping for colorfix | |
| input_ref = ( | |
| rearrange(input_tensor[:, None], "c t h w -> t c h w") | |
| if input_tensor.ndim == 3 | |
| else rearrange(input_tensor, "c t h w -> t c h w") | |
| ) | |
| if use_colorfix: | |
| sample = wavelet_reconstruction(sample.to("cpu"), input_ref[:sample.size(0)].to("cpu")) | |
| else: | |
| sample = sample.to("cpu") | |
| # Final normalization | |
| sample = ( | |
| rearrange(sample[:, None], "t c h w -> t h w c") | |
| if sample.ndim == 3 | |
| else rearrange(sample, "t c h w -> t h w c") | |
| ) | |
| sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() | |
| sample = sample.to(torch.uint8).numpy() | |
| # Save | |
| result_image = Image.fromarray(sample[0]) | |
| result_image.save(output_filename) | |
| # Cleanup | |
| del runner, cond_latents, samples | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return result_image, output_filename | |
| # --- Gradio UI --- | |
| # Custom CSS for the "Top Tier" look | |
| custom_css = """ | |
| /* Font Import handled by Theme, but custom tweaks here */ | |
| .gradio-container { | |
| font-family: 'IBM Plex Sans', sans-serif !important; | |
| } | |
| /* Header Styling */ | |
| h1 { | |
| text-align: center; | |
| color: #FF7043; | |
| font-weight: 700 !important; | |
| font-size: 2.5rem !important; | |
| margin-bottom: 0.5rem !important; | |
| } | |
| h3 { | |
| text-align: center; | |
| color: #525252; | |
| font-weight: 400 !important; | |
| margin-top: 0 !important; | |
| } | |
| /* Button Styling - Vibrant Orange */ | |
| button.primary { | |
| background: linear-gradient(135deg, #FF7043 0%, #FF5722 100%) !important; | |
| border: none !important; | |
| box-shadow: 0 4px 6px -1px rgba(255, 87, 34, 0.2), 0 2px 4px -1px rgba(255, 87, 34, 0.1) !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| button.primary:hover { | |
| transform: translateY(-1px); | |
| box-shadow: 0 10px 15px -3px rgba(255, 87, 34, 0.3), 0 4px 6px -2px rgba(255, 87, 34, 0.15) !important; | |
| } | |
| /* UI Boxes (Groups/Columns) */ | |
| .ui-box { | |
| background: white; | |
| border: 1px solid #E5E7EB; | |
| border-radius: 12px; | |
| padding: 20px; | |
| box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06); | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| /* Footer Styling */ | |
| .footer-link { | |
| color: #FF7043; | |
| text-decoration: none; | |
| font-weight: 600; | |
| } | |
| .footer-link:hover { | |
| text-decoration: underline; | |
| } | |
| """ | |
| # Refined Theme | |
| theme = gr.themes.Soft( | |
| primary_hue="orange", | |
| secondary_hue="zinc", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| radius_size="lg", | |
| ).set( | |
| body_background_fill="#F9FAFB", | |
| block_background_fill="white", | |
| block_border_width="0px", # Clean look | |
| block_shadow="none", | |
| # Remove orange background from labels | |
| block_label_background_fill="transparent", | |
| block_label_text_color="#4B5563", | |
| block_label_text_weight="600", | |
| block_title_text_color="#1F2937", | |
| # Input/Output styling | |
| input_background_fill="#F3F4F6", | |
| # Primary Button (Orange) | |
| button_primary_background_fill="#FF7043", | |
| button_primary_background_fill_hover="#F4511E", | |
| button_primary_text_color="white", | |
| ) | |
| with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as demo: | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| # 🍊 SeedVR2 Image Upscaler | |
| ### Professional One-Step Restoration & Upscaling | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| # Left Column: Input | |
| with gr.Column(scale=1, elem_classes="ui-box"): | |
| gr.Markdown("#### Source", elem_id="input-header") | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="filepath", | |
| height=400, | |
| sources=["upload", "clipboard"], | |
| show_label=False | |
| ) | |
| gr.Markdown("#### Settings", elem_id="settings-header") | |
| with gr.Group(): | |
| with gr.Row(): | |
| seed_input = gr.Number(label="Seed", value=666, precision=0, container=True) | |
| cfg_input = gr.Slider(label="CFG Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1, container=True) | |
| # Spacer | |
| gr.HTML("<div style='height: 20px;'></div>") | |
| run_btn = gr.Button("Upscale Image", variant="primary", size="lg") | |
| # Right Column: Output | |
| with gr.Column(scale=1, elem_classes="ui-box"): | |
| gr.Markdown("#### Result", elem_id="output-header") | |
| output_image = gr.Image( | |
| label="Restored Result", | |
| interactive=False, | |
| height=400, | |
| show_label=False | |
| ) | |
| download_file = gr.File(label="Download High-Res", container=False) | |
| run_btn.click( | |
| fn=upscale_image, | |
| inputs=[input_image, seed_input, cfg_input], | |
| outputs=[output_image, download_file] | |
| ) | |
| # Footer | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin-top: 40px; margin-bottom: 20px; font-size: 0.9em; color: #6B7280;"> | |
| <p style="margin: 5px 0;">Powered by <b>SeedVR2</b> | One-Step Diffusion Model</p> | |
| <p style="margin: 5px 0;">UI/UX by <a href="https://huggingface.co/bbqhan" target="_blank" class="footer-link">@bbqhan</a></p> | |
| </div> | |
| """ | |
| ) | |
| demo.queue() | |
| demo.launch() | |