Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,230 Bytes
6cbcc74 a6ea067 42f2c22 f14231d 42f2c22 f14231d 42f2c22 d7c03c2 c5715dc 42f2c22 d7c03c2 f14231d 42f2c22 f14231d 42f2c22 f14231d 42f2c22 f14231d 367bee6 f14231d e276ced e3b4db8 d7c03c2 de5b2a1 d7c03c2 de5b2a1 f14231d e3b4db8 f14231d e3b4db8 f14231d e3b4db8 f14231d e3b4db8 f14231d e3b4db8 42f2c22 f14231d 42f2c22 f14231d 42f2c22 367bee6 d62797d 42f2c22 f14231d 96af013 f14231d 42f2c22 f14231d 42f2c22 f14231d 42f2c22 f14231d d62797d f14231d 42f2c22 273be34 f14231d 42f2c22 f14231d d7c03c2 f14231d 96af013 42f2c22 f14231d 42f2c22 d7c03c2 42f2c22 273be34 42f2c22 f14231d 42f2c22 512f3c8 42f2c22 f14231d 42f2c22 f14231d 42f2c22 f14231d 42f2c22 f14231d 512f3c8 42f2c22 f14231d 42f2c22 273be34 42f2c22 f14231d 42f2c22 f14231d 42f2c22 f14231d 273be34 f14231d d7c03c2 f14231d d7c03c2 f14231d d7c03c2 f14231d 42f2c22 f14231d 42f2c22 d7c03c2 f14231d d7c03c2 f14231d d7c03c2 f14231d f34d937 f14231d d7c03c2 f14231d f34d937 f14231d f34d937 f14231d d7c03c2 f14231d f34d937 d7c03c2 f34d937 d7c03c2 f34d937 f14231d d7c03c2 f14231d f34d937 f14231d f34d937 d7c03c2 f34d937 d7c03c2 f34d937 d7c03c2 f34d937 d7c03c2 f14231d f34d937 f14231d d7c03c2 f34d937 f14231d f34d937 d7c03c2 f34d937 d7c03c2 f34d937 d7c03c2 f34d937 f14231d 42f2c22 d7c03c2 f34d937 f14231d f34d937 0d8f0d4 f14231d 0d8f0d4 f14231d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 |
import spaces
import subprocess
import os
import torch
import uuid
import gc
import shutil
import argparse
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file, get_dir
import shlex
import gradio as gr
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
from einops import rearrange
from torchvision.transforms import Compose, Lambda, Normalize
import torchvision.transforms as T
# --- Project Specific Imports (Assumed to be present in repo) ---
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
# Note: Keeping Rearrange in case it's a specific wrapper, though typically einops suffices
from data.video.transforms.rearrange import Rearrange
if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
use_colorfix = True
else:
use_colorfix = False
from common.distributed import init_torch
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.distributed.ops import sync_data
from common.seed import set_seed
from common.partition import partition_by_size
# --- Environment Setup ---
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
# Install Flash Attention if missing
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# --- Model & Resource Downloading ---
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = file_name if file_name else os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
ckpt_dir = Path('./ckpts')
ckpt_dir.mkdir(exist_ok=True)
pretrain_model_url = {
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
}
# Download Weights
if not os.path.exists('./ckpts/seedvr2_ema_3b.pth'):
load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/')
if not os.path.exists('./ckpts/ema_vae.pth'):
load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/')
if not os.path.exists('./pos_emb.pt'):
load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./')
if not os.path.exists('./neg_emb.pt'):
load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./')
if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'):
load_file_from_url(url=pretrain_model_url['apex'], model_dir='./')
subprocess.run(shlex.split("pip install apex-0.1-cp310-cp310-linux_x86_64.whl"))
# --- Core Inference Logic ---
@spaces.GPU(duration=100)
def configure_runner():
"""Initializes the model runner singleton."""
config_path = os.path.join('./configs_3b', 'main.yaml')
config = load_config(config_path)
runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(runner.config, False)
# Standard init for single GPU
init_torch(cudnn_benchmark=False)
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
runner.configure_vae_model()
if hasattr(runner.vae, "set_memory_limit"):
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
return runner
@spaces.GPU(duration=100)
def generation_step(runner, text_embeds_dict, cond_latents):
"""Executes the diffusion generation step."""
def _move_to_cuda(x):
return [i.to(torch.device("cuda")) for i in x]
# Generate noise
noises = [torch.randn_like(latent) for latent in cond_latents]
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
# Sync and move
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
cond_noise_scale = 0.1
def _add_noise(x, aug_noise):
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
t = runner.timestep_transform(t, shape)
x = runner.schedule.forward(x, aug_noise, t)
return x
conditions = [
runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = runner.inference(
noises=noises,
conditions=conditions,
dit_offload=False,
**text_embeds_dict,
)
# Output formatting
samples = [
(rearrange(video[:, None], "c t h w -> t c h w") if video.ndim == 3
else rearrange(video, "c t h w -> t c h w"))
for video in video_tensors
]
return samples
def get_text_embeds():
"""Loads static text embeddings."""
text_pos = torch.load('pos_emb.pt')
text_neg = torch.load('neg_emb.pt')
return {"texts_pos": [text_pos], "texts_neg": [text_neg]}
@spaces.GPU(duration=100)
def upscale_image(image_path, seed=666, cfg_scale=1.0):
if not image_path:
return None, None
# Initialize runner
runner = configure_runner()
# Configure Diffusion
runner.config.diffusion.cfg.scale = cfg_scale
runner.config.diffusion.cfg.rescale = 0.0
runner.config.diffusion.timesteps.sampling.steps = 1 # One-step generation
runner.configure_diffusion()
# Seed
seed = int(seed) % (2**32)
set_seed(seed, same_across_ranks=True)
os.makedirs('output/', exist_ok=True)
output_filename = f'output/{uuid.uuid4()}.png'
# Prepare Transforms
# Note: Model is optimized for 2560x1440 area equivalent
video_transform = Compose([
NaResize(resolution=(2560 * 1440) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
# Load and Preprocess Image
img = Image.open(image_path).convert("RGB")
img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
# Model expects (C, T, H, W), for image T=1
video_input = img_tensor.permute(0, 1, 2, 3)
cond_latents = [video_transform(video_input.to(torch.device("cuda")))]
input_tensor = cond_latents[0] # Keep for colorfix ref
# Encode
cond_latents = runner.vae_encode(cond_latents)
# Get Embeddings
text_embeds = get_text_embeds()
for k in ["texts_pos", "texts_neg"]:
text_embeds[k] = [emb.to(torch.device("cuda")) for emb in text_embeds[k]]
# Inference
samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
# Post-process
sample = samples[0]
# Handle tensor shaping for colorfix
input_ref = (
rearrange(input_tensor[:, None], "c t h w -> t c h w")
if input_tensor.ndim == 3
else rearrange(input_tensor, "c t h w -> t c h w")
)
if use_colorfix:
sample = wavelet_reconstruction(sample.to("cpu"), input_ref[:sample.size(0)].to("cpu"))
else:
sample = sample.to("cpu")
# Final normalization
sample = (
rearrange(sample[:, None], "t c h w -> t h w c")
if sample.ndim == 3
else rearrange(sample, "t c h w -> t h w c")
)
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
sample = sample.to(torch.uint8).numpy()
# Save
result_image = Image.fromarray(sample[0])
result_image.save(output_filename)
# Cleanup
del runner, cond_latents, samples
gc.collect()
torch.cuda.empty_cache()
return result_image, output_filename
# --- Gradio UI ---
# Custom CSS for the "Top Tier" look
custom_css = """
/* Font Import handled by Theme, but custom tweaks here */
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif !important;
}
/* Header Styling */
h1 {
text-align: center;
color: #FF7043;
font-weight: 700 !important;
font-size: 2.5rem !important;
margin-bottom: 0.5rem !important;
}
h3 {
text-align: center;
color: #525252;
font-weight: 400 !important;
margin-top: 0 !important;
}
/* Button Styling - Vibrant Orange */
button.primary {
background: linear-gradient(135deg, #FF7043 0%, #FF5722 100%) !important;
border: none !important;
box-shadow: 0 4px 6px -1px rgba(255, 87, 34, 0.2), 0 2px 4px -1px rgba(255, 87, 34, 0.1) !important;
transition: all 0.2s ease !important;
}
button.primary:hover {
transform: translateY(-1px);
box-shadow: 0 10px 15px -3px rgba(255, 87, 34, 0.3), 0 4px 6px -2px rgba(255, 87, 34, 0.15) !important;
}
/* UI Boxes (Groups/Columns) */
.ui-box {
background: white;
border: 1px solid #E5E7EB;
border-radius: 12px;
padding: 20px;
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06);
height: 100%;
display: flex;
flex-direction: column;
}
/* Footer Styling */
.footer-link {
color: #FF7043;
text-decoration: none;
font-weight: 600;
}
.footer-link:hover {
text-decoration: underline;
}
"""
# Refined Theme
theme = gr.themes.Soft(
primary_hue="orange",
secondary_hue="zinc",
neutral_hue="slate",
font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
radius_size="lg",
).set(
body_background_fill="#F9FAFB",
block_background_fill="white",
block_border_width="0px", # Clean look
block_shadow="none",
# Remove orange background from labels
block_label_background_fill="transparent",
block_label_text_color="#4B5563",
block_label_text_weight="600",
block_title_text_color="#1F2937",
# Input/Output styling
input_background_fill="#F3F4F6",
# Primary Button (Orange)
button_primary_background_fill="#FF7043",
button_primary_background_fill_hover="#F4511E",
button_primary_text_color="white",
)
with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as demo:
with gr.Column():
gr.Markdown(
"""
# 🍊 SeedVR2 Image Upscaler
### Professional One-Step Restoration & Upscaling
"""
)
with gr.Row(equal_height=True):
# Left Column: Input
with gr.Column(scale=1, elem_classes="ui-box"):
gr.Markdown("#### Source", elem_id="input-header")
input_image = gr.Image(
label="Input Image",
type="filepath",
height=400,
sources=["upload", "clipboard"],
show_label=False
)
gr.Markdown("#### Settings", elem_id="settings-header")
with gr.Group():
with gr.Row():
seed_input = gr.Number(label="Seed", value=666, precision=0, container=True)
cfg_input = gr.Slider(label="CFG Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1, container=True)
# Spacer
gr.HTML("<div style='height: 20px;'></div>")
run_btn = gr.Button("Upscale Image", variant="primary", size="lg")
# Right Column: Output
with gr.Column(scale=1, elem_classes="ui-box"):
gr.Markdown("#### Result", elem_id="output-header")
output_image = gr.Image(
label="Restored Result",
interactive=False,
height=400,
show_label=False
)
download_file = gr.File(label="Download High-Res", container=False)
run_btn.click(
fn=upscale_image,
inputs=[input_image, seed_input, cfg_input],
outputs=[output_image, download_file]
)
# Footer
gr.HTML(
"""
<div style="text-align: center; margin-top: 40px; margin-bottom: 20px; font-size: 0.9em; color: #6B7280;">
<p style="margin: 5px 0;">Powered by <b>SeedVR2</b> | One-Step Diffusion Model</p>
<p style="margin: 5px 0;">UI/UX by <a href="https://huggingface.co/bbqhan" target="_blank" class="footer-link">@bbqhan</a></p>
</div>
"""
)
demo.queue()
demo.launch()
|