bbqhan's picture
reverting back, error encountered.
d7c03c2 verified
import spaces
import subprocess
import os
import torch
import uuid
import gc
import shutil
import argparse
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file, get_dir
import shlex
import gradio as gr
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
from einops import rearrange
from torchvision.transforms import Compose, Lambda, Normalize
import torchvision.transforms as T
# --- Project Specific Imports (Assumed to be present in repo) ---
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
# Note: Keeping Rearrange in case it's a specific wrapper, though typically einops suffices
from data.video.transforms.rearrange import Rearrange
if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
use_colorfix = True
else:
use_colorfix = False
from common.distributed import init_torch
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.distributed.ops import sync_data
from common.seed import set_seed
from common.partition import partition_by_size
# --- Environment Setup ---
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
# Install Flash Attention if missing
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# --- Model & Resource Downloading ---
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = file_name if file_name else os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
ckpt_dir = Path('./ckpts')
ckpt_dir.mkdir(exist_ok=True)
pretrain_model_url = {
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
}
# Download Weights
if not os.path.exists('./ckpts/seedvr2_ema_3b.pth'):
load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/')
if not os.path.exists('./ckpts/ema_vae.pth'):
load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/')
if not os.path.exists('./pos_emb.pt'):
load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./')
if not os.path.exists('./neg_emb.pt'):
load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./')
if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'):
load_file_from_url(url=pretrain_model_url['apex'], model_dir='./')
subprocess.run(shlex.split("pip install apex-0.1-cp310-cp310-linux_x86_64.whl"))
# --- Core Inference Logic ---
@spaces.GPU(duration=100)
def configure_runner():
"""Initializes the model runner singleton."""
config_path = os.path.join('./configs_3b', 'main.yaml')
config = load_config(config_path)
runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(runner.config, False)
# Standard init for single GPU
init_torch(cudnn_benchmark=False)
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
runner.configure_vae_model()
if hasattr(runner.vae, "set_memory_limit"):
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
return runner
@spaces.GPU(duration=100)
def generation_step(runner, text_embeds_dict, cond_latents):
"""Executes the diffusion generation step."""
def _move_to_cuda(x):
return [i.to(torch.device("cuda")) for i in x]
# Generate noise
noises = [torch.randn_like(latent) for latent in cond_latents]
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
# Sync and move
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
cond_noise_scale = 0.1
def _add_noise(x, aug_noise):
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
t = runner.timestep_transform(t, shape)
x = runner.schedule.forward(x, aug_noise, t)
return x
conditions = [
runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = runner.inference(
noises=noises,
conditions=conditions,
dit_offload=False,
**text_embeds_dict,
)
# Output formatting
samples = [
(rearrange(video[:, None], "c t h w -> t c h w") if video.ndim == 3
else rearrange(video, "c t h w -> t c h w"))
for video in video_tensors
]
return samples
def get_text_embeds():
"""Loads static text embeddings."""
text_pos = torch.load('pos_emb.pt')
text_neg = torch.load('neg_emb.pt')
return {"texts_pos": [text_pos], "texts_neg": [text_neg]}
@spaces.GPU(duration=100)
def upscale_image(image_path, seed=666, cfg_scale=1.0):
if not image_path:
return None, None
# Initialize runner
runner = configure_runner()
# Configure Diffusion
runner.config.diffusion.cfg.scale = cfg_scale
runner.config.diffusion.cfg.rescale = 0.0
runner.config.diffusion.timesteps.sampling.steps = 1 # One-step generation
runner.configure_diffusion()
# Seed
seed = int(seed) % (2**32)
set_seed(seed, same_across_ranks=True)
os.makedirs('output/', exist_ok=True)
output_filename = f'output/{uuid.uuid4()}.png'
# Prepare Transforms
# Note: Model is optimized for 2560x1440 area equivalent
video_transform = Compose([
NaResize(resolution=(2560 * 1440) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
# Load and Preprocess Image
img = Image.open(image_path).convert("RGB")
img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
# Model expects (C, T, H, W), for image T=1
video_input = img_tensor.permute(0, 1, 2, 3)
cond_latents = [video_transform(video_input.to(torch.device("cuda")))]
input_tensor = cond_latents[0] # Keep for colorfix ref
# Encode
cond_latents = runner.vae_encode(cond_latents)
# Get Embeddings
text_embeds = get_text_embeds()
for k in ["texts_pos", "texts_neg"]:
text_embeds[k] = [emb.to(torch.device("cuda")) for emb in text_embeds[k]]
# Inference
samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
# Post-process
sample = samples[0]
# Handle tensor shaping for colorfix
input_ref = (
rearrange(input_tensor[:, None], "c t h w -> t c h w")
if input_tensor.ndim == 3
else rearrange(input_tensor, "c t h w -> t c h w")
)
if use_colorfix:
sample = wavelet_reconstruction(sample.to("cpu"), input_ref[:sample.size(0)].to("cpu"))
else:
sample = sample.to("cpu")
# Final normalization
sample = (
rearrange(sample[:, None], "t c h w -> t h w c")
if sample.ndim == 3
else rearrange(sample, "t c h w -> t h w c")
)
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
sample = sample.to(torch.uint8).numpy()
# Save
result_image = Image.fromarray(sample[0])
result_image.save(output_filename)
# Cleanup
del runner, cond_latents, samples
gc.collect()
torch.cuda.empty_cache()
return result_image, output_filename
# --- Gradio UI ---
# Custom CSS for the "Top Tier" look
custom_css = """
/* Font Import handled by Theme, but custom tweaks here */
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif !important;
}
/* Header Styling */
h1 {
text-align: center;
color: #FF7043;
font-weight: 700 !important;
font-size: 2.5rem !important;
margin-bottom: 0.5rem !important;
}
h3 {
text-align: center;
color: #525252;
font-weight: 400 !important;
margin-top: 0 !important;
}
/* Button Styling - Vibrant Orange */
button.primary {
background: linear-gradient(135deg, #FF7043 0%, #FF5722 100%) !important;
border: none !important;
box-shadow: 0 4px 6px -1px rgba(255, 87, 34, 0.2), 0 2px 4px -1px rgba(255, 87, 34, 0.1) !important;
transition: all 0.2s ease !important;
}
button.primary:hover {
transform: translateY(-1px);
box-shadow: 0 10px 15px -3px rgba(255, 87, 34, 0.3), 0 4px 6px -2px rgba(255, 87, 34, 0.15) !important;
}
/* UI Boxes (Groups/Columns) */
.ui-box {
background: white;
border: 1px solid #E5E7EB;
border-radius: 12px;
padding: 20px;
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06);
height: 100%;
display: flex;
flex-direction: column;
}
/* Footer Styling */
.footer-link {
color: #FF7043;
text-decoration: none;
font-weight: 600;
}
.footer-link:hover {
text-decoration: underline;
}
"""
# Refined Theme
theme = gr.themes.Soft(
primary_hue="orange",
secondary_hue="zinc",
neutral_hue="slate",
font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
radius_size="lg",
).set(
body_background_fill="#F9FAFB",
block_background_fill="white",
block_border_width="0px", # Clean look
block_shadow="none",
# Remove orange background from labels
block_label_background_fill="transparent",
block_label_text_color="#4B5563",
block_label_text_weight="600",
block_title_text_color="#1F2937",
# Input/Output styling
input_background_fill="#F3F4F6",
# Primary Button (Orange)
button_primary_background_fill="#FF7043",
button_primary_background_fill_hover="#F4511E",
button_primary_text_color="white",
)
with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as demo:
with gr.Column():
gr.Markdown(
"""
# 🍊 SeedVR2 Image Upscaler
### Professional One-Step Restoration & Upscaling
"""
)
with gr.Row(equal_height=True):
# Left Column: Input
with gr.Column(scale=1, elem_classes="ui-box"):
gr.Markdown("#### Source", elem_id="input-header")
input_image = gr.Image(
label="Input Image",
type="filepath",
height=400,
sources=["upload", "clipboard"],
show_label=False
)
gr.Markdown("#### Settings", elem_id="settings-header")
with gr.Group():
with gr.Row():
seed_input = gr.Number(label="Seed", value=666, precision=0, container=True)
cfg_input = gr.Slider(label="CFG Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1, container=True)
# Spacer
gr.HTML("<div style='height: 20px;'></div>")
run_btn = gr.Button("Upscale Image", variant="primary", size="lg")
# Right Column: Output
with gr.Column(scale=1, elem_classes="ui-box"):
gr.Markdown("#### Result", elem_id="output-header")
output_image = gr.Image(
label="Restored Result",
interactive=False,
height=400,
show_label=False
)
download_file = gr.File(label="Download High-Res", container=False)
run_btn.click(
fn=upscale_image,
inputs=[input_image, seed_input, cfg_input],
outputs=[output_image, download_file]
)
# Footer
gr.HTML(
"""
<div style="text-align: center; margin-top: 40px; margin-bottom: 20px; font-size: 0.9em; color: #6B7280;">
<p style="margin: 5px 0;">Powered by <b>SeedVR2</b> | One-Step Diffusion Model</p>
<p style="margin: 5px 0;">UI/UX by <a href="https://huggingface.co/bbqhan" target="_blank" class="footer-link">@bbqhan</a></p>
</div>
"""
)
demo.queue()
demo.launch()