ChuxiJ commited on
Commit
b857422
·
1 Parent(s): 82758c5

reduce vRAM usage for vae decode

Browse files
Files changed (1) hide show
  1. acestep/handler.py +92 -9
acestep/handler.py CHANGED
@@ -1995,7 +1995,7 @@ class AceStepHandler:
1995
 
1996
  return outputs
1997
 
1998
- def tiled_decode(self, latents, chunk_size=512, overlap=64):
1999
  """
2000
  Decode latents using tiling to reduce VRAM usage.
2001
  Uses overlap-discard strategy to avoid boundary artifacts.
@@ -2004,6 +2004,7 @@ class AceStepHandler:
2004
  latents: [Batch, Channels, Length]
2005
  chunk_size: Size of latent chunk to process at once
2006
  overlap: Overlap size in latent frames
 
2007
  """
2008
  B, C, T = latents.shape
2009
 
@@ -2015,14 +2016,21 @@ class AceStepHandler:
2015
  stride = chunk_size - 2 * overlap
2016
  if stride <= 0:
2017
  raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}")
2018
-
2019
- decoded_audio_list = []
2020
-
2021
- # We need to determine upsample factor to trim audio correctly
2022
- upsample_factor = None
2023
 
2024
  num_steps = math.ceil(T / stride)
2025
 
 
 
 
 
 
 
 
 
 
 
 
 
2026
  for i in tqdm(range(num_steps), desc="Decoding audio chunks"):
2027
  # Core range in latents
2028
  core_start = i * stride
@@ -2045,16 +2053,16 @@ class AceStepHandler:
2045
 
2046
  # Calculate trim amounts in audio samples
2047
  # How much overlap was added at the start?
2048
- added_start = core_start - win_start # latent frames
2049
  trim_start = int(round(added_start * upsample_factor))
2050
 
2051
  # How much overlap was added at the end?
2052
- added_end = win_end - core_end # latent frames
2053
  trim_end = int(round(added_end * upsample_factor))
2054
 
2055
  # Trim audio
2056
  audio_len = audio_chunk.shape[-1]
2057
- end_idx = audio_len - trim_end
2058
 
2059
  audio_core = audio_chunk[:, :, trim_start:end_idx]
2060
  decoded_audio_list.append(audio_core)
@@ -2062,6 +2070,81 @@ class AceStepHandler:
2062
  # Concatenate
2063
  final_audio = torch.cat(decoded_audio_list, dim=-1)
2064
  return final_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2065
 
2066
  def generate_music(
2067
  self,
 
1995
 
1996
  return outputs
1997
 
1998
+ def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False):
1999
  """
2000
  Decode latents using tiling to reduce VRAM usage.
2001
  Uses overlap-discard strategy to avoid boundary artifacts.
 
2004
  latents: [Batch, Channels, Length]
2005
  chunk_size: Size of latent chunk to process at once
2006
  overlap: Overlap size in latent frames
2007
+ offload_wav_to_cpu: If True, offload decoded wav audio to CPU immediately to save VRAM
2008
  """
2009
  B, C, T = latents.shape
2010
 
 
2016
  stride = chunk_size - 2 * overlap
2017
  if stride <= 0:
2018
  raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}")
 
 
 
 
 
2019
 
2020
  num_steps = math.ceil(T / stride)
2021
 
2022
+ if offload_wav_to_cpu:
2023
+ # Optimized path: offload wav to CPU immediately to save VRAM
2024
+ return self._tiled_decode_offload_cpu(latents, B, T, stride, overlap, num_steps)
2025
+ else:
2026
+ # Default path: keep everything on GPU
2027
+ return self._tiled_decode_gpu(latents, B, T, stride, overlap, num_steps)
2028
+
2029
+ def _tiled_decode_gpu(self, latents, B, T, stride, overlap, num_steps):
2030
+ """Standard tiled decode keeping all data on GPU."""
2031
+ decoded_audio_list = []
2032
+ upsample_factor = None
2033
+
2034
  for i in tqdm(range(num_steps), desc="Decoding audio chunks"):
2035
  # Core range in latents
2036
  core_start = i * stride
 
2053
 
2054
  # Calculate trim amounts in audio samples
2055
  # How much overlap was added at the start?
2056
+ added_start = core_start - win_start # latent frames
2057
  trim_start = int(round(added_start * upsample_factor))
2058
 
2059
  # How much overlap was added at the end?
