Spaces:
Running on Zero
Running on Zero
reduce vRAM usage for vae decode
Browse files- acestep/handler.py +92 -9
acestep/handler.py
CHANGED
|
@@ -1995,7 +1995,7 @@ class AceStepHandler:
|
|
| 1995 |
|
| 1996 |
return outputs
|
| 1997 |
|
| 1998 |
-
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
| 1999 |
"""
|
| 2000 |
Decode latents using tiling to reduce VRAM usage.
|
| 2001 |
Uses overlap-discard strategy to avoid boundary artifacts.
|
|
@@ -2004,6 +2004,7 @@ class AceStepHandler:
|
|
| 2004 |
latents: [Batch, Channels, Length]
|
| 2005 |
chunk_size: Size of latent chunk to process at once
|
| 2006 |
overlap: Overlap size in latent frames
|
|
|
|
| 2007 |
"""
|
| 2008 |
B, C, T = latents.shape
|
| 2009 |
|
|
@@ -2015,14 +2016,21 @@ class AceStepHandler:
|
|
| 2015 |
stride = chunk_size - 2 * overlap
|
| 2016 |
if stride <= 0:
|
| 2017 |
raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}")
|
| 2018 |
-
|
| 2019 |
-
decoded_audio_list = []
|
| 2020 |
-
|
| 2021 |
-
# We need to determine upsample factor to trim audio correctly
|
| 2022 |
-
upsample_factor = None
|
| 2023 |
|
| 2024 |
num_steps = math.ceil(T / stride)
|
| 2025 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2026 |
for i in tqdm(range(num_steps), desc="Decoding audio chunks"):
|
| 2027 |
# Core range in latents
|
| 2028 |
core_start = i * stride
|
|
@@ -2045,16 +2053,16 @@ class AceStepHandler:
|
|
| 2045 |
|
| 2046 |
# Calculate trim amounts in audio samples
|
| 2047 |
# How much overlap was added at the start?
|
| 2048 |
-
added_start = core_start - win_start
|
| 2049 |
trim_start = int(round(added_start * upsample_factor))
|
| 2050 |
|
| 2051 |
# How much overlap was added at the end?
|
| 2052 |
-
added_end = win_end - core_end
|
| 2053 |
trim_end = int(round(added_end * upsample_factor))
|
| 2054 |
|
| 2055 |
# Trim audio
|
| 2056 |
audio_len = audio_chunk.shape[-1]
|
| 2057 |
-
end_idx = audio_len - trim_end
|
| 2058 |
|
| 2059 |
audio_core = audio_chunk[:, :, trim_start:end_idx]
|
| 2060 |
decoded_audio_list.append(audio_core)
|
|
@@ -2062,6 +2070,81 @@ class AceStepHandler:
|
|
| 2062 |
# Concatenate
|
| 2063 |
final_audio = torch.cat(decoded_audio_list, dim=-1)
|
| 2064 |
return final_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2065 |
|
| 2066 |
def generate_music(
|
| 2067 |
self,
|
|
|
|
| 1995 |
|
| 1996 |
return outputs
|
| 1997 |
|
| 1998 |
+
def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False):
|
| 1999 |
"""
|
| 2000 |
Decode latents using tiling to reduce VRAM usage.
|
| 2001 |
Uses overlap-discard strategy to avoid boundary artifacts.
|
|
|
|
| 2004 |
latents: [Batch, Channels, Length]
|
| 2005 |
chunk_size: Size of latent chunk to process at once
|
| 2006 |
overlap: Overlap size in latent frames
|
| 2007 |
+
offload_wav_to_cpu: If True, offload decoded wav audio to CPU immediately to save VRAM
|
| 2008 |
"""
|
| 2009 |
B, C, T = latents.shape
|
| 2010 |
|
|
|
|
| 2016 |
stride = chunk_size - 2 * overlap
|
| 2017 |
if stride <= 0:
|
| 2018 |
raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2019 |
|
| 2020 |
num_steps = math.ceil(T / stride)
|
| 2021 |
|
| 2022 |
+
if offload_wav_to_cpu:
|
| 2023 |
+
# Optimized path: offload wav to CPU immediately to save VRAM
|
| 2024 |
+
return self._tiled_decode_offload_cpu(latents, B, T, stride, overlap, num_steps)
|
| 2025 |
+
else:
|
| 2026 |
+
# Default path: keep everything on GPU
|
| 2027 |
+
return self._tiled_decode_gpu(latents, B, T, stride, overlap, num_steps)
|
| 2028 |
+
|
| 2029 |
+
def _tiled_decode_gpu(self, latents, B, T, stride, overlap, num_steps):
|
| 2030 |
+
"""Standard tiled decode keeping all data on GPU."""
|
| 2031 |
+
decoded_audio_list = []
|
| 2032 |
+
upsample_factor = None
|
| 2033 |
+
|
| 2034 |
for i in tqdm(range(num_steps), desc="Decoding audio chunks"):
|
| 2035 |
# Core range in latents
|
| 2036 |
core_start = i * stride
|
|
|
|
| 2053 |
|
| 2054 |
# Calculate trim amounts in audio samples
|
| 2055 |
# How much overlap was added at the start?
|
| 2056 |
+
added_start = core_start - win_start # latent frames
|
| 2057 |
trim_start = int(round(added_start * upsample_factor))
|
| 2058 |
|
| 2059 |
# How much overlap was added at the end?
|
| 2060 |
+
added_end = win_end - core_end # latent frames
|
| 2061 |
trim_end = int(round(added_end * upsample_factor))
|
| 2062 |
|
| 2063 |
# Trim audio
|
| 2064 |
audio_len = audio_chunk.shape[-1]
|
| 2065 |
+
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
|
| 2066 |
|
| 2067 |
audio_core = audio_chunk[:, :, trim_start:end_idx]
|
| 2068 |
decoded_audio_list.append(audio_core)
|
|
|
|
| 2070 |
# Concatenate
|
| 2071 |
final_audio = torch.cat(decoded_audio_list, dim=-1)
|
| 2072 |
return final_audio
|
| 2073 |
+
|
| 2074 |
+
def _tiled_decode_offload_cpu(self, latents, B, T, stride, overlap, num_steps):
|
| 2075 |
+
"""Optimized tiled decode that offloads to CPU immediately to save VRAM."""
|
| 2076 |
+
# First pass: decode first chunk to get upsample_factor and audio channels
|
| 2077 |
+
first_core_start = 0
|
| 2078 |
+
first_core_end = min(stride, T)
|
| 2079 |
+
first_win_start = 0
|
| 2080 |
+
first_win_end = min(T, first_core_end + overlap)
|
| 2081 |
+
|
| 2082 |
+
first_latent_chunk = latents[:, :, first_win_start:first_win_end]
|
| 2083 |
+
first_audio_chunk = self.vae.decode(first_latent_chunk).sample
|
| 2084 |
+
|
| 2085 |
+
upsample_factor = first_audio_chunk.shape[-1] / first_latent_chunk.shape[-1]
|
| 2086 |
+
audio_channels = first_audio_chunk.shape[1]
|
| 2087 |
+
|
| 2088 |
+
# Calculate total audio length and pre-allocate CPU tensor
|
| 2089 |
+
total_audio_length = int(round(T * upsample_factor))
|
| 2090 |
+
final_audio = torch.zeros(B, audio_channels, total_audio_length,
|
| 2091 |
+
dtype=first_audio_chunk.dtype, device='cpu')
|
| 2092 |
+
|
| 2093 |
+
# Process first chunk: trim and copy to CPU
|
| 2094 |
+
first_added_end = first_win_end - first_core_end
|
| 2095 |
+
first_trim_end = int(round(first_added_end * upsample_factor))
|
| 2096 |
+
first_audio_len = first_audio_chunk.shape[-1]
|
| 2097 |
+
first_end_idx = first_audio_len - first_trim_end if first_trim_end > 0 else first_audio_len
|
| 2098 |
+
|
| 2099 |
+
first_audio_core = first_audio_chunk[:, :, :first_end_idx]
|
| 2100 |
+
audio_write_pos = first_audio_core.shape[-1]
|
| 2101 |
+
final_audio[:, :, :audio_write_pos] = first_audio_core.cpu()
|
| 2102 |
+
|
| 2103 |
+
# Free GPU memory
|
| 2104 |
+
del first_audio_chunk, first_audio_core, first_latent_chunk
|
| 2105 |
+
|
| 2106 |
+
# Process remaining chunks
|
| 2107 |
+
for i in tqdm(range(1, num_steps), desc="Decoding audio chunks"):
|
| 2108 |
+
# Core range in latents
|
| 2109 |
+
core_start = i * stride
|
| 2110 |
+
core_end = min(core_start + stride, T)
|
| 2111 |
+
|
| 2112 |
+
# Window range (with overlap)
|
| 2113 |
+
win_start = max(0, core_start - overlap)
|
| 2114 |
+
win_end = min(T, core_end + overlap)
|
| 2115 |
+
|
| 2116 |
+
# Extract chunk
|
| 2117 |
+
latent_chunk = latents[:, :, win_start:win_end]
|
| 2118 |
+
|
| 2119 |
+
# Decode on GPU
|
| 2120 |
+
# [Batch, Channels, AudioSamples]
|
| 2121 |
+
audio_chunk = self.vae.decode(latent_chunk).sample
|
| 2122 |
+
|
| 2123 |
+
# Calculate trim amounts in audio samples
|
| 2124 |
+
added_start = core_start - win_start # latent frames
|
| 2125 |
+
trim_start = int(round(added_start * upsample_factor))
|
| 2126 |
+
|
| 2127 |
+
added_end = win_end - core_end # latent frames
|
| 2128 |
+
trim_end = int(round(added_end * upsample_factor))
|
| 2129 |
+
|
| 2130 |
+
# Trim audio
|
| 2131 |
+
audio_len = audio_chunk.shape[-1]
|
| 2132 |
+
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
|
| 2133 |
+
|
| 2134 |
+
audio_core = audio_chunk[:, :, trim_start:end_idx]
|
| 2135 |
+
|
| 2136 |
+
# Copy to pre-allocated CPU tensor
|
| 2137 |
+
core_len = audio_core.shape[-1]
|
| 2138 |
+
final_audio[:, :, audio_write_pos:audio_write_pos + core_len] = audio_core.cpu()
|
| 2139 |
+
audio_write_pos += core_len
|
| 2140 |
+
|
| 2141 |
+
# Free GPU memory immediately
|
| 2142 |
+
del audio_chunk, audio_core, latent_chunk
|
| 2143 |
+
|
| 2144 |
+
# Trim to actual length (in case of rounding differences)
|
| 2145 |
+
final_audio = final_audio[:, :, :audio_write_pos]
|
| 2146 |
+
|
| 2147 |
+
return final_audio
|
| 2148 |
|
| 2149 |
def generate_music(
|
| 2150 |
self,
|