Spaces:
Running
on
Zero
Running
on
Zero
reverting back, error encountered.
Browse files
app.py
CHANGED
|
@@ -18,16 +18,10 @@ from einops import rearrange
|
|
| 18 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 19 |
import torchvision.transforms as T
|
| 20 |
|
| 21 |
-
# ---
|
| 22 |
-
# Enable TensorFloat-32 (Crucial for H100/H200 speed)
|
| 23 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 24 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 25 |
-
# optimizing for Hopper architecture
|
| 26 |
-
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
|
| 27 |
-
|
| 28 |
-
# --- Project Specific Imports ---
|
| 29 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
| 30 |
from data.image.transforms.na_resize import NaResize
|
|
|
|
| 31 |
from data.video.transforms.rearrange import Rearrange
|
| 32 |
|
| 33 |
if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
|
|
@@ -49,11 +43,10 @@ os.environ["MASTER_PORT"] = "12355"
|
|
| 49 |
os.environ["RANK"] = str(0)
|
| 50 |
os.environ["WORLD_SIZE"] = str(1)
|
| 51 |
|
| 52 |
-
# Install Flash Attention
|
| 53 |
-
# We skip the build check to force it to look at the H200 environment
|
| 54 |
subprocess.run(
|
| 55 |
"pip install flash-attn --no-build-isolation",
|
| 56 |
-
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "
|
| 57 |
shell=True,
|
| 58 |
)
|
| 59 |
|
|
@@ -108,23 +101,13 @@ def configure_runner():
|
|
| 108 |
OmegaConf.set_readonly(runner.config, False)
|
| 109 |
|
| 110 |
# Standard init for single GPU
|
| 111 |
-
init_torch(cudnn_benchmark=
|
| 112 |
|
| 113 |
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
|
| 114 |
runner.configure_vae_model()
|
| 115 |
|
| 116 |
if hasattr(runner.vae, "set_memory_limit"):
|
| 117 |
-
|
| 118 |
-
# runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
|
| 119 |
-
pass
|
| 120 |
-
|
| 121 |
-
# --- H200 OPTIMIZATION: COMPILE DiT ---
|
| 122 |
-
# We use 'max-autotune' because H200 can handle the compilation search space
|
| 123 |
-
# and results in significantly faster kernels than standard eager mode.
|
| 124 |
-
# We disable fullgraph to handle some dynamic control flow if present.
|
| 125 |
-
print("🚀 Optimizing DiT for H200 (max-autotune)... this may take a minute on first run.")
|
| 126 |
-
runner.dit = torch.compile(runner.dit, mode="max-autotune", fullgraph=False)
|
| 127 |
-
|
| 128 |
return runner
|
| 129 |
|
| 130 |
@spaces.GPU(duration=100)
|
|
@@ -199,6 +182,7 @@ def upscale_image(image_path, seed=666, cfg_scale=1.0):
|
|
| 199 |
output_filename = f'output/{uuid.uuid4()}.png'
|
| 200 |
|
| 201 |
# Prepare Transforms
|
|
|
|
| 202 |
video_transform = Compose([
|
| 203 |
NaResize(resolution=(2560 * 1440) ** 0.5, mode="area", downsample_only=False),
|
| 204 |
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
|
|
@@ -210,6 +194,7 @@ def upscale_image(image_path, seed=666, cfg_scale=1.0):
|
|
| 210 |
# Load and Preprocess Image
|
| 211 |
img = Image.open(image_path).convert("RGB")
|
| 212 |
img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
|
|
|
|
| 213 |
video_input = img_tensor.permute(0, 1, 2, 3)
|
| 214 |
|
| 215 |
cond_latents = [video_transform(video_input.to(torch.device("cuda")))]
|
|
@@ -229,6 +214,7 @@ def upscale_image(image_path, seed=666, cfg_scale=1.0):
|
|
| 229 |
# Post-process
|
| 230 |
sample = samples[0]
|
| 231 |
|
|
|
|
| 232 |
input_ref = (
|
| 233 |
rearrange(input_tensor[:, None], "c t h w -> t c h w")
|
| 234 |
if input_tensor.ndim == 3
|
|
@@ -253,20 +239,22 @@ def upscale_image(image_path, seed=666, cfg_scale=1.0):
|
|
| 253 |
result_image = Image.fromarray(sample[0])
|
| 254 |
result_image.save(output_filename)
|
| 255 |
|
| 256 |
-
#
|
| 257 |
-
|
| 258 |
-
|
| 259 |
torch.cuda.empty_cache()
|
| 260 |
|
| 261 |
return result_image, output_filename
|
| 262 |
|
| 263 |
# --- Gradio UI ---
|
| 264 |
|
|
|
|
| 265 |
custom_css = """
|
| 266 |
-
/* Font Import handled by Theme */
|
| 267 |
.gradio-container {
|
| 268 |
font-family: 'IBM Plex Sans', sans-serif !important;
|
| 269 |
}
|
|
|
|
| 270 |
h1 {
|
| 271 |
text-align: center;
|
| 272 |
color: #FF7043;
|
|
@@ -280,6 +268,7 @@ h3 {
|
|
| 280 |
font-weight: 400 !important;
|
| 281 |
margin-top: 0 !important;
|
| 282 |
}
|
|
|
|
| 283 |
button.primary {
|
| 284 |
background: linear-gradient(135deg, #FF7043 0%, #FF5722 100%) !important;
|
| 285 |
border: none !important;
|
|
@@ -290,6 +279,7 @@ button.primary:hover {
|
|
| 290 |
transform: translateY(-1px);
|
| 291 |
box-shadow: 0 10px 15px -3px rgba(255, 87, 34, 0.3), 0 4px 6px -2px rgba(255, 87, 34, 0.15) !important;
|
| 292 |
}
|
|
|
|
| 293 |
.ui-box {
|
| 294 |
background: white;
|
| 295 |
border: 1px solid #E5E7EB;
|
|
@@ -300,6 +290,7 @@ button.primary:hover {
|
|
| 300 |
display: flex;
|
| 301 |
flex-direction: column;
|
| 302 |
}
|
|
|
|
| 303 |
.footer-link {
|
| 304 |
color: #FF7043;
|
| 305 |
text-decoration: none;
|
|
@@ -310,6 +301,7 @@ button.primary:hover {
|
|
| 310 |
}
|
| 311 |
"""
|
| 312 |
|
|
|
|
| 313 |
theme = gr.themes.Soft(
|
| 314 |
primary_hue="orange",
|
| 315 |
secondary_hue="zinc",
|
|
@@ -319,13 +311,16 @@ theme = gr.themes.Soft(
|
|
| 319 |
).set(
|
| 320 |
body_background_fill="#F9FAFB",
|
| 321 |
block_background_fill="white",
|
| 322 |
-
block_border_width="0px",
|
| 323 |
block_shadow="none",
|
|
|
|
| 324 |
block_label_background_fill="transparent",
|
| 325 |
block_label_text_color="#4B5563",
|
| 326 |
block_label_text_weight="600",
|
| 327 |
block_title_text_color="#1F2937",
|
|
|
|
| 328 |
input_background_fill="#F3F4F6",
|
|
|
|
| 329 |
button_primary_background_fill="#FF7043",
|
| 330 |
button_primary_background_fill_hover="#F4511E",
|
| 331 |
button_primary_text_color="white",
|
|
@@ -342,6 +337,7 @@ with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as d
|
|
| 342 |
)
|
| 343 |
|
| 344 |
with gr.Row(equal_height=True):
|
|
|
|
| 345 |
with gr.Column(scale=1, elem_classes="ui-box"):
|
| 346 |
gr.Markdown("#### Source", elem_id="input-header")
|
| 347 |
input_image = gr.Image(
|
|
@@ -358,9 +354,12 @@ with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as d
|
|
| 358 |
seed_input = gr.Number(label="Seed", value=666, precision=0, container=True)
|
| 359 |
cfg_input = gr.Slider(label="CFG Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1, container=True)
|
| 360 |
|
|
|
|
| 361 |
gr.HTML("<div style='height: 20px;'></div>")
|
|
|
|
| 362 |
run_btn = gr.Button("Upscale Image", variant="primary", size="lg")
|
| 363 |
|
|
|
|
| 364 |
with gr.Column(scale=1, elem_classes="ui-box"):
|
| 365 |
gr.Markdown("#### Result", elem_id="output-header")
|
| 366 |
output_image = gr.Image(
|
|
@@ -377,6 +376,7 @@ with gr.Blocks(theme=theme, css=custom_css, title="SeedVR2 Image Upscaler") as d
|
|
| 377 |
outputs=[output_image, download_file]
|
| 378 |
)
|
| 379 |
|
|
|
|
| 380 |
gr.HTML(
|
| 381 |
"""
|
| 382 |
<div style="text-align: center; margin-top: 40px; margin-bottom: 20px; font-size: 0.9em; color: #6B7280;">
|
|
|
|
| 18 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 19 |
import torchvision.transforms as T
|
| 20 |
|
| 21 |
+
# --- Project Specific Imports (Assumed to be present in repo) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
| 23 |
from data.image.transforms.na_resize import NaResize
|
| 24 |
+
# Note: Keeping Rearrange in case it's a specific wrapper, though typically einops suffices
|
| 25 |
from data.video.transforms.rearrange import Rearrange
|
| 26 |
|
| 27 |
if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
|
|
|
|
| 43 |
os.environ["RANK"] = str(0)
|
| 44 |
os.environ["WORLD_SIZE"] = str(1)
|
| 45 |
|
| 46 |
+
# Install Flash Attention if missing
|
|
|
|
| 47 |
subprocess.run(
|
| 48 |
"pip install flash-attn --no-build-isolation",
|
| 49 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
| 50 |
shell=True,
|
| 51 |
)
|
| 52 |
|
|
|
|
| 101 |
OmegaConf.set_readonly(runner.config, False)
|
| 102 |
|
| 103 |
# Standard init for single GPU
|
| 104 |
+
init_torch(cudnn_benchmark=False)
|
| 105 |
|
| 106 |
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
|
| 107 |
runner.configure_vae_model()
|
| 108 |
|
| 109 |
if hasattr(runner.vae, "set_memory_limit"):
|
| 110 |
+
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
return runner
|
| 112 |
|
| 113 |
@spaces.GPU(duration=100)
|
|
|
|
| 182 |
output_filename = f'output/{uuid.uuid4()}.png'
|
| 183 |
|
| 184 |
# Prepare Transforms
|
| 185 |
+
# Note: Model is optimized for 2560x1440 area equivalent
|
| 186 |
video_transform = Compose([
|
| 187 |
NaResize(resolution=(2560 * 1440) ** 0.5, mode="area", downsample_only=False),
|
| 188 |
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
|
|
|
|
| 194 |
# Load and Preprocess Image
|
| 195 |
img = Image.open(image_path).convert("RGB")
|
| 196 |
img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
|
| 197 |
+
# Model expects (C, T, H, W), for image T=1
|
| 198 |
video_input = img_tensor.permute(0, 1, 2, 3)
|
| 199 |
|
| 200 |
cond_latents = [video_transform(video_input.to(torch.device("cuda")))]
|
|
|
|
| 214 |
# Post-process
|
| 215 |
sample = samples[0]
|
| 216 |
|
| 217 |
+
# Handle tensor shaping for colorfix
|
| 218 |
input_ref = (
|
| 219 |
rearrange(input_tensor[:, None], "c t h w -> t c h w")
|
| 220 |
if input_tensor.ndim == 3
|
|
|
|
| 239 |
result_image = Image.fromarray(sample[0])
|
| 240 |
result_image.save(output_filename)
|
| 241 |
|
| 242 |
+
# Cleanup
|
| 243 |
+
del runner, cond_latents, samples
|
| 244 |
+
gc.collect()
|
| 245 |
torch.cuda.empty_cache()
|
| 246 |
|
| 247 |
return result_image, output_filename
|
| 248 |
|
| 249 |
# --- Gradio UI ---
|
| 250 |
|
| 251 |
+
# Custom CSS for the "Top Tier" look
|
| 252 |
custom_css = """
|
| 253 |
+
/* Font Import handled by Theme, but custom tweaks here */
|
| 254 |
.gradio-container {
|
| 255 |
font-family: 'IBM Plex Sans', sans-serif !important;
|
| 256 |
}
|
| 257 |
+
/* Header Styling */
|
| 258 |
h1 {
|
| 259 |
text-align: center;
|
| 260 |
color: #FF7043;
|
|
|
|
| 268 |
font-weight: 400 !important;
|
| 269 |
margin-top: 0 !important;
|
| 270 |
}
|
| 271 |
+
/* Button Styling - Vibrant Orange */
|
| 272 |
button.primary {
|
| 273 |
background: linear-gradient(135deg, #FF7043 0%, #FF5722 100%) !important;
|
| 274 |
border: none !important;
|
|
|
|
| 279 |
transform: translateY(-1px);
|
| 280 |
box-shadow: 0 10px 15px -3px rgba(255, 87, 34, 0.3), 0 4px 6px -2px rgba(255, 87, 34, 0.15) !important;
|
| 281 |
}
|
| 282 |
+
/* UI Boxes (Groups/Columns) */
|
| 283 |
.ui-box {
|
| 284 |
background: white;
|
| 285 |
border: 1px solid #E5E7EB;
|
|
|
|
| 290 |
display: flex;
|
| 291 |
flex-direction: column;
|
| 292 |
}
|
| 293 |
+
/* Footer Styling */
|
| 294 |
.footer-link {
|
| 295 |
color: #FF7043;
|
| 296 |
text-decoration: none;
|
|
|
|
| 301 |
}
|
| 302 |
"""
|
| 303 |
|
| 304 |
+
# Refined Theme
|
| 305 |
theme = gr.themes.Soft(
|
| 306 |
primary_hue="orange",
|
| 307 |
secondary_hue="zinc",
|
|
|
|
| 311 |
).set(
|
| 312 |
body_background_fill="#F9FAFB",
|
| 313 |
block_background_fill="white",
|
| 314 |
+
block_border_width="0px", # Clean look
|
| 315 |
block_shadow="none",
|
| 316 |
+
# Remove orange background from labels
|
| 317 |
block_label_background_fill="transparent",
|
| 318 |
block_label_text_color="#4B5563",
|
| 319 |
block_label_text_weight="600",
|
| 320 |
block_title_text_color="#1F2937",
|
| 321 |
+
# Input/Output styling
|
| 322 |
input_background_fill="#F3F4F6",
|
| 323 |
+
# Primary Button (Orange)
|
| 324 |
button_primary_background_fill="#FF7043",
|
| 325 |
button_primary_background_fill_hover="#F4511E",
|
| 326 |
button_primary_text_color="white",
|
|
|
|
| 337 |
)
|
| 338 |
|
| 339 |
with gr.Row(equal_height=True):
|
| 340 |
+
# Left Column: Input
|
| 341 |
with gr.Column(scale=1, elem_classes="ui-box"):
|
| 342 |
gr.Markdown("#### Source", elem_id="input-header")
|
| 343 |
input_image = gr.Image(
|
|
|
|
| 354 |
seed_input = gr.Number(label="Seed", value=666, precision=0, container=True)
|
| 355 |
cfg_input = gr.Slider(label="CFG Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1, container=True)
|
| 356 |
|
| 357 |
+
# Spacer
|
| 358 |
gr.HTML("<div style='height: 20px;'></div>")
|
| 359 |
+
|
| 360 |
run_btn = gr.Button("Upscale Image", variant="primary", size="lg")
|
| 361 |
|
| 362 |
+
# Right Column: Output
|
| 363 |
with gr.Column(scale=1, elem_classes="ui-box"):
|
| 364 |
gr.Markdown("#### Result", elem_id="output-header")
|
| 365 |
output_image = gr.Image(
|
|
|
|
| 376 |
outputs=[output_image, download_file]
|
| 377 |
)
|
| 378 |
|
| 379 |
+
# Footer
|
| 380 |
gr.HTML(
|
| 381 |
"""
|
| 382 |
<div style="text-align: center; margin-top: 40px; margin-bottom: 20px; font-size: 0.9em; color: #6B7280;">
|