2060
+ added_end = win_end - core_end # latent frames
2061
  trim_end = int(round(added_end * upsample_factor))
2062
 
2063
  # Trim audio
2064
  audio_len = audio_chunk.shape[-1]
2065
+ end_idx = audio_len - trim_end if trim_end > 0 else audio_len
2066
 
2067
  audio_core = audio_chunk[:, :, trim_start:end_idx]
2068
  decoded_audio_list.append(audio_core)
 
2070
  # Concatenate
2071
  final_audio = torch.cat(decoded_audio_list, dim=-1)
2072
  return final_audio
2073
+
2074
+ def _tiled_decode_offload_cpu(self, latents, B, T, stride, overlap, num_steps):
2075
+ """Optimized tiled decode that offloads to CPU immediately to save VRAM."""
2076
+ # First pass: decode first chunk to get upsample_factor and audio channels
2077
+ first_core_start = 0
2078
+ first_core_end = min(stride, T)
2079
+ first_win_start = 0
2080
+ first_win_end = min(T, first_core_end + overlap)
2081
+
2082
+ first_latent_chunk = latents[:, :, first_win_start:first_win_end]
2083
+ first_audio_chunk = self.vae.decode(first_latent_chunk).sample
2084
+
2085
+ upsample_factor = first_audio_chunk.shape[-1] / first_latent_chunk.shape[-1]
2086
+ audio_channels = first_audio_chunk.shape[1]
2087
+
2088
+ # Calculate total audio length and pre-allocate CPU tensor
2089
+ total_audio_length = int(round(T * upsample_factor))
2090
+ final_audio = torch.zeros(B, audio_channels, total_audio_length,
2091
+ dtype=first_audio_chunk.dtype, device='cpu')
2092
+
2093
+ # Process first chunk: trim and copy to CPU
2094
+ first_added_end = first_win_end - first_core_end
2095
+ first_trim_end = int(round(first_added_end * upsample_factor))
2096
+ first_audio_len = first_audio_chunk.shape[-1]
2097
+ first_end_idx = first_audio_len - first_trim_end if first_trim_end > 0 else first_audio_len
2098
+
2099
+ first_audio_core = first_audio_chunk[:, :, :first_end_idx]
2100
+ audio_write_pos = first_audio_core.shape[-1]
2101
+ final_audio[:, :, :audio_write_pos] = first_audio_core.cpu()
2102
+
2103
+ # Free GPU memory
2104
+ del first_audio_chunk, first_audio_core, first_latent_chunk
2105
+
2106
+ # Process remaining chunks
2107
+ for i in tqdm(range(1, num_steps), desc="Decoding audio chunks"):
2108
+ # Core range in latents
2109
+ core_start = i * stride
2110
+ core_end = min(core_start + stride, T)
2111
+
2112
+ # Window range (with overlap)
2113
+ win_start = max(0, core_start - overlap)
2114
+ win_end = min(T, core_end + overlap)
2115
+
2116
+ # Extract chunk
2117
+ latent_chunk = latents[:, :, win_start:win_end]
2118
+
2119
+ # Decode on GPU
2120
+ # [Batch, Channels, AudioSamples]
2121
+ audio_chunk = self.vae.decode(latent_chunk).sample
2122
+
2123
+ # Calculate trim amounts in audio samples
2124
+ added_start = core_start - win_start # latent frames
2125
+ trim_start = int(round(added_start * upsample_factor))
2126
+
2127
+ added_end = win_end - core_end # latent frames
2128
+ trim_end = int(round(added_end * upsample_factor))
2129
+
2130
+ # Trim audio
2131
+ audio_len = audio_chunk.shape[-1]
2132
+ end_idx = audio_len - trim_end if trim_end > 0 else audio_len
2133
+
2134
+ audio_core = audio_chunk[:, :, trim_start:end_idx]
2135
+
2136
+ # Copy to pre-allocated CPU tensor
2137
+ core_len = audio_core.shape[-1]
2138
+ final_audio[:, :, audio_write_pos:audio_write_pos + core_len] = audio_core.cpu()
2139
+ audio_write_pos += core_len
2140
+
2141
+ # Free GPU memory immediately
2142
+ del audio_chunk, audio_core, latent_chunk
2143
+
2144
+ # Trim to actual length (in case of rounding differences)
2145
+ final_audio = final_audio[:, :, :audio_write_pos]
2146
+
2147
+ return final_audio
2148
 
2149
  def generate_music(
2150
  self,