File size: 4,331 Bytes
1daf6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import os
import soundfile as sf
from diffusers.models import AutoencoderOobleck
from tqdm import tqdm
import torch.nn.functional as F

def process_audio(audio_path, target_sr=48000):
    try:
        # Load audio using soundfile
        audio_np, sr = sf.read(audio_path, dtype='float32')
        
        # Convert to torch: [samples, channels] or [samples] -> [channels, samples]
        if audio_np.ndim == 1:
            audio = torch.from_numpy(audio_np).unsqueeze(0)
        else:
            audio = torch.from_numpy(audio_np.T)
        
        # Ensure stereo
        if audio.shape[0] == 1:
            audio = torch.cat([audio, audio], dim=0)
        
        audio = audio[:2]
        
        # Resample if needed
        if sr != target_sr:
            ratio = target_sr / sr
            new_length = int(audio.shape[-1] * ratio)
            audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
        
        audio = torch.clamp(audio, -1.0, 1.0)
        return audio.unsqueeze(0) # Add batch dim: [1, 2, samples]
        
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

def main():
    print("Initializing Calibration Data Preparation...")
    
    current_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.dirname(current_dir)
    data_dir = os.path.join(project_root, "data", "quant_data")
    output_path = os.path.join(project_root, "data", "calibration_latents.pt")
    vae_path = os.path.join(project_root, "checkpoints", "vae")
    
    if not os.path.exists(data_dir):
        print(f"Error: Data directory not found at {data_dir}")
        return

    print(f"Loading VAE from {vae_path}...")
    try:
        vae = AutoencoderOobleck.from_pretrained(vae_path)
    except Exception as e:
        print(f"Failed to load VAE: {e}")
        return

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Check for XPU
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        device = "xpu"
    
    print(f"Using device: {device}")
    vae = vae.to(device)
    vae.eval()

    audio_files = [f for f in os.listdir(data_dir) if f.endswith('.flac')]
    print(f"Found {len(audio_files)} audio files.")
    
    all_chunks = []
    chunk_size = 512 # Latent frames
    samples_per_latent = 1920
    audio_chunk_size = chunk_size * samples_per_latent
    
    pbar = tqdm(audio_files, desc="Processing audio")
    for audio_file in pbar:
        file_path = os.path.join(data_dir, audio_file)
        full_audio = process_audio(file_path)
        
        if full_audio is None:
            continue
            
        # Split audio into chunks
        num_samples = full_audio.shape[-1]
        
        for start_idx in range(0, num_samples, audio_chunk_size):
            end_idx = start_idx + audio_chunk_size
            if end_idx > num_samples:
                break # Skip incomplete chunks
                
            audio_input = full_audio[:, :, start_idx:end_idx].to(device)
            
            try:
                with torch.no_grad():
                    # Encode
                    # VAE encode expects [Batch, Channels, Samples]
                    # Returns DiagonalGaussianDistribution
                    posterior = vae.encode(audio_input).latent_dist
                    latents = posterior.sample() # [1, 64, LatentLength]
                    
                    # It should be exactly chunk_size, but let's be safe
                    if latents.shape[-1] >= chunk_size:
                        all_chunks.append(latents[:, :, :chunk_size].cpu())
                    
                    pbar.set_postfix({"chunks": len(all_chunks)})
                    
            except Exception as e:
                print(f"Error encoding chunk {start_idx}-{end_idx} of {audio_file}: {e}")
                torch.cuda.empty_cache()
                if device == "xpu":
                    torch.xpu.empty_cache()
    
    print(f"Collected {len(all_chunks)} chunks of size {chunk_size}.")
    
    if len(all_chunks) > 0:
        print(f"Saving to {output_path}...")
        torch.save(all_chunks, output_path)
        print("Done.")
    else:
        print("No chunks collected.")

if __name__ == "__main__":
    main()