Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
"""
FlashVSR Video Upscaling Plugin for Wan2GP
This plugin provides 4x video upscaling using FlashVSR models.
Based on the FlashVSR_plus implementation by lihaoyun6.
Copyright 2025 Wan2GP Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Features:
- 4x video upscaling with AI models
- Support for 8GB GPUs with tile_dit optimization
- Three pipeline variants (Tiny/Tiny-Long/Full)
- Sparse SageAttention for efficient processing
"""
from shared.utils.plugins import WAN2GPPlugin
import gradio as gr
import torch
import torch.nn.functional as F_torch
import numpy as np
import math
def create_feather_mask(size, overlap):
"""
Create a feather mask for blending overlapping tiles.
Matches the upstream FlashVSR_plus implementation.
Args:
size: Tuple of (height, width) of the tile
overlap: Overlap in pixels (already scaled to output resolution)
Returns:
Tensor of shape (1, 1, H, W) with linear ramp feather weights
"""
H, W = size
mask = torch.ones(1, 1, H, W)
if overlap <= 0:
return mask
ramp = torch.linspace(0, 1, overlap)
# Left edge
mask[:, :, :, :overlap] = torch.minimum(mask[:, :, :, :overlap], ramp.view(1, 1, 1, -1))
# Right edge
mask[:, :, :, -overlap:] = torch.minimum(mask[:, :, :, -overlap:], ramp.flip(0).view(1, 1, 1, -1))
# Top edge
mask[:, :, :overlap, :] = torch.minimum(mask[:, :, :overlap, :], ramp.view(1, 1, -1, 1))
# Bottom edge
mask[:, :, -overlap:, :] = torch.minimum(mask[:, :, -overlap:, :], ramp.flip(0).view(1, 1, -1, 1))
return mask
def calculate_tile_coords(height, width, tile_size, overlap):
"""
Calculate tile coordinates for spatial tiling with overlap.
Matches the upstream FlashVSR_plus implementation.
Note: These are coordinates at the ORIGINAL (source) resolution.
The pipeline will upscale each tile, and results are stitched
at the scaled resolution.
Args:
height: Total height of the source image/video
width: Total width of the source image/video
tile_size: Size of each tile (at source resolution)
overlap: Overlap between adjacent tiles in pixels (at source resolution)
Returns:
List of (x1, y1, x2, y2) tuples (note: x1, y1 order matches upstream)
"""
coords = []
stride = tile_size - overlap
num_rows = math.ceil((height - overlap) / stride)
num_cols = math.ceil((width - overlap) / stride)
for r in range(num_rows):
for c in range(num_cols):
y1 = r * stride
x1 = c * stride
y2 = min(y1 + tile_size, height)
x2 = min(x1 + tile_size, width)
# Adjust start if tile is smaller than tile_size at boundary
if y2 - y1 < tile_size:
y1 = max(0, y2 - tile_size)
if x2 - x1 < tile_size:
x1 = max(0, x2 - tile_size)
coords.append((x1, y1, x2, y2))
return coords
def largest_8n1_leq(n):
"""Return largest value <= n of form 8k+1."""
return 0 if n < 1 else ((n - 1) // 8) * 8 + 1
def next_8n5(n):
"""Return next value >= n of form 8k+5."""
return 21 if n < 21 else ((n - 5 + 7) // 8) * 8 + 5
def get_input_params(image_tensor, scale):
"""
Calculate input parameters for FlashVSR pipeline.
Matches upstream FlashVSR_plus implementation.
Args:
image_tensor: Input video tensor of shape (N, H, W, C)
scale: Upscale factor (2 or 4)
Returns:
Tuple of (target_height, target_width, num_frames)
"""
N0, h0, w0, _ = image_tensor.shape
multiple = 128
sW, sH = w0 * scale, h0 * scale
tW = max(multiple, (sW // multiple) * multiple)
tH = max(multiple, (sH // multiple) * multiple)
F = largest_8n1_leq(N0 + 4)
if F == 0:
raise RuntimeError(f"Not enough frames. Got {N0 + 4}.")
return tH, tW, F
def prepare_input_tensor(image_tensor, device, scale=4, dtype=torch.bfloat16):
"""
Prepare input tensor for FlashVSR pipeline.
Matches upstream FlashVSR_plus implementation - prepares LQ_video
with bicubic upscaling to target resolution.
Args:
image_tensor: Input video tensor of shape (N, H, W, C) in [0, 1] range
device: Target device
scale: Upscale factor
dtype: Target dtype
Returns:
Tuple of (LQ_video, target_height, target_width, num_frames)
LQ_video shape: (1, C, F, H, W) in [-1, 1] range
"""
N0, h0, w0, _ = image_tensor.shape
tH, tW, Fs = get_input_params(image_tensor, scale)
frames = []
for i in range(Fs):
frame_idx = min(i, N0 - 1)
frame_slice = image_tensor[frame_idx].to(device)
tensor_bchw = frame_slice.permute(2, 0, 1).unsqueeze(0)
# Bicubic upscale to scaled dimensions
upscaled_tensor = F_torch.interpolate(
tensor_bchw,
size=(h0 * scale, w0 * scale),
mode='bicubic',
align_corners=False
)
# Center crop to aligned target dimensions
l = max(0, (w0 * scale - tW) // 2)
t = max(0, (h0 * scale - tH) // 2)
cropped_tensor = upscaled_tensor[:, :, t:t + tH, l:l + tW]
# Normalize to [-1, 1]
tensor_out = (cropped_tensor.squeeze(0) * 2.0 - 1.0).to('cpu').to(dtype)
frames.append(tensor_out)
vid_stacked = torch.stack(frames, 0)
vid_final = vid_stacked.permute(1, 0, 2, 3).unsqueeze(0) # (1, C, F, H, W)
# Clean VRAM
if torch.cuda.is_available():
torch.cuda.empty_cache()
return vid_final, tH, tW, Fs
def tensor2video(frames_tensor):
"""
Convert output tensor to video frames.
Matches upstream FlashVSR_plus implementation.
Args:
frames_tensor: Tensor of shape (C, F, H, W) or (1, C, F, H, W) in [-1, 1] range
Returns:
Tensor of shape (F, H, W, C) in [0, 1] range
"""
from einops import rearrange
video_squeezed = frames_tensor.squeeze(0) if frames_tensor.dim() == 5 else frames_tensor
video_permuted = rearrange(video_squeezed, "C F H W -> F H W C")
video_final = (video_permuted.float() + 1.0) / 2.0
return video_final
def clean_vram():
"""Clean VRAM by emptying CUDA cache."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
# =============================================================================
# PRE-FLIGHT RESOURCE CHECK FUNCTIONS
# =============================================================================
# Safety factor to account for intermediate activations, VAE overhead, and CUDA workspace
# Based on ComfyUI-FlashVSR_Stable reference implementation
VRAM_SAFETY_FACTOR = 4.0
# OOM threshold - warn if predicted usage exceeds this percentage of available VRAM
OOM_THRESHOLD = 0.95
def estimate_vram_usage(width, height, frames, scale, tiled_vae, tiled_dit, mode="tiny"):
"""
Estimate VRAM usage for FlashVSR upscaling operation.
Based on ComfyUI-FlashVSR_Stable reference which uses SAFETY_FACTOR = 4.0
to account for intermediate activations, VAE overhead, and CUDA workspace.
Args:
width: Input video width
height: Input video height
frames: Number of frames
scale: Scale factor (2 or 4)
tiled_vae: Whether tiled VAE is enabled
tiled_dit: Whether tiled DiT is enabled
mode: Pipeline variant ("tiny", "tiny-long", "full")
Returns:
dict with:
- model_vram_gb: Base model VRAM requirements
- inference_vram_gb: Estimated inference VRAM with safety factor
- output_ram_gb: Estimated system RAM for output phase
- total_vram_gb: Total estimated VRAM needed
- details: Human-readable breakdown string
"""
# Output dimensions
out_width = width * scale
out_height = height * scale
# Base model sizes (approximate, from HuggingFace model cards)
model_sizes = {
"tiny": 2.5, # TCDecoder + DiT in bf16
"tiny-long": 2.5, # Same as tiny but with streaming
"full": 4.0 # Full VAE decoder
}
base_model_gb = model_sizes.get(mode, 2.5)
# Per-frame tensor size at output resolution (bf16 = 2 bytes per element)
# Shape: (C=3, H, W) per frame
bytes_per_frame_bf16 = 3 * out_height * out_width * 2 # bf16
bytes_per_frame_fp32 = 3 * out_height * out_width * 4 # fp32 for intermediate
# DiT processing memory - varies based on tiling
if tiled_dit:
# Tiled: process one tile at a time, much lower peak
tile_size = 256 # Default tile size
effective_pixels = tile_size * tile_size * scale * scale
dit_memory_gb = (effective_pixels * 3 * 2 * VRAM_SAFETY_FACTOR) / (1024**3)
else:
# Full frame: need to hold entire frame in memory
dit_memory_gb = (out_width * out_height * 3 * 2 * VRAM_SAFETY_FACTOR) / (1024**3)
# VAE decoding memory
if tiled_vae:
# Tiled VAE: lower peak memory
vae_memory_gb = 1.0
else:
# Full VAE: scales with output size
vae_memory_gb = (out_width * out_height * 3 * 4) / (1024**3) * 2 # fp32 with overhead
# Inference VRAM = model + dit processing + vae
inference_vram_gb = base_model_gb + dit_memory_gb + vae_memory_gb
# Output phase RAM estimation (for non-streaming modes)
# Need to hold all frames as float32 for processing, then uint8 for output
if mode == "tiny-long":
# Streaming mode: minimal RAM overhead
output_ram_gb = bytes_per_frame_fp32 * 2 / (1024**3) # ~2 frames buffer
else:
# Batch mode: need all frames in memory
# float32 tensor + numpy uint8 copy
output_ram_gb = (frames * bytes_per_frame_fp32 + frames * bytes_per_frame_bf16) / (1024**3)
total_vram_gb = inference_vram_gb
# Build details string
details = (
f"Model: {base_model_gb:.1f}GB, "
f"DiT: {dit_memory_gb:.1f}GB{'(tiled)' if tiled_dit else ''}, "
f"VAE: {vae_memory_gb:.1f}GB{'(tiled)' if tiled_vae else ''}, "
f"Output RAM: {output_ram_gb:.1f}GB"
)
return {
"model_vram_gb": base_model_gb,
"inference_vram_gb": inference_vram_gb,
"output_ram_gb": output_ram_gb,
"total_vram_gb": total_vram_gb,
"details": details
}
def check_resources(estimated_vram_gb, estimated_ram_gb):
"""
Compare estimated resource requirements against available system resources.
Args:
estimated_vram_gb: Estimated VRAM needed in GB
estimated_ram_gb: Estimated system RAM needed in GB
Returns:
dict with:
- vram_ok: Boolean, True if sufficient VRAM
- ram_ok: Boolean, True if sufficient RAM
- available_vram_gb: Available VRAM in GB
- available_ram_gb: Available system RAM in GB
- vram_headroom_gb: VRAM headroom (negative = shortage)
- ram_headroom_gb: RAM headroom (negative = shortage)
- warnings: List of warning messages
"""
import psutil
warnings = []
# Check VRAM
available_vram_gb = 0.0
total_vram_gb = 0.0
if torch.cuda.is_available():
try:
free_vram, total_vram = torch.cuda.mem_get_info()
available_vram_gb = free_vram / (1024**3)
total_vram_gb = total_vram / (1024**3)
except Exception:
# Fallback: estimate from allocated
total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
allocated = torch.cuda.memory_allocated() / (1024**3)
available_vram_gb = total_vram_gb - allocated
# Check system RAM
mem = psutil.virtual_memory()
available_ram_gb = mem.available / (1024**3)
total_ram_gb = mem.total / (1024**3)
# Calculate headroom
vram_headroom_gb = available_vram_gb - estimated_vram_gb
ram_headroom_gb = available_ram_gb - estimated_ram_gb
# Apply OOM threshold
vram_ok = estimated_vram_gb < (available_vram_gb * OOM_THRESHOLD)
ram_ok = estimated_ram_gb < (available_ram_gb * OOM_THRESHOLD)
# Generate warnings
if not vram_ok:
warnings.append(
f"⚠️ VRAM may be insufficient: need ~{estimated_vram_gb:.1f}GB, "
f"available {available_vram_gb:.1f}GB/{total_vram_gb:.1f}GB"
)
if not ram_ok:
warnings.append(
f"⚠️ System RAM may be insufficient for output phase: need ~{estimated_ram_gb:.1f}GB, "
f"available {available_ram_gb:.1f}GB/{total_ram_gb:.1f}GB"
)
return {
"vram_ok": vram_ok,
"ram_ok": ram_ok,
"available_vram_gb": available_vram_gb,
"available_ram_gb": available_ram_gb,
"total_vram_gb": total_vram_gb,
"total_ram_gb": total_ram_gb,
"vram_headroom_gb": vram_headroom_gb,
"ram_headroom_gb": ram_headroom_gb,
"warnings": warnings
}
def get_optimal_settings(width, height, frames, scale, available_vram_gb, available_ram_gb, current_mode="tiny"):
"""
Recommend optimal settings based on available resources.
Provides recommendations for tiling and chunking when resources are limited.
Based on ComfyUI-FlashVSR_Stable frame chunking recommendations:
- 8GB VRAM → 20 frames
- 12GB VRAM → 50 frames
- 16GB VRAM → 100 frames
- 24GB+ VRAM → all frames
Args:
width: Input video width
height: Input video height
frames: Number of frames
scale: Scale factor
available_vram_gb: Available VRAM in GB
available_ram_gb: Available system RAM in GB
current_mode: Current pipeline variant
Returns:
dict with:
- recommended_tiled_vae: Boolean
- recommended_tiled_dit: Boolean
- recommended_mode: String pipeline variant
- recommended_tile_size: Int tile size if tiling recommended
- recommendations: List of recommendation strings
"""
recommendations = []
# Start with current settings
recommended_tiled_vae = False
recommended_tiled_dit = False
recommended_mode = current_mode
recommended_tile_size = 256
# Calculate output size
out_width = width * scale
out_height = height * scale
out_pixels = out_width * out_height
# VRAM-based recommendations
if available_vram_gb < 8:
recommendations.append("🔴 Less than 8GB VRAM - FlashVSR may not run. Consider using a smaller scale factor.")
recommended_tiled_vae = True
recommended_tiled_dit = True
recommended_tile_size = 128
elif available_vram_gb < 10:
recommendations.append("🟡 8-10GB VRAM detected. Recommend: Enable Tiled DiT, use Tiny mode.")
recommended_tiled_vae = True
recommended_tiled_dit = True
recommended_mode = "tiny"
elif available_vram_gb < 12:
recommendations.append("🟢 10-12GB VRAM detected. Recommend: Enable Tiled VAE, use Tiny or Tiny-Long mode.")
recommended_tiled_vae = True
recommended_mode = "tiny" if frames < 120 else "tiny-long"
elif available_vram_gb < 16:
recommendations.append("🟢 12-16GB VRAM detected. Should handle most videos with Tiled VAE enabled.")
recommended_tiled_vae = True
elif available_vram_gb < 24:
recommendations.append("🟢 16-24GB VRAM detected. Can handle large videos, Full mode available.")
else:
recommendations.append("🟢 24GB+ VRAM detected. Full mode with all optimizations available.")
# Output RAM recommendations for large videos
estimated_output_ram = (frames * out_height * out_width * 3 * 6) / (1024**3) # ~6 bytes per pixel (fp32 + overhead)
if estimated_output_ram > available_ram_gb * 0.8:
recommendations.append(
f"🟡 Output phase may use ~{estimated_output_ram:.1f}GB RAM. "
f"Recommend: Use Tiny-Long mode for streaming output."
)
recommended_mode = "tiny-long"
# Large resolution recommendations
if out_pixels > 3840 * 2160: # Larger than 4K
recommendations.append("🟡 Output resolution exceeds 4K. Recommend: Enable Tiled DiT for stability.")
recommended_tiled_dit = True
# High frame count without tiling - computational complexity warning
# DiT attention is O(n²) across frames, so >100 frames without tiling is very slow
if frames > 100 and not recommended_tiled_dit:
recommendations.append(
f"🟡 High frame count ({frames} frames) without Tiled DiT will be very slow. "
f"DiT attention is O(n²) across frames. Recommend: Enable Tiled DiT."
)
recommended_tiled_dit = True
# Long video recommendations
if frames > 200:
recommendations.append(f"🟡 Long video ({frames} frames). Recommend: Use Tiny-Long mode for memory efficiency.")
recommended_mode = "tiny-long"
return {
"recommended_tiled_vae": recommended_tiled_vae,
"recommended_tiled_dit": recommended_tiled_dit,
"recommended_mode": recommended_mode,
"recommended_tile_size": recommended_tile_size,
"recommendations": recommendations
}
def log_processing_summary(start_time, width, height, frames, scale, mode, tiled_vae, tiled_dit):
"""
Log processing summary on completion with peak VRAM usage.
Args:
start_time: Processing start time (from time.time())
width: Input width
height: Input height
frames: Number of frames
scale: Scale factor used
mode: Pipeline variant used
tiled_vae: Whether tiled VAE was used
tiled_dit: Whether tiled DiT was used
"""
import time
elapsed = time.time() - start_time
# Get peak VRAM usage
peak_vram_gb = 0.0
if torch.cuda.is_available():
try:
peak_vram_bytes = torch.cuda.max_memory_reserved()
peak_vram_gb = peak_vram_bytes / (1024**3)
except Exception:
pass
# Calculate throughput
fps_throughput = frames / elapsed if elapsed > 0 else 0
out_width = width * scale
out_height = height * scale
summary = (
f"\n{'='*60}\n"
f"[FlashVSR] Processing Summary\n"
f"{'='*60}\n"
f"Input: {width}x{height} @ {frames} frames\n"
f"Output: {out_width}x{out_height} @ {frames} frames\n"
f"Mode: {mode.upper()} | Scale: {scale}x\n"
f"Tiling: VAE={'Yes' if tiled_vae else 'No'}, DiT={'Yes' if tiled_dit else 'No'}\n"
f"{'='*60}\n"
f"Time: {elapsed:.1f}s ({fps_throughput:.2f} frames/sec)\n"
f"Peak VRAM: {peak_vram_gb:.2f} GB\n"
f"{'='*60}\n"
)
print(summary)
return {
"elapsed_seconds": elapsed,
"peak_vram_gb": peak_vram_gb,
"fps_throughput": fps_throughput
}
def reset_peak_vram_stats():
"""Reset CUDA peak memory statistics for accurate tracking."""
if torch.cuda.is_available():
try:
torch.cuda.reset_peak_memory_stats()
except Exception:
pass
def tensor_upscale_then_center_crop(frame_slice, scale, tW, tH):
"""
Upscale a frame tensor using bicubic interpolation, then center crop.
Args:
frame_slice: Frame tensor of shape (H, W, C) in [0, 1] range
scale: Upscale factor
tW: Target width after cropping
tH: Target height after cropping
Returns:
Tensor of shape (C, H, W) in [0, 1] range
"""
h0, w0, _ = frame_slice.shape
tensor_bchw = frame_slice.permute(2, 0, 1).unsqueeze(0)
# Bicubic upscale to scaled dimensions
upscaled_tensor = F_torch.interpolate(
tensor_bchw,
size=(h0 * scale, w0 * scale),
mode='bicubic',
align_corners=False
)
# Center crop to aligned target dimensions
l = max(0, (w0 * scale - tW) // 2)
t = max(0, (h0 * scale - tH) // 2)
cropped_tensor = upscaled_tensor[:, :, t:t + tH, l:l + tW]
return cropped_tensor.squeeze(0)
def input_tensor_generator(image_tensor, device, scale=4, dtype=torch.bfloat16):
"""
Generator function that yields prepared frame tensors one at a time.
Used by Tiny-Long pipeline for memory-efficient streaming.
Args:
image_tensor: Input video tensor of shape (N, H, W, C) in [0, 1] range
device: Target device
scale: Upscale factor
dtype: Target dtype
Yields:
Tensor of shape (C, H, W) in [-1, 1] range for each frame
"""
N0, h0, w0, _ = image_tensor.shape
tH, tW, F = get_input_params(image_tensor, scale)
for i in range(F):
frame_idx = min(i, N0 - 1)
frame_slice = image_tensor[frame_idx].to(device)
tensor_chw = tensor_upscale_then_center_crop(frame_slice, scale=scale, tW=tW, tH=tH)
tensor_out = tensor_chw * 2.0 - 1.0
del tensor_chw
yield tensor_out.to('cpu').to(dtype)
# =============================================================================
# MEMORY-EFFICIENT VIDEO OUTPUT FUNCTIONS
# =============================================================================
# Enable memory debug logging via environment variable or config
MEMORY_DEBUG = False
def set_memory_debug(enabled: bool):
"""Enable or disable memory debug logging."""
global MEMORY_DEBUG
MEMORY_DEBUG = enabled
def log_memory_usage(context: str = ""):
"""
Log current memory usage if MEMORY_DEBUG is enabled.
Args:
context: Description of what's happening when this is called
"""
if not MEMORY_DEBUG:
return
import psutil
# System RAM
mem = psutil.virtual_memory()
ram_used_gb = (mem.total - mem.available) / (1024**3)
ram_total_gb = mem.total / (1024**3)
# VRAM
vram_used_gb = 0.0
vram_total_gb = 0.0
if torch.cuda.is_available():
try:
vram_used_gb = torch.cuda.memory_allocated() / (1024**3)
vram_total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
except Exception:
pass
print(f"[FlashVSR Memory] {context}: RAM={ram_used_gb:.2f}/{ram_total_gb:.2f}GB, VRAM={vram_used_gb:.2f}/{vram_total_gb:.2f}GB")
def calculate_safe_batch_size(frame_height: int, frame_width: int, available_ram_bytes: int, safety_margin_gb: float = 2.0) -> int:
"""
Calculate safe batch size for video output based on available system RAM.
Estimates memory needed per frame (tensor + numpy copy) and calculates
how many frames can be safely processed at once.
Args:
frame_height: Height of each frame in pixels
frame_width: Width of each frame in pixels
available_ram_bytes: Available system RAM in bytes
safety_margin_gb: RAM to keep free (default 2GB)
Returns:
Safe batch size (minimum 1, maximum 40)
"""
# Per-frame memory estimation:
# - float32 tensor: H * W * 3 * 4 bytes
# - uint8 numpy copy: H * W * 3 * 1 byte
# Total: H * W * 3 * 5 bytes per frame, with 2x safety factor
bytes_per_frame = frame_height * frame_width * 3 * 5 * 2 # 2x for safety
# Subtract safety margin
safety_bytes = int(safety_margin_gb * 1024**3)
usable_ram = max(0, available_ram_bytes - safety_bytes)
# Calculate batch size with bounds
if bytes_per_frame <= 0:
return 40 # Fallback for edge cases
batch_size = usable_ram // bytes_per_frame
batch_size = max(1, min(40, batch_size))
log_memory_usage(f"Calculated batch size: {batch_size} frames (frame size: {frame_height}x{frame_width})")
return batch_size
def streaming_frame_writer(output_path: str, fps: int, quality: int):
"""
Create a context manager for streaming frame output.
Returns a writer object that can append frames one at a time,
avoiding the need to hold all frames in memory.
Args:
output_path: Path to output video file
fps: Frames per second
quality: Video quality (1-10)
Returns:
imageio writer object (use as context manager or call close() when done)
"""
import imageio
def _clamp_quality(q: int) -> int:
try:
q_int = int(q)
except Exception:
q_int = 5
return max(1, min(10, q_int))
def _quality_to_crf(q: int) -> int:
# Map UI quality 1..10 (low..high) to CRF 35..17 (high..low)
# Lower CRF = higher quality.
q = _clamp_quality(q)
crf_low_quality = 35
crf_high_quality = 17
if q == 1:
return crf_low_quality
if q == 10:
return crf_high_quality
return int(round(crf_low_quality - (q - 1) * (crf_low_quality - crf_high_quality) / 9))
# Use explicit encoding settings for player compatibility and to avoid
# edge-cases where imageio's `quality=` maps to parameters some ffmpeg
# builds/codecs handle poorly (observed: quality=10 producing audio-only output).
try:
crf = _quality_to_crf(quality)
return imageio.get_writer(
output_path,
fps=fps,
codec="libx264",
ffmpeg_params=[
"-pix_fmt",
"yuv420p",
"-movflags",
"+faststart",
"-crf",
str(crf),
],
)
except Exception:
return imageio.get_writer(output_path, fps=fps, quality=_clamp_quality(quality))
def write_frames_streaming(output_frames_tensor, output_path: str, fps: int, quality: int, progress_callback=None):
"""
Write video frames to file using streaming to minimize memory usage.
Instead of converting all frames to numpy at once (which causes massive
RAM spikes), this function converts and writes frames in small batches.
Args:
output_frames_tensor: Tensor of shape (F, H, W, C) in [0, 1] range
output_path: Path to output video file
fps: Frames per second
quality: Video quality (1-10)
progress_callback: Optional callback(current, total) for progress updates
Returns:
Tuple of (output_height, output_width) from written frames
"""
import imageio
import psutil
from tqdm import tqdm as tqdm_save
log_memory_usage("Starting streaming write")
# Get frame dimensions
num_frames = output_frames_tensor.shape[0]
frame_height = output_frames_tensor.shape[1]
frame_width = output_frames_tensor.shape[2]
# Calculate safe batch size based on available RAM
available_ram = psutil.virtual_memory().available
batch_size = calculate_safe_batch_size(frame_height, frame_width, available_ram)
print(f"[FlashVSR] Writing {num_frames} frames with batch size {batch_size}")
# Open writer (use the same settings as streaming_frame_writer)
writer = streaming_frame_writer(output_path=output_path, fps=fps, quality=quality)
try:
for batch_start in tqdm_save(range(0, num_frames, batch_size), desc="[FlashVSR] Saving video"):
batch_end = min(batch_start + batch_size, num_frames)
log_memory_usage(f"Processing batch {batch_start}-{batch_end}")
# Extract batch from tensor
batch_tensor = output_frames_tensor[batch_start:batch_end]
# Convert batch to numpy (this is where memory spike happens, but now bounded)
batch_np = (batch_tensor.cpu().float() * 255.0).clip(0, 255).numpy().astype(np.uint8)
# Write each frame in batch
for frame_np in batch_np:
writer.append_data(frame_np)
# Immediately free batch memory
del batch_tensor, batch_np
# Progress callback
if progress_callback:
progress_callback(batch_end, num_frames)
log_memory_usage(f"Completed batch {batch_start}-{batch_end}")
finally:
writer.close()
log_memory_usage("Streaming write complete")
return frame_height, frame_width
def write_canvas_streaming(canvas_tensor, weight_tensor, output_path: str, fps: int, quality: int, frame_count: int):
"""
Normalize canvas by weights and write frames using streaming output.
For tiled processing, the canvas and weights are accumulated during tile
processing. This function normalizes and writes frames in batches to
avoid memory spikes from holding all normalized frames at once.
Args:
canvas_tensor: Accumulated weighted canvas (F, H, W, C)
weight_tensor: Accumulated weights (F, H, W, C)
output_path: Path to output video file
fps: Frames per second
quality: Video quality (1-10)
frame_count: Number of original frames (before padding)
Returns:
Tuple of (output_height, output_width) from written frames
"""
import imageio
import psutil
from tqdm import tqdm as tqdm_save
log_memory_usage("Starting canvas streaming write")
# Get dimensions
num_frames = min(canvas_tensor.shape[0], frame_count)
frame_height = canvas_tensor.shape[1]
frame_width = canvas_tensor.shape[2]
# Calculate safe batch size
available_ram = psutil.virtual_memory().available
batch_size = calculate_safe_batch_size(frame_height, frame_width, available_ram)
print(f"[FlashVSR] Writing {num_frames} frames from canvas with batch size {batch_size}")
writer = streaming_frame_writer(output_path=output_path, fps=fps, quality=quality)
try:
for batch_start in tqdm_save(range(0, num_frames, batch_size), desc="[FlashVSR] Saving video"):
batch_end = min(batch_start + batch_size, num_frames)
log_memory_usage(f"Processing canvas batch {batch_start}-{batch_end}")
# Extract batch from canvas and weights
canvas_batch = canvas_tensor[batch_start:batch_end]
weight_batch = weight_tensor[batch_start:batch_end].clone()
# Avoid division by zero
weight_batch[weight_batch == 0] = 1.0
# Normalize batch
normalized_batch = canvas_batch / weight_batch
# Convert to numpy
batch_np = (normalized_batch.cpu().float() * 255.0).clip(0, 255).numpy().astype(np.uint8)
# Write frames
for frame_np in batch_np:
writer.append_data(frame_np)
# Free batch memory immediately
del canvas_batch, weight_batch, normalized_batch, batch_np
log_memory_usage(f"Completed canvas batch {batch_start}-{batch_end}")
finally:
writer.close()
log_memory_usage("Canvas streaming write complete")
return frame_height, frame_width
def stitch_video_tiles(
tile_paths,
tile_coords,
final_dims,
scale,
overlap,
output_path,
fps,
quality,
cleanup=True,
chunk_size=40
):
"""
Stitch multiple tile videos into a single output video.
Used by Tiny-Long pipeline for tiled processing.
Args:
tile_paths: List of paths to tile video files
tile_coords: List of (x1, y1, x2, y2) coordinates for each tile
final_dims: Tuple of (width, height) for final output
scale: Upscale factor used
overlap: Tile overlap in pixels (at source resolution)
output_path: Path to write the stitched video
fps: Output video FPS
quality: Output video quality (1-10)
cleanup: Whether to remove temp tile files after stitching
chunk_size: Number of frames to process at once (for memory efficiency)
"""
import imageio
import os
from tqdm import tqdm
if not tile_paths:
print("[FlashVSR] No tile videos found to stitch.")
return
final_W, final_H = final_dims
# Open all video files
readers = [imageio.get_reader(p) for p in tile_paths]
try:
# Get total frame count
num_frames = readers[0].count_frames()
if num_frames is None or num_frames <= 0:
num_frames = len([_ for _ in readers[0]])
for r in readers:
r.close()
readers = [imageio.get_reader(p) for p in tile_paths]
# Open output writer
with streaming_frame_writer(output_path=output_path, fps=fps, quality=quality) as writer:
# Process in chunks for memory efficiency
for start_frame in tqdm(range(0, num_frames, chunk_size), desc="[FlashVSR] Stitching Chunks"):
end_frame = min(start_frame + chunk_size, num_frames)
current_chunk_size = end_frame - start_frame
# Create canvas for this chunk
chunk_canvas = np.zeros((current_chunk_size, final_H, final_W, 3), dtype=np.float32)
weight_canvas = np.zeros_like(chunk_canvas, dtype=np.float32)
# Process each tile
for i, reader in enumerate(readers):
try:
# Read frames for this chunk using get_data for random access
tile_chunk_frames = []
for frame_idx in range(start_frame, end_frame):
try:
frame = reader.get_data(frame_idx)
tile_chunk_frames.append(frame.astype(np.float32) / 255.0)
except IndexError:
# Reached end of video
break
if not tile_chunk_frames:
print(f"[FlashVSR] Warning: No frames read from tile {i} for range {start_frame}-{end_frame}")
continue
tile_chunk_np = np.stack(tile_chunk_frames, axis=0)
except Exception as e:
print(f"[FlashVSR] Warning: Could not read chunk from tile {i}: {e}")
continue
if tile_chunk_np.shape[0] != current_chunk_size:
print(f"[FlashVSR] Warning: Tile {i} chunk has {tile_chunk_np.shape[0]} frames, expected {current_chunk_size}. Adjusting...")
# Adjust current_chunk_size for this iteration if needed
actual_chunk_size = tile_chunk_np.shape[0]
else:
actual_chunk_size = current_chunk_size
# Create feather mask
tile_H, tile_W, _ = tile_chunk_np.shape[1:]
scaled_overlap = overlap * scale
if scaled_overlap > 0:
ramp = np.linspace(0, 1, scaled_overlap, dtype=np.float32)
mask = np.ones((tile_H, tile_W, 1), dtype=np.float32)
mask[:, :scaled_overlap, :] *= ramp[np.newaxis, :, np.newaxis]
mask[:, -scaled_overlap:, :] *= np.flip(ramp)[np.newaxis, :, np.newaxis]
mask[:scaled_overlap, :, :] *= ramp[:, np.newaxis, np.newaxis]
mask[-scaled_overlap:, :, :] *= np.flip(ramp)[:, np.newaxis, np.newaxis]
else:
mask = np.ones((tile_H, tile_W, 1), dtype=np.float32)
mask_4d = mask[np.newaxis, :, :, :]
# Blend into canvas
x1_orig, y1_orig, _, _ = tile_coords[i]
out_y1, out_x1 = y1_orig * scale, x1_orig * scale
out_y2, out_x2 = out_y1 + tile_H, out_x1 + tile_W
chunk_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += tile_chunk_np * mask_4d
weight_canvas[:, out_y1:out_y2, out_x1:out_x2, :] += mask_4d
# Normalize and write frames
weight_canvas[weight_canvas == 0] = 1.0
stitched_chunk = chunk_canvas / weight_canvas
for frame_idx_in_chunk in range(current_chunk_size):
frame_uint8 = (np.clip(stitched_chunk[frame_idx_in_chunk], 0, 1) * 255).astype(np.uint8)
writer.append_data(frame_uint8)
finally:
print("[FlashVSR] Closing all tile reader instances...")
for reader in readers:
reader.close()
if cleanup:
print("[FlashVSR] Cleaning up temporary tile files...")
for path in tile_paths:
try:
os.remove(path)
except OSError as e:
print(f"[FlashVSR] Could not remove temporary file '{path}': {e}")
class FlashVSRPlugin(WAN2GPPlugin):
"""
FlashVSR video upscaling plugin for Wan2GP.
This plugin provides AI-powered 4x video upscaling using FlashVSR models,
based on the FlashVSR_plus implementation by lihaoyun6. It supports multiple
pipeline variants optimized for different VRAM configurations:
- Tiny (8-10GB VRAM): Fastest, uses TCDecoder for efficient decoding
- Tiny-Long (10-12GB VRAM): Optimized for long videos (>120 frames)
- Full (18-24GB VRAM): Highest quality, uses full VAE decoder
Key Features:
- Sparse SageAttention for efficient memory usage
- Tile-based processing for low-VRAM GPUs (8GB minimum)
- Automatic model downloading from HuggingFace
- VAE sharing with Wan2GP installation
- Dedicated upscaling tab in Wan2GP interface
Attributes:
name (str): Plugin display name
version (str): Plugin version (semantic versioning)
description (str): Short plugin description
current_pipeline: Currently loaded FlashVSR pipeline instance
models_loaded (bool): Whether models have been downloaded/initialized
Example:
The plugin is automatically discovered and loaded by Wan2GP's plugin
system. Users access it via the "FlashVSR Upscaling" tab.
"""
def __init__(self):
"""
Initialize the FlashVSR plugin.
Sets up plugin metadata and initializes the plugin state.
Model loading is deferred until first use to minimize startup time.
"""
super().__init__()
self.name = "FlashVSR Upscaling"
self.version = "1.0.3"
self.description = "AI-powered 4x video upscaling with FlashVSR models (8GB+ VRAM)"
# Plugin state
self.current_pipeline = None
self.models_loaded = False
# Cancellation flag for stopping long-running operations (Task 4.1)
self._cancel_flag = False
# Load config
self.config = self.load_config()
# Enable memory debug logging if configured (Task 2.5)
debug_config = self.config.get("debug", {})
if debug_config.get("memory_logging", False):
set_memory_debug(True)
print("[FlashVSR] Memory debug logging enabled")
def load_config(self):
"""
Load configuration from config.json file.
Reads the plugin configuration file and merges it with default values.
Handles missing or corrupted config files gracefully by returning defaults.
Returns:
dict: Configuration dictionary with all settings
"""
import json
from pathlib import Path
# Default configuration - matches FlashVSR_plus defaults
default_config = {
"model_variant": "tiny",
"model_version": "FlashVSR-v1.1",
"scale_factor": 4,
"vram_optimization": {
"tiled_vae": True,
"tiled_dit": False,
"tile_size": 256,
"overlap": 24
},
"quality_settings": {
"color_fix": True,
"output_quality": 6,
"output_fps": 30
},
"sparse_attention": {
"sparse_ratio": 2.0,
"kv_ratio": 3,
"local_range": 11
},
"processing": {
"dtype": "bf16",
"unload_dit": False
},
"debug": {
"memory_logging": False # Enable verbose memory logging for debugging
}
}
# Get plugin directory
plugin_dir = Path(__file__).parent
config_path = plugin_dir / "config.json"
# Try to load config file
try:
if config_path.exists():
with open(config_path, 'r') as f:
config_data = json.load(f)
# Check if config has default field (schema file)
if "default" in config_data:
user_config = config_data["default"]
else:
user_config = config_data
# Merge with defaults (user config takes precedence)
merged_config = default_config.copy()
for key, value in user_config.items():
if isinstance(value, dict) and key in merged_config:
# Deep merge for nested dicts
merged_config[key].update(value)
else:
merged_config[key] = value
print(f"[FlashVSR] Loaded configuration from {config_path}")
return merged_config
else:
print("[FlashVSR] Config file not found, using defaults")
return default_config
except json.JSONDecodeError as e:
print(f"[FlashVSR] Warning: Failed to parse config.json: {e}")
print("[FlashVSR] Using default configuration")
return default_config
except Exception as e:
print(f"[FlashVSR] Warning: Error loading config: {e}")
print("[FlashVSR] Using default configuration")
return default_config
def save_config(self, config=None):
"""
Save configuration to config.json file.
Writes the current plugin configuration to disk for persistence
across sessions. Creates the config file if it doesn't exist.
Args:
config: Configuration dictionary to save. If None, uses self.config.
Returns:
bool: True if save succeeded, False otherwise
"""
import json
from pathlib import Path
if config is None:
config = self.config
# Get plugin directory
plugin_dir = Path(__file__).parent
config_path = plugin_dir / "config.json"
try:
# Read existing file to preserve schema if present
existing_data = {}
if config_path.exists():
try:
with open(config_path, 'r') as f:
existing_data = json.load(f)
except Exception:
pass # If read fails, we'll create new file
# Check if this is a schema file (has $schema field)
if "$schema" in existing_data:
# Update the default field instead of replacing entire file
existing_data["default"] = config
data_to_write = existing_data
else:
# Just write the config directly
data_to_write = config
# Write config file
with open(config_path, 'w') as f:
json.dump(data_to_write, f, indent=2)
print(f"[FlashVSR] Configuration saved to {config_path}")
return True
except Exception as e:
print(f"[FlashVSR] Warning: Failed to save config: {e}")
return False
def update_config_from_ui(self, **kwargs):
"""
Update configuration from UI component values.
Extracts settings from UI components and updates the plugin config.
Automatically saves the updated config to disk.
Args:
**kwargs: Keyword arguments with setting names and values
Returns:
dict: Updated configuration dictionary
"""
# Map UI values to config structure
if "model_variant" in kwargs:
variant_map = {
"Tiny (8-10GB VRAM)": "tiny",
"Tiny-Long (10-12GB VRAM)": "tiny-long",
"Full (18-24GB VRAM)": "full"
}
self.config["model_variant"] = variant_map.get(kwargs["model_variant"], "tiny")
if "scale_factor" in kwargs:
self.config["scale_factor"] = int(kwargs["scale_factor"].replace("x", ""))
if "tiled_vae" in kwargs:
self.config["vram_optimization"]["tiled_vae"] = kwargs["tiled_vae"]
if "tiled_dit" in kwargs:
self.config["vram_optimization"]["tiled_dit"] = kwargs["tiled_dit"]
if "tile_size" in kwargs:
self.config["vram_optimization"]["tile_size"] = int(kwargs["tile_size"])
if "overlap" in kwargs:
self.config["vram_optimization"]["overlap"] = int(kwargs["overlap"])
if "color_fix" in kwargs:
self.config["quality_settings"]["color_fix"] = kwargs["color_fix"]
if "output_quality" in kwargs:
self.config["quality_settings"]["output_quality"] = int(kwargs["output_quality"])
if "output_fps" in kwargs:
self.config["quality_settings"]["output_fps"] = int(kwargs["output_fps"])
if "sparse_ratio" in kwargs:
self.config["sparse_attention"]["sparse_ratio"] = float(kwargs["sparse_ratio"])
if "kv_ratio" in kwargs:
self.config["sparse_attention"]["kv_ratio"] = int(kwargs["kv_ratio"])
if "local_range" in kwargs:
self.config["sparse_attention"]["local_range"] = int(kwargs["local_range"])
if "dtype" in kwargs:
self.config["processing"]["dtype"] = kwargs["dtype"]
if "unload_dit" in kwargs:
self.config["processing"]["unload_dit"] = kwargs["unload_dit"]
if "model_version" in kwargs:
self.config["model_version"] = kwargs["model_version"]
# Save updated config
self.save_config()
return self.config
def get_config_defaults(self):
"""
Get default values from config for UI initialization.
Returns a dictionary mapping UI component names to their default
values from the configuration file.
Returns:
dict: Default values for UI components
"""
variant_map = {
"tiny": "Tiny (8-10GB VRAM)",
"tiny-long": "Tiny-Long (10-12GB VRAM)",
"full": "Full (18-24GB VRAM)"
}
# Handle backwards compatibility for old config format
sparse_attn = self.config.get("sparse_attention", {})
quality = self.config.get("quality_settings", {})
processing = self.config.get("processing", {})
return {
"model_variant": variant_map.get(self.config["model_variant"], "Tiny (8-10GB VRAM)"),
"model_version": self.config.get("model_version", "FlashVSR-v1.1"),
"scale_factor": f"{self.config['scale_factor']}x",
"tiled_vae": self.config["vram_optimization"]["tiled_vae"],
"tiled_dit": self.config["vram_optimization"]["tiled_dit"],
"tile_size": self.config["vram_optimization"]["tile_size"],
"overlap": self.config["vram_optimization"]["overlap"],
"color_fix": quality.get("color_fix", True),
"output_quality": quality.get("output_quality", 6),
"output_fps": quality.get("output_fps", 30),
"sparse_ratio": sparse_attn.get("sparse_ratio", 2.0),
"kv_ratio": sparse_attn.get("kv_ratio", 3),
"local_range": sparse_attn.get("local_range", 11),
"dtype": processing.get("dtype", "bf16"),
"unload_dit": processing.get("unload_dit", False)
}
def setup_ui(self):
"""
Setup UI components before the main Wan2GP UI is built.
This method is called during plugin initialization to register
custom tabs and request access to shared components.
Currently adds a dedicated "FlashVSR Upscaling" tab at position 5.
"""
# Add dedicated FlashVSR tab
self.add_tab(
tab_id="flashvsr_upscaling",
label="FlashVSR Upscaling",
component_constructor=self.create_flashvsr_ui,
position=5 # After main generation tabs
)
def create_flashvsr_ui(self):
"""
Create the FlashVSR upscaling tab user interface.
Builds the Gradio UI components for the FlashVSR upscaling functionality.
Includes all controls for video upscaling with FlashVSR models.
Features:
- Video file upload with validation
- Model variant selection (Tiny/Tiny-Long/Full)
- Scale factor selection (2x/4x)
- Advanced settings:
- Tiled VAE/DiT for VRAM optimization
- Tile size and overlap controls
- Color correction toggle
- Sparse attention parameters
- Progress bar for upscaling operation
- Output video display with download
Returns:
gr.Blocks: Gradio Blocks component containing the FlashVSR UI
"""
# Get default values from config
defaults = self.get_config_defaults()
with gr.Blocks() as demo:
gr.Markdown("""
## FlashVSR Video Upscaling
Upload a video and upscale it using AI-powered FlashVSR models.
**Features:**
- 4x upscaling (2x also supported)
- Support for 8GB+ VRAM GPUs
- Automatic model downloading from HuggingFace
- Tile-based processing for low VRAM scenarios
- Automatically remuxes audio from original video (if present)
""")
with gr.Row():
# Left column - Input controls
with gr.Column(scale=1):
gr.Markdown("### Input Settings")
video_input = gr.File(
label="Input Video",
file_types=["video"],
elem_id="flashvsr_input_video"
)
with gr.Row():
model_variant = gr.Dropdown(
choices=[
"Tiny (8-10GB VRAM)",
"Tiny-Long (10-12GB VRAM)",
"Full (18-24GB VRAM)"
],
value=defaults["model_variant"],
label="Model Variant",
info="Tiny recommended for most users",
elem_id="flashvsr_model_variant"
)
with gr.Row():
scale_factor = gr.Dropdown(
choices=["2x", "4x"],
value=defaults["scale_factor"],
label="Scale Factor",
info="4x recommended (native FlashVSR)",
elem_id="flashvsr_scale_factor"
)
# Advanced Settings Accordion
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("#### VRAM Optimization")
tiled_vae = gr.Checkbox(
label="Tiled VAE",
value=defaults["tiled_vae"],
info="Enable for high resolution (>1080p)",
elem_id="flashvsr_tiled_vae"
)
tiled_dit = gr.Checkbox(
label="Tiled DiT",
value=defaults["tiled_dit"],
info="⚡ Recommended for 50+ frames (faster processing)",
elem_id="flashvsr_tiled_dit"
)
with gr.Row():
tile_size = gr.Slider(
minimum=128,
maximum=512,
value=defaults["tile_size"],
step=64,
label="Tile Size",
info="Smaller = less VRAM, slower",
elem_id="flashvsr_tile_size"
)
overlap = gr.Slider(
minimum=8,
maximum=64,
value=defaults["overlap"],
step=8,
label="Tile Overlap (px)",
info="Reduces seam artifacts",
elem_id="flashvsr_overlap"
)
gr.Markdown("#### Quality Settings")
color_fix = gr.Checkbox(
label="Enable Color Fix",
value=defaults["color_fix"],
info="Wavelet-based color correction",
elem_id="flashvsr_color_fix"
)
with gr.Row():
output_quality = gr.Slider(
minimum=1,
maximum=10,
value=defaults["output_quality"],
step=1,
label="Output Video Quality",
info="Higher = better quality, larger file",
elem_id="flashvsr_output_quality"
)
output_fps = gr.Number(
value=defaults["output_fps"],
label="Output FPS",
info="Fallback when video metadata unavailable",
precision=0,
elem_id="flashvsr_output_fps"
)
gr.Markdown("#### Processing Settings")
with gr.Row():
dtype = gr.Radio(
choices=["fp16", "bf16"],
value=defaults["dtype"],
label="Data Type",
info="bf16 recommended for most GPUs",
elem_id="flashvsr_dtype"
)
unload_dit = gr.Checkbox(
label="Unload DiT before Decoding",
value=defaults["unload_dit"],
info="Saves VRAM during decode",
elem_id="flashvsr_unload_dit"
)
gr.Markdown("#### Sparse Attention Parameters")
with gr.Row():
sparse_ratio = gr.Slider(
minimum=0.5,
maximum=5.0,
value=defaults["sparse_ratio"],
step=0.1,
label="Sparse Ratio",
info="Controls attention sparsity; smaller = more sparse",
elem_id="flashvsr_sparse_ratio"
)
kv_ratio = gr.Slider(
minimum=1,
maximum=8,
value=defaults["kv_ratio"],
step=1,
label="KV Cache Ratio",
info="Controls the length of the KV cache",
elem_id="flashvsr_kv_ratio"
)
local_range = gr.Slider(
minimum=3,
maximum=15,
value=defaults["local_range"],
step=2,
label="Local Range",
info="Size of the local attention window",
elem_id="flashvsr_local_range"
)
gr.Markdown("#### Model Version")
model_version = gr.Radio(
choices=["FlashVSR", "FlashVSR-v1.1"],
value=defaults.get("model_version", "FlashVSR-v1.1"),
label="Model Version",
info="FlashVSR-v1.1 uses causal attention for better temporal consistency",
elem_id="flashvsr_model_version"
)
# Upscale button with progress (Task 5.1: two-button pattern)
upscale_btn = gr.Button(
"🚀 Upscale Video",
variant="primary",
size="lg",
elem_id="flashvsr_upscale_btn",
visible=True
)
# Stop button (Task 4.2: visible=False by default)
stop_btn = gr.Button(
"⬛ Stop Processing",
variant="stop",
size="lg",
elem_id="flashvsr_stop_btn",
visible=False
)
progress_bar = gr.Progress()
# Right column - Output (Task 5.4: status_text removed)
with gr.Column(scale=1):
gr.Markdown("### Output")
video_output = gr.Video(
label="Upscaled Video",
elem_id="flashvsr_output_video"
)
# Info box with VRAM estimates
vram_info = gr.Markdown(
"""
**VRAM Estimates:**
- Tiny: 8-10GB for 1080p
- Tiny-Long: 10-12GB for long videos
- Full: 18-24GB for highest quality
**⚡ Tiled DiT:** Recommended for videos with 50+ frames.
Without it, processing is O(n²) slow regardless of VRAM.
""",
elem_id="flashvsr_vram_info"
)
# Event handler - implements full upscaling functionality
def upscale_video(
video, variant, scale, t_vae, t_dit,
t_size, t_overlap, c_fix, out_quality, out_fps,
data_type, do_unload_dit, sparse_r, kv, local_r,
model_ver,
progress=gr.Progress()
):
"""
Upscale a video using FlashVSR models.
Args:
video: Gradio File object containing the input video
variant: Model variant selection string
scale: Scale factor ("2x" or "4x")
t_vae: Enable tiled VAE
t_dit: Enable tiled DiT
t_size: Tile size for tiled processing
t_overlap: Tile overlap in pixels
c_fix: Enable color correction
out_quality: Output video quality (1-10)
out_fps: Fallback FPS when video metadata unavailable
data_type: Data type ("fp16" or "bf16")
do_unload_dit: Unload DiT before decoding
sparse_r: Sparse ratio for attention (0.5-5.0)
kv: KV cache ratio (1-8)
local_r: Local attention range (3-15)
progress: Gradio progress tracker
Returns:
str: Output video path, or None on error (Task 5.5)
"""
import os
import torch
import imageio
import numpy as np
import ffmpeg
from pathlib import Path
# Reset cancellation flag at start (Task 4.3)
self._cancel_flag = False
# Validation
if video is None:
gr.Warning("Please upload a video first.")
return None
try:
progress(0, desc="Initializing...")
# Parse variant
variant_map = {
"Tiny (8-10GB VRAM)": "tiny",
"Tiny-Long (10-12GB VRAM)": "tiny-long",
"Full (18-24GB VRAM)": "full"
}
selected_variant = variant_map.get(variant, "tiny")
is_tiny_long = (selected_variant == "tiny-long")
# Parse scale factor
scale_factor = int(scale.replace("x", ""))
# Determine dtype based on user selection
if data_type == "bf16" and torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
gr.Error("CUDA GPU required for FlashVSR upscaling.")
return None
# Import helper functions from download_manager
from .src.models.download_manager import load_pipeline
progress(0.05, desc="Loading input video...")
# Load input video
video_path = video.name if hasattr(video, 'name') else str(video)
try:
reader = imageio.get_reader(video_path)
meta = reader.get_meta_data()
fps = int(round(meta.get('fps', out_fps)))
# Load all frames
frames = []
for frame_data in reader:
frame_np = frame_data.astype(np.float32) / 255.0
frames.append(torch.from_numpy(frame_np).to(torch_dtype))
reader.close()
if len(frames) < 21:
gr.Warning(f"Video must have at least 21 frames. Got {len(frames)} frames.")
return None
video_tensor = torch.stack(frames, 0) # Shape: (N, H, W, C)
except Exception as e:
gr.Error(f"Error loading video: {str(e)}")
return None
# =================================================================
# PRE-FLIGHT RESOURCE CHECK
# =================================================================
progress(0.10, desc="Checking system resources...")
frame_count = int(video_tensor.shape[0])
N0, h0, w0, _ = video_tensor.shape
# Estimate resource requirements
estimates = estimate_vram_usage(
width=w0,
height=h0,
frames=frame_count,
scale=scale_factor,
tiled_vae=t_vae,
tiled_dit=t_dit,
mode=selected_variant
)
# Check against available resources
resources = check_resources(
estimated_vram_gb=estimates["total_vram_gb"],
estimated_ram_gb=estimates["output_ram_gb"]
)
# Log resource check results
print(f"[FlashVSR] Pre-flight check: {estimates['details']}")
print(f"[FlashVSR] Available: VRAM={resources['available_vram_gb']:.1f}GB, RAM={resources['available_ram_gb']:.1f}GB")
# Display warnings if resources are insufficient
if resources["warnings"]:
# Get optimal settings recommendations
optimal = get_optimal_settings(
width=w0,
height=h0,
frames=frame_count,
scale=scale_factor,
available_vram_gb=resources["available_vram_gb"],
available_ram_gb=resources["available_ram_gb"],
current_mode=selected_variant
)
# Build warning message
warning_parts = resources["warnings"].copy()
warning_parts.extend(optimal["recommendations"])
warning_msg = "\n".join(warning_parts)
# Display warning to user (processing continues)
gr.Warning(warning_msg)
print(f"[FlashVSR] Resource warnings:\n{warning_msg}")
# Reset peak VRAM stats for accurate tracking
reset_peak_vram_stats()
# Record start time for processing summary
import time as time_module
processing_start_time = time_module.time()
progress(0.15, desc=f"Loading {selected_variant.upper()} pipeline (model: {model_ver})...")
# Load pipeline (this will download models if needed)
# IMPORTANT: Reinitialize fresh each time to avoid state leakage (matches upstream)
try:
# Clean up any existing pipeline to avoid memory issues
if self.current_pipeline is not None:
del self.current_pipeline
self.current_pipeline = None
clean_vram()
pipeline = load_pipeline(
variant=selected_variant,
device=device,
torch_dtype=torch_dtype,
model_version=model_ver
)
self.current_pipeline = pipeline
self.models_loaded = True
except Exception as e:
gr.Error(f"Error loading pipeline: {str(e)}")
return None
progress(0.25, desc="Preparing input frames...")
# Prepare input tensor - matches upstream FlashVSR_plus approach
# Frame padding to ensure 8n+5 alignment for the pipeline
# Note: frame_count, N0, h0, w0 already defined in pre-flight check
pad_to = next_8n5(frame_count)
add = pad_to - frame_count
if add > 0:
padding_frames = video_tensor[-1:, :, :, :].repeat(add, 1, 1, 1)
video_tensor = torch.cat([video_tensor, padding_frames], dim=0)
# Clean VRAM before processing
clean_vram()
print(f"[FlashVSR] Processing {frame_count} frames...")
progress(0.35, desc="Running upscaling inference...")
# Build common pipe_kwargs matching upstream FlashVSR_plus
# Note: color_fix is handled differently for tiled vs non-tiled modes
pipe_kwargs = {
"prompt": "",
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"seed": 0,
"tiled": t_vae,
"is_full_block": False,
"if_buffer": True,
"kv_ratio": int(kv),
"local_range": int(local_r),
"unload_dit": False, # Don't unload between tiles
"fps": fps, # CRITICAL: Pass fps for temporal consistency
}
final_output_tensor = None
output_frames = None
output_written_directly = False # Flag for Tiny-Long direct file output
# Prepare output path early (needed for Tiny-Long mode)
output_dir = Path("outputs") / "flashvsr"
output_dir.mkdir(parents=True, exist_ok=True)
import time as time_module
timestamp = time_module.strftime("%Y%m%d-%H%M%S")
output_filename = f"flashvsr_{selected_variant}_{scale}_{timestamp}.mp4"
output_path = output_dir / output_filename
try:
if t_dit:
# ============================================================
# TILED DiT PROCESSING - matches upstream FlashVSR_plus
# ============================================================
N, H, W, C = video_tensor.shape
progress(0.35, desc=f"Initializing tiled processing (tile_size={t_size}, overlap={t_overlap})...")
# Validate overlap
if t_overlap > t_size / 2:
gr.Warning("Overlap must be less than half of the tile size!")
return None
# Calculate tile coordinates at ORIGINAL resolution
tile_coords = calculate_tile_coords(H, W, t_size, t_overlap)
num_tiles = len(tile_coords)
print(f"[FlashVSR] Tile-DiT: Processing {num_tiles} tiles at {W}x{H} (output: {W*scale_factor}x{H*scale_factor})")
from tqdm import tqdm as tqdm_progress
# Add color_fix to pipe_kwargs for tiled processing
tile_pipe_kwargs = {**pipe_kwargs, "color_fix": c_fix}
if is_tiny_long:
# ============================================================
# TINY-LONG TILED MODE: Write each tile to temp file, then stitch
# ============================================================
import tempfile
import uuid
temp_dir = Path(tempfile.gettempdir()) / f"flashvsr_tiles_{uuid.uuid4().hex}"
temp_dir.mkdir(parents=True, exist_ok=True)
temp_videos = []
for tile_idx, (x1, y1, x2, y2) in enumerate(tqdm_progress(tile_coords, desc="[FlashVSR] Processing tiles")):
# Task 4.4: Check for cancellation at start of each tile
if self._cancel_flag:
# Task 4.5: Cleanup on cancellation - delete temp files
import shutil
try:
shutil.rmtree(temp_dir)
except:
pass
# Task 4.6: Show cancellation warning
gr.Warning("Upscaling cancelled by user. Partial files deleted.")
return None
progress(
0.35 + 0.40 * (tile_idx / num_tiles),
desc=f"Processing tile {tile_idx+1}/{num_tiles}"
)
# Extract tile from ORIGINAL frames
input_tile = video_tensor[:, y1:y2, x1:x2, :]
# Get input parameters for this tile
th, tw, F = get_input_params(input_tile, scale=scale_factor)
# Use generator for memory-efficient processing
LQ_tile = input_tensor_generator(input_tile, device, scale=scale_factor, dtype=torch_dtype)
# Temp output path for this tile
temp_name = str(temp_dir / f"{tile_idx+1:05d}.mp4")
# Calculate topk_ratio for this tile's resolution
topk_ratio_tile = sparse_r * 768 * 1280 / (th * tw)
# Run pipeline on tile - writes directly to temp file
result = pipeline(
LQ_video=LQ_tile,
num_frames=F,
height=th,
width=tw,
topk_ratio=topk_ratio_tile,
output_path=temp_name,
quality=int(out_quality),
**tile_pipe_kwargs
)
temp_videos.append(temp_name)
# Clean up
del input_tile
clean_vram()
progress(0.75, desc="Stitching tiles...")
# Stitch all tiles together
stitch_video_tiles(
tile_paths=temp_videos,
tile_coords=tile_coords,
final_dims=(W * scale_factor, H * scale_factor),
scale=scale_factor,
overlap=t_overlap,
output_path=str(output_path),
fps=fps,
quality=int(out_quality),
cleanup=True
)
# Clean up temp directory
import shutil
try:
shutil.rmtree(temp_dir)
except:
pass
output_written_directly = True
print("[FlashVSR] Tile-DiT processing complete (Tiny-Long mode).")
else:
# ============================================================
# STANDARD TILED MODE: Accumulate in memory, then save
# ============================================================
num_aligned_frames = largest_8n1_leq(N + 4) - 4
# Create output canvas at SCALED resolution
final_output_canvas = torch.zeros(
(num_aligned_frames, H * scale_factor, W * scale_factor, C),
dtype=torch.float32
)
weight_sum_canvas = torch.zeros_like(final_output_canvas)
for tile_idx, (x1, y1, x2, y2) in enumerate(tqdm_progress(tile_coords, desc="[FlashVSR] Processing tiles")):
# Task 4.4: Check for cancellation at start of each tile
if self._cancel_flag:
# Task 4.5: Cleanup on cancellation - delete canvas tensors
del final_output_canvas, weight_sum_canvas
clean_vram()
# Task 4.6: Show cancellation warning
gr.Warning("Upscaling cancelled by user.")
return None
progress(
0.35 + 0.50 * (tile_idx / num_tiles),
desc=f"Processing tile {tile_idx+1}/{num_tiles}"
)
# Extract tile from ORIGINAL frames (not upscaled)
input_tile = video_tensor[:, y1:y2, x1:x2, :]
# Prepare the tile for the pipeline (bicubic upscale + normalize)
LQ_tile, th, tw, F = prepare_input_tensor(
input_tile, device, scale=scale_factor, dtype=torch_dtype
)
LQ_tile = LQ_tile.to(device)
# Calculate topk_ratio for this tile's resolution
topk_ratio_tile = sparse_r * 768 * 1280 / (th * tw)
# Run pipeline on tile
output_tile_gpu = pipeline(
LQ_video=LQ_tile,
num_frames=F,
height=th,
width=tw,
topk_ratio=topk_ratio_tile,
**tile_pipe_kwargs
)
# Check for pipeline error (returns boolean on failure)
if not isinstance(output_tile_gpu, torch.Tensor):
raise RuntimeError(f"Pipeline returned {type(output_tile_gpu).__name__} instead of tensor. This may indicate an incompatible pipeline variant or internal error.")
# Convert output tile to video frames format
processed_tile_cpu = tensor2video(output_tile_gpu).cpu()
# Create feather mask for blending at SCALED resolution
tile_out_h, tile_out_w = processed_tile_cpu.shape[1], processed_tile_cpu.shape[2]
mask = create_feather_mask(
(tile_out_h, tile_out_w),
t_overlap * scale_factor
).cpu()
# Reshape mask for broadcasting: (1, 1, H, W) -> (1, H, W, 1)
mask = mask.permute(0, 2, 3, 1)
# Calculate output coordinates at SCALED resolution
x1_s, y1_s = x1 * scale_factor, y1 * scale_factor
x2_s = x1_s + tile_out_w
y2_s = y1_s + tile_out_h
# Accumulate weighted tile into canvas
actual_frames = processed_tile_cpu.shape[0]
canvas_frames = final_output_canvas.shape[0]
use_frames = min(actual_frames, canvas_frames)
final_output_canvas[:use_frames, y1_s:y2_s, x1_s:x2_s, :] += processed_tile_cpu[:use_frames] * mask
weight_sum_canvas[:use_frames, y1_s:y2_s, x1_s:x2_s, :] += mask
# Clean up tile to free VRAM
del LQ_tile, output_tile_gpu, processed_tile_cpu, input_tile, mask
clean_vram()
# Use streaming write for tiled canvas output (Task 2.3)
# This avoids the memory spike from normalizing all frames at once
progress(0.85, desc="Saving output video (streaming)...")
log_memory_usage("Before canvas streaming write")
out_h, out_w = write_canvas_streaming(
canvas_tensor=final_output_canvas,
weight_tensor=weight_sum_canvas,
output_path=str(output_path),
fps=fps,
quality=int(out_quality),
frame_count=frame_count
)
# Delete intermediate tensors immediately (Task 2.4)
del final_output_canvas, weight_sum_canvas
clean_vram()
log_memory_usage("After canvas streaming write")
output_written_directly = True
print("[FlashVSR] Tile-DiT processing complete.")
# Clean up pipeline if requested
if do_unload_dit and hasattr(pipeline, 'offload_model'):
pipeline.offload_model(keep_vae=True)
else:
# ============================================================
# STANDARD (NON-TILED) PROCESSING
# ============================================================
# Get input parameters
tH, tW, F = get_input_params(video_tensor, scale_factor)
# Calculate topk_ratio
topk_ratio_adjusted = sparse_r * 768 * 1280 / (tH * tW)
# Add color_fix and unload_dit for non-tiled mode
full_pipe_kwargs = {
**pipe_kwargs,
"color_fix": c_fix,
"unload_dit": do_unload_dit,
}
# Task 4.4: Check for cancellation before starting pipeline
if self._cancel_flag:
clean_vram()
gr.Warning("Upscaling cancelled by user.")
return None
if is_tiny_long:
# ============================================================
# TINY-LONG NON-TILED: Write directly to output file
# ============================================================
# Use generator for memory-efficient processing
LQ_video = input_tensor_generator(video_tensor, device, scale=scale_factor, dtype=torch_dtype)
# Run pipeline with output_path - writes directly to file
result = pipeline(
LQ_video=LQ_video,
num_frames=F,
height=tH,
width=tW,
topk_ratio=topk_ratio_adjusted,
output_path=str(output_path),
quality=int(out_quality),
**full_pipe_kwargs
)
if result == False:
raise RuntimeError("Pipeline returned False, indicating an error during processing. Check console for details.")
output_written_directly = True
print("[FlashVSR] Processing complete (Tiny-Long mode).")
else:
# ============================================================
# STANDARD NON-TILED: Process in memory, then save
# ============================================================
# Prepare full-frame input tensor
LQ_video, tH, tW, F = prepare_input_tensor(
video_tensor, device, scale=scale_factor, dtype=torch_dtype
)
LQ_video = LQ_video.to(device)
# Run full pipeline
output_tensor = pipeline(
LQ_video=LQ_video,
num_frames=F,
height=tH,
width=tW,
topk_ratio=topk_ratio_adjusted,
**full_pipe_kwargs
)
# Check for pipeline error (returns boolean on failure)
if not isinstance(output_tensor, torch.Tensor):
raise RuntimeError(f"Pipeline returned {type(output_tensor).__name__} instead of tensor. Check console for error details.")
# Convert output to video frames
output_frames = tensor2video(output_tensor).cpu()
# Trim to original frame count
output_frames = output_frames[:frame_count]
del pipeline
clean_vram()
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(f"[FlashVSR] Error during upscaling: {error_trace}")
gr.Error(f"Error during upscaling: {str(e)}")
return None
progress(0.85, desc="Saving output video...")
# Save video (skip if already written directly)
if not output_written_directly:
# output_frames is in (F, H, W, C) format in [0, 1] range
# Use streaming write to avoid memory spike (Task 2.2)
log_memory_usage("Before streaming write")
out_h, out_w = write_frames_streaming(
output_frames_tensor=output_frames,
output_path=str(output_path),
fps=fps,
quality=int(out_quality)
)
# Delete intermediate tensors immediately (Task 2.4)
del output_frames
clean_vram()
log_memory_usage("After streaming write")
# Get output dimensions from file for status message
try:
probe = ffmpeg.probe(str(output_path))
video_stream = next(s for s in probe['streams'] if s['codec_type'] == 'video')
out_w = int(video_stream['width'])
out_h = int(video_stream['height'])
except:
out_w, out_h = 0, 0 # Fallback if probe fails
# Task 4.4: Check for cancellation before audio merge
if self._cancel_flag:
# Task 4.5: Delete partial output file
try:
if output_path.exists():
os.remove(str(output_path))
except:
pass
gr.Warning("Upscaling cancelled by user. Partial files deleted.")
return None
progress(0.95, desc="Merging audio...")
# Try to merge audio from source
try:
probe = ffmpeg.probe(video_path)
audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio']
if audio_streams:
temp_path = str(output_path) + "_temp.mp4"
os.rename(str(output_path), temp_path)
input_video = ffmpeg.input(temp_path)['v']
input_audio = ffmpeg.input(video_path)['a']
ffmpeg.output(
input_video, input_audio, str(output_path),
vcodec='copy', acodec='copy'
).run(overwrite_output=True, quiet=True)
os.remove(temp_path)
except Exception as e:
# Audio merge failed, but video is still usable
print(f"[FlashVSR] Warning: Audio merge failed: {e}")
progress(1.0, desc="Complete!")
# Get output frame count for status
output_frame_count = frame_count # Default to input frame count
try:
probe = ffmpeg.probe(str(output_path))
video_stream = next(s for s in probe['streams'] if s['codec_type'] == 'video')
output_frame_count = int(video_stream.get('nb_frames', frame_count))
except:
pass # Use default frame_count
# Log processing summary with peak VRAM usage
summary_stats = log_processing_summary(
start_time=processing_start_time,
width=w0,
height=h0,
frames=frame_count,
scale=scale_factor,
mode=selected_variant,
tiled_vae=t_vae,
tiled_dit=t_dit
)
# Display success message using gr.Info (Task 5.6)
gr.Info(f"Upscaling complete! {frame_count} frames @ {out_w}x{out_h} in {summary_stats['elapsed_seconds']:.1f}s ({summary_stats['fps_throughput']:.2f} fps). Peak VRAM: {summary_stats['peak_vram_gb']:.2f} GB")
return str(output_path)
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(f"[FlashVSR] Error: {error_trace}")
gr.Error(f"Error: {str(e)}")
return None
# Task 4.3: Stop button click handler
def on_stop_click():
"""Handle stop button click - sets cancellation flag"""
self._cancel_flag = True
gr.Warning(
"Cancellation requested. Will stop after current operation completes. "
"Note: Cannot interrupt mid-inference. If stuck, restart the application."
)
return gr.update(interactive=False, value="⏳ Stopping...")
# Task 5.2, 5.3: Button state management wrappers
def show_processing_state():
"""Show stop button, hide upscale button when processing starts"""
return gr.update(visible=False), gr.update(visible=True)
def show_ready_state():
"""Show upscale button, hide stop button when processing completes"""
return gr.update(visible=True), gr.update(visible=False, interactive=True, value="⬛ Stop Processing")
# Wire up stop button (Task 4.3)
stop_btn.click(
fn=on_stop_click,
inputs=[],
outputs=[stop_btn],
queue=False
)
# Wire up upscale button with button state management (Tasks 5.2, 5.3)
upscale_btn.click(
fn=show_processing_state,
inputs=[],
outputs=[upscale_btn, stop_btn],
queue=False
).then(
fn=upscale_video,
inputs=[
video_input, model_variant, scale_factor,
tiled_vae, tiled_dit, tile_size, overlap,
color_fix, output_quality, output_fps,
dtype, unload_dit, sparse_ratio, kv_ratio, local_range,
model_version
],
outputs=[video_output]
).then(
fn=show_ready_state,
inputs=[],
outputs=[upscale_btn, stop_btn],
queue=False
)
# Save config when settings change
def save_settings(variant, scale, t_vae, t_dit, t_size, t_overlap,
c_fix, out_qual, out_fps, data_type, do_unload,
sparse_r, kv, local_r, model_ver):
"""Save current UI settings to config file"""
self.update_config_from_ui(
model_variant=variant,
scale_factor=scale,
tiled_vae=t_vae,
tiled_dit=t_dit,
tile_size=t_size,
overlap=t_overlap,
color_fix=c_fix,
output_quality=out_qual,
output_fps=out_fps,
dtype=data_type,
unload_dit=do_unload,
sparse_ratio=sparse_r,
kv_ratio=kv,
local_range=local_r,
model_version=model_ver
)
return None
# Attach change handlers to save config (debounced via change event)
for component in [model_variant, scale_factor, tiled_vae, tiled_dit,
tile_size, overlap, color_fix, output_quality, output_fps,
dtype, unload_dit, sparse_ratio, kv_ratio, local_range,
model_version]:
component.change(
fn=save_settings,
inputs=[
model_variant, scale_factor, tiled_vae, tiled_dit,
tile_size, overlap, color_fix, output_quality, output_fps,
dtype, unload_dit, sparse_ratio, kv_ratio, local_range,
model_version
],
outputs=None
)
return demo
def post_ui_setup(self, components: dict):
"""
Perform post-UI setup after the main Wan2GP UI is built.
This method is called after all UI components are created and allows
the plugin to:
- Access and wire events to existing components
- Inject new UI elements into existing layouts
- Configure cross-component interactions
Args:
components: Dictionary of Gradio components from the main UI,
keyed by their elem_id values
Returns:
dict: Empty dictionary (no components to expose currently)
"""
return {}
def on_tab_select(self, state):
"""
Handle FlashVSR tab selection event.
Called when the user navigates to the FlashVSR Upscaling tab.
Pre-loads models to reduce first-upscale latency and prepares GPU resources.
Args:
state: Current application state (from Gradio)
"""
# Check if we have a pipeline loaded
if self.current_pipeline is not None:
try:
print("[FlashVSR] Tab selected - loading models to GPU...")
# Move pipeline models to GPU
if hasattr(self.current_pipeline, 'load_models_to_device'):
self.current_pipeline.load_models_to_device(['dit', 'vae', 'TCDecoder'])
# Re-initialize cross-attention KV cache if it was offloaded
if hasattr(self.current_pipeline, 'prompt_emb_posi'):
if self.current_pipeline.prompt_emb_posi is not None:
if self.current_pipeline.prompt_emb_posi.get('stats') == 'offload':
context = self.current_pipeline.prompt_emb_posi.get('context')
if context is not None:
print("[FlashVSR] Re-initializing cross-attention KV cache...")
self.current_pipeline.init_cross_kv(context_tensor=context)
# Move LQ_proj_in to GPU if it exists
if hasattr(self.current_pipeline, 'dit') and self.current_pipeline.dit is not None:
if hasattr(self.current_pipeline.dit, 'LQ_proj_in') and self.current_pipeline.dit.LQ_proj_in is not None:
device = self.current_pipeline.device
self.current_pipeline.dit.LQ_proj_in.to(device)
# Move TCDecoder to GPU
if hasattr(self.current_pipeline, 'TCDecoder') and self.current_pipeline.TCDecoder is not None:
device = self.current_pipeline.device
self.current_pipeline.TCDecoder.to(device)
print("[FlashVSR] Models loaded to GPU. Ready for upscaling.")
except Exception as e:
print(f"[FlashVSR] Warning: Failed to pre-load models on tab select: {e}")
# Non-critical error, models will load on first upscale anyway
def on_tab_deselect(self, state):
"""
Handle FlashVSR tab deselection event.
Called when the user navigates away from the FlashVSR Upscaling tab.
Offloads models to CPU to free VRAM for other Wan2GP operations.
Args:
state: Current application state (from Gradio)
"""
import torch
# Check if we have a pipeline loaded
if self.current_pipeline is not None:
try:
print("[FlashVSR] Tab deselected - offloading models to CPU to free VRAM...")
# Get current VRAM usage before offload
vram_before = 0.0
if torch.cuda.is_available():
vram_before = torch.cuda.memory_allocated() / 1024**3 # GB
print(f"[FlashVSR] VRAM before offload: {vram_before:.2f} GB")
# Offload pipeline models to CPU
if hasattr(self.current_pipeline, 'offload_model'):
self.current_pipeline.offload_model(keep_vae=False)
else:
# Manual offload if method doesn't exist
if hasattr(self.current_pipeline, 'dit') and self.current_pipeline.dit is not None:
if hasattr(self.current_pipeline.dit, 'clear_cross_kv'):
self.current_pipeline.dit.clear_cross_kv()
self.current_pipeline.dit.to('cpu')
if hasattr(self.current_pipeline, 'vae') and self.current_pipeline.vae is not None:
self.current_pipeline.vae.to('cpu')
if hasattr(self.current_pipeline, 'TCDecoder') and self.current_pipeline.TCDecoder is not None:
self.current_pipeline.TCDecoder.to('cpu')
# Update status
if hasattr(self.current_pipeline, 'prompt_emb_posi'):
if self.current_pipeline.prompt_emb_posi is not None:
self.current_pipeline.prompt_emb_posi['stats'] = 'offload'
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
vram_after = torch.cuda.memory_allocated() / 1024**3 # GB
freed = vram_before - vram_after
print(f"[FlashVSR] VRAM after offload: {vram_after:.2f} GB (freed {freed:.2f} GB)")
print("[FlashVSR] Models offloaded to CPU. VRAM freed for other tasks.")
except Exception as e:
print(f"[FlashVSR] Warning: Failed to offload models on tab deselect: {e}")
# Non-critical error, but VRAM may not be freed