bbqhan's picture
testing optimization for h200
c1a0d47 verified
raw
history blame
13.5 kB
import spaces
import subprocess
import os
import torch
import uuid
import gc
import shutil
import argparse
from pathlib import Path
from urllib.parse import urlparse
from torch.hub import download_url_to_file, get_dir
import shlex
import gradio as gr
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
from einops import rearrange
from torchvision.transforms import Compose, Lambda, Normalize
import torchvision.transforms as T
# --- H200 Optimization Flags ---
# Enable TensorFloat-32 (Crucial for H100/H200 speed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# optimizing for Hopper architecture
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
# --- Project Specific Imports ---
from data.image.transforms.divisible_crop import DivisibleCrop
from data.image.transforms.na_resize import NaResize
from data.video.transforms.rearrange import Rearrange
if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
use_colorfix = True
else:
use_colorfix = False
from common.distributed import init_torch
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
from common.config import load_config
from common.distributed.ops import sync_data
from common.seed import set_seed
from common.partition import partition_by_size
# --- Environment Setup ---
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
# Install Flash Attention tailored for H200 (Hopper)
# We skip the build check to force it to look at the H200 environment
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "FALSE", "TORCH_CUDA_ARCH_LIST": "9.0"},
shell=True,
)
# --- Model & Resource Downloading ---
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = file_name if file_name else os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
ckpt_dir = Path('./ckpts')
ckpt_dir.mkdir(exist_ok=True)
pretrain_model_url = {
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
}
# Download Weights
if not os.path.exists('./ckpts/seedvr2_ema_3b.pth'):
load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/')
if not os.path.exists('./ckpts/ema_vae.pth'):
load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/')
if not os.path.exists('./pos_emb.pt'):
load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./')
if not os.path.exists('./neg_emb.pt'):
load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./')
if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'):
load_file_from_url(url=pretrain_model_url['apex'], model_dir='./')
subprocess.run(shlex.split("pip install apex-0.1-cp310-cp310-linux_x86_64.whl"))
# --- Core Inference Logic ---
@spaces.GPU(duration=100)
def configure_runner():
"""Initializes the model runner singleton."""
config_path = os.path.join('./configs_3b', 'main.yaml')
config = load_config(config_path)
runner = VideoDiffusionInfer(config)
OmegaConf.set_readonly(runner.config, False)
# Standard init for single GPU
init_torch(cudnn_benchmark=True) # Benchmark True is safe on H200
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
runner.configure_vae_model()
if hasattr(runner.vae, "set_memory_limit"):
# H200 has massive memory, we can relax limits if config allows
# runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
pass
# --- H200 OPTIMIZATION: COMPILE DiT ---
# We use 'max-autotune' because H200 can handle the compilation search space
# and results in significantly faster kernels than standard eager mode.
# We disable fullgraph to handle some dynamic control flow if present.
print("🚀 Optimizing DiT for H200 (max-autotune)... this may take a minute on first run.")
runner.dit = torch.compile(runner.dit, mode="max-autotune", fullgraph=False)
return runner
@spaces.GPU(duration=100)
def generation_step(runner, text_embeds_dict, cond_latents):
"""Executes the diffusion generation step."""
def _move_to_cuda(x):
return [i.to(torch.device("cuda")) for i in x]
# Generate noise
noises = [torch.randn_like(latent) for latent in cond_latents]
aug_noises = [torch.randn_like(latent) for latent in cond_latents]
# Sync and move
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
cond_noise_scale = 0.1
def _add_noise(x, aug_noise):
t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
t = runner.timestep_transform(t, shape)
x = runner.schedule.forward(x, aug_noise, t)
return x
conditions = [
runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
]
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
video_tensors = runner.inference(
noises=noises,
conditions=conditions,
dit_offload=False,
**text_embeds_dict,
)
# Output formatting
samples = [
(rearrange(video[:, None], "c t h w -> t c h w") if video.ndim == 3
else rearrange(video, "c t h w -> t c h w"))
for video in video_tensors
]
return samples
def get_text_embeds():
"""Loads static text embeddings."""
text_pos = torch.load('pos_emb.pt')
text_neg = torch.load('neg_emb.pt')
return {"texts_pos": [text_pos], "texts_neg": [text_neg]}
@spaces.GPU(duration=100)
def upscale_image(image_path, seed=666, cfg_scale=1.0):
if not image_path:
return None, None
# Initialize runner
runner = configure_runner()
# Configure Diffusion
runner.config.diffusion.cfg.scale = cfg_scale
runner.config.diffusion.cfg.rescale = 0.0
runner.config.diffusion.timesteps.sampling.steps = 1 # One-step generation
runner.configure_diffusion()
# Seed
seed = int(seed) % (2**32)
set_seed(seed, same_across_ranks=True)
os.makedirs('output/', exist_ok=True)
output_filename = f'output/{uuid.uuid4()}.png'
# Prepare Transforms
video_transform = Compose([
NaResize(resolution=(2560 * 1440) ** 0.5, mode="area", downsample_only=False),
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
DivisibleCrop((16, 16)),
Normalize(0.5, 0.5),
Rearrange("t c h w -> c t h w"),
])
# Load and Preprocess Image
img = Image.open(image_path).convert("RGB")
img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
video_input = img_tensor.permute(0, 1, 2, 3)
cond_latents = [video_transform(video_input.to(torch.device("cuda")))]
input_tensor = cond_latents[0] # Keep for colorfix ref
# Encode
cond_latents = runner.vae_encode(cond_latents)
# Get Embeddings
text_embeds = get_text_embeds()
for k in ["texts_pos", "texts_neg"]:
text_embeds[k] = [emb.to(torch.device("cuda")) for emb in text_embeds[k]]
# Inference
samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
# Post-process
sample = samples[0]
input_ref = (
rearrange(input_tensor[:, None], "c t h w -> t c h w")
if input_tensor.ndim == 3
else rearrange(input_tensor, "c t h w -> t c h w")
)
if use_colorfix:
sample = wavelet_reconstruction(sample.to("cpu"), input_ref[:sample.size(0)].to("cpu"))
else:
sample = sample.to("cpu")
# Final normalization
sample = (
rearrange(sample[:, None], "t c h w -> t h w c")
if sample.ndim == 3
else rearrange(sample, "t c h w -> t h w c")
)
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
sample = sample.to(torch.uint8).numpy()
# Save
result_image = Image.fromarray(sample[0])
result_image.save(output_filename)
# Less aggressive cleanup on H200 to keep JIT kernels warm
# del runner
# gc.collect()
torch.cuda.empty_cache()
return result_image, output_filename
# --- Gradio UI ---
custom_css = """
/* Font Import handled by Theme */
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif !important;
}
h1 {
text-align: center;
color: #FF7043;
font-weight: 700 !important;
font-size: 2.5rem !important;
margin-bottom: 0.5rem !important;
}
h3 {
text-align: center;
color: #525252;
font-weight: 400 !important;
margin-top: 0 !important;
}
button.primary {
background: linear-gradient(135deg, #FF7043 0%, #FF5722 100%) !important;
border: none !important;
box-shadow: 0 4px 6px -1px rgba(255, 87, 34, 0.2), 0 2px 4px -1px rgba(255, 87, 34, 0.1) !important;
transition: all 0.2s ease !important;
}
button.primary:hover {
transform: translateY(-1px);
box-shadow: 0 10px 15px -3px rgba(255, 87, 34, 0.3), 0 4px 6px -2px rgba(255, 87, 34, 0.15) !important;
}
.ui-box {
background: white;
border: 1px solid #E5E7EB;
border-radius: 12px;
padding: 20px;
box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06);
height: 100%;
display: flex;
flex-direction: column;
}
.footer-link {
color: #FF7043;
text-decoration: none;
font-weight: 600;
}
.footer-link:hover {
text-decoration: underline;
}
"""
theme = gr.themes.Soft(
primary_hue="orange",
secondary_hue="zinc",
neutral_hue="slate",
font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
radius_size="lg",
).set(
body_background_fill="#F9FAFB",
block_background_fill="white",
block_border_width="0px",
block_shadow="none",
block_label_background_fill="transparent",
block_label_text_color="#4B5563",
block_label_text_weight="600",
block_title_text_color="#1F2937",
input_background_fill="#F3F4F6",
button_primary_background_fill="#FF7043",
button_primary_background_fill_hover="#F4511E",
button_primary_text_color="white",
)
with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as demo:
with gr.Column():
gr.Markdown(
"""
# 🍊 SeedVR2 Image Upscaler
### Professional One-Step Restoration & Upscaling
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, elem_classes="ui-box"):
gr.Markdown("#### Source", elem_id="input-header")
input_image = gr.Image(
label="Input Image",
type="filepath",
height=400,
sources=["upload", "clipboard"],
show_label=False
)
gr.Markdown("#### Settings", elem_id="settings-header")
with gr.Group():
with gr.Row():
seed_input = gr.Number(label="Seed", value=666, precision=0, container=True)
cfg_input = gr.Slider(label="CFG Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1, container=True)
gr.HTML("<div style='height: 20px;'></div>")
run_btn = gr.Button("Upscale Image", variant="primary", size="lg")
with gr.Column(scale=1, elem_classes="ui-box"):
gr.Markdown("#### Result", elem_id="output-header")
output_image = gr.Image(
label="Restored Result",
interactive=False,
height=400,
show_label=False
)
download_file = gr.File(label="Download High-Res", container=False)
run_btn.click(
fn=upscale_image,
inputs=[input_image, seed_input, cfg_input],
outputs=[output_image, download_file]
)
gr.HTML(
"""
<div style="text-align: center; margin-top: 40px; margin-bottom: 20px; font-size: 0.9em; color: #6B7280;">
<p style="margin: 5px 0;">Powered by <b>SeedVR2</b> | One-Step Diffusion Model</p>
<p style="margin: 5px 0;">UI/UX by <a href="https://huggingface.co/bbqhan" target="_blank" class="footer-link">@bbqhan</a></p>
</div>
"""
)
demo.queue()
demo.launch()