Spaces:
Running
on
A100
Running
on
A100
File size: 4,331 Bytes
1daf6b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
import os
import soundfile as sf
from diffusers.models import AutoencoderOobleck
from tqdm import tqdm
import torch.nn.functional as F
def process_audio(audio_path, target_sr=48000):
try:
# Load audio using soundfile
audio_np, sr = sf.read(audio_path, dtype='float32')
# Convert to torch: [samples, channels] or [samples] -> [channels, samples]
if audio_np.ndim == 1:
audio = torch.from_numpy(audio_np).unsqueeze(0)
else:
audio = torch.from_numpy(audio_np.T)
# Ensure stereo
if audio.shape[0] == 1:
audio = torch.cat([audio, audio], dim=0)
audio = audio[:2]
# Resample if needed
if sr != target_sr:
ratio = target_sr / sr
new_length = int(audio.shape[-1] * ratio)
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
audio = torch.clamp(audio, -1.0, 1.0)
return audio.unsqueeze(0) # Add batch dim: [1, 2, samples]
except Exception as e:
print(f"Error processing {audio_path}: {e}")
return None
def main():
print("Initializing Calibration Data Preparation...")
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
data_dir = os.path.join(project_root, "data", "quant_data")
output_path = os.path.join(project_root, "data", "calibration_latents.pt")
vae_path = os.path.join(project_root, "checkpoints", "vae")
if not os.path.exists(data_dir):
print(f"Error: Data directory not found at {data_dir}")
return
print(f"Loading VAE from {vae_path}...")
try:
vae = AutoencoderOobleck.from_pretrained(vae_path)
except Exception as e:
print(f"Failed to load VAE: {e}")
return
device = "cuda" if torch.cuda.is_available() else "cpu"
# Check for XPU
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = "xpu"
print(f"Using device: {device}")
vae = vae.to(device)
vae.eval()
audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')]
print(f"Found {len(audio_files)} audio files.")
all_chunks = []
chunk_size = 512 # Latent frames
samples_per_latent = 1920
audio_chunk_size = chunk_size * samples_per_latent
pbar = tqdm(audio_files, desc="Processing audio")
for audio_file in pbar:
file_path = os.path.join(data_dir, audio_file)
full_audio = process_audio(file_path)
if full_audio is None:
continue
# Split audio into chunks
num_samples = full_audio.shape[-1]
for start_idx in range(0, num_samples, audio_chunk_size):
end_idx = start_idx + audio_chunk_size
if end_idx > num_samples:
break # Skip incomplete chunks
audio_input = full_audio[:, :, start_idx:end_idx].to(device)
try:
with torch.no_grad():
# Encode
# VAE encode expects [Batch, Channels, Samples]
# Returns DiagonalGaussianDistribution
posterior = vae.encode(audio_input).latent_dist
latents = posterior.sample() # [1, 64, LatentLength]
# It should be exactly chunk_size, but let's be safe
if latents.shape[-1] >= chunk_size:
all_chunks.append(latents[:, :, :chunk_size].cpu())
pbar.set_postfix({"chunks": len(all_chunks)})
except Exception as e:
print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}")
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.")
if len(all_chunks) > 0:
print(f"Saving to {output_path}...")
torch.save(all_chunks, output_path)
print("Done.")
else:
print("No chunks collected.")
if __name__ == "__main__":
main()
|