|
|
import torch |
|
|
import os |
|
|
import soundfile as sf |
|
|
from diffusers.models import AutoencoderOobleck |
|
|
from tqdm import tqdm |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def process_audio(audio_path, target_sr=48000): |
|
|
try: |
|
|
|
|
|
audio_np, sr = sf.read(audio_path, dtype='float32') |
|
|
|
|
|
|
|
|
if audio_np.ndim == 1: |
|
|
audio = torch.from_numpy(audio_np).unsqueeze(0) |
|
|
else: |
|
|
audio = torch.from_numpy(audio_np.T) |
|
|
|
|
|
|
|
|
if audio.shape[0] == 1: |
|
|
audio = torch.cat([audio, audio], dim=0) |
|
|
|
|
|
audio = audio[:2] |
|
|
|
|
|
|
|
|
if sr != target_sr: |
|
|
ratio = target_sr / sr |
|
|
new_length = int(audio.shape[-1] * ratio) |
|
|
audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0) |
|
|
|
|
|
audio = torch.clamp(audio, -1.0, 1.0) |
|
|
return audio.unsqueeze(0) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {audio_path}: {e}") |
|
|
return None |
|
|
|
|
|
def main(): |
|
|
print("Initializing Calibration Data Preparation...") |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
project_root = os.path.dirname(current_dir) |
|
|
data_dir = os.path.join(project_root, "data", "quant_data") |
|
|
output_path = os.path.join(project_root, "data", "calibration_latents.pt") |
|
|
vae_path = os.path.join(project_root, "checkpoints", "vae") |
|
|
|
|
|
if not os.path.exists(data_dir): |
|
|
print(f"Error: Data directory not found at {data_dir}") |
|
|
return |
|
|
|
|
|
print(f"Loading VAE from {vae_path}...") |
|
|
try: |
|
|
vae = AutoencoderOobleck.from_pretrained(vae_path) |
|
|
except Exception as e: |
|
|
print(f"Failed to load VAE: {e}") |
|
|
return |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
if hasattr(torch, "xpu") and torch.xpu.is_available(): |
|
|
device = "xpu" |
|
|
|
|
|
print(f"Using device: {device}") |
|
|
vae = vae.to(device) |
|
|
vae.eval() |
|
|
|
|
|
audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')] |
|
|
print(f"Found {len(audio_files)} audio files.") |
|
|
|
|
|
all_chunks = [] |
|
|
chunk_size = 512 |
|
|
samples_per_latent = 1920 |
|
|
audio_chunk_size = chunk_size * samples_per_latent |
|
|
|
|
|
pbar = tqdm(audio_files, desc="Processing audio") |
|
|
for audio_file in pbar: |
|
|
file_path = os.path.join(data_dir, audio_file) |
|
|
full_audio = process_audio(file_path) |
|
|
|
|
|
if full_audio is None: |
|
|
continue |
|
|
|
|
|
|
|
|
num_samples = full_audio.shape[-1] |
|
|
|
|
|
for start_idx in range(0, num_samples, audio_chunk_size): |
|
|
end_idx = start_idx + audio_chunk_size |
|
|
if end_idx > num_samples: |
|
|
break |
|
|
|
|
|
audio_input = full_audio[:, :, start_idx:end_idx].to(device) |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
posterior = vae.encode(audio_input).latent_dist |
|
|
latents = posterior.sample() |
|
|
|
|
|
|
|
|
if latents.shape[-1] >= chunk_size: |
|
|
all_chunks.append(latents[:, :, :chunk_size].cpu()) |
|
|
|
|
|
pbar.set_postfix({"chunks": len(all_chunks)}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}") |
|
|
torch.cuda.empty_cache() |
|
|
if device == "xpu": |
|
|
torch.xpu.empty_cache() |
|
|
|
|
|
print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.") |
|
|
|
|
|
if len(all_chunks) > 0: |
|
|
print(f"Saving to {output_path}...") |
|
|
torch.save(all_chunks, output_path) |
|
|
print("Done.") |
|
|
else: |
|
|
print("No chunks collected.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|