File size: 3,117 Bytes
b3c4dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torchaudio
from pathlib import Path
from tqdm import tqdm
import torch
import argparse
import json
from model.ear_vae import EAR_VAE

def main(args):
    indir = args.indir
    model_path = args.model_path
    outdir = args.outdir
    device = args.device
    config_path = args.config

    print(f"Input directory: {indir}")
    print(f"Model path: {model_path}")
    print(f"Output directory: {outdir}")
    print(f"Device: {device}")
    print(f"Config path: {config_path}")
    

    input_path = Path(indir)
    output_path_dir = Path(outdir)
    output_path_dir.mkdir(parents=True, exist_ok=True)

    with open(config_path, 'r') as f:
        vae_gan_model_config = json.load(f)

    print("Loading model...")
    model = EAR_VAE(model_config=vae_gan_model_config).to(device)

    state = torch.load(model_path, map_location="cpu")
    model.load_state_dict(state)
    model.eval()
    print("Model loaded successfully.")

    audios = list(input_path.rglob("*"))
    print(f"Found {len(audios)} audio files to process.")

    with torch.no_grad():
        for audio_path in tqdm(audios, desc="Processing audio files"):
            try:
                gt_y, sr = torchaudio.load(audio_path, backend="ffmpeg")

                if len(gt_y.shape) == 1:
                    gt_y = gt_y.unsqueeze(0)

                # Resample if necessary
                if sr != 44100:
                    resampler = torchaudio.transforms.Resample(sr, 44100).to(device)
                    gt_y = resampler(gt_y)

                gt_y = gt_y.to(device, torch.float32)
                
                # Convert to stereo if mono
                if gt_y.shape[0] == 1:
                    gt_y = torch.cat([gt_y, gt_y], dim=0)

                # Add batch dimension
                gt_y = gt_y.unsqueeze(0)

                fake_audio = model.inference(gt_y)

                output_filename = f"{Path(audio_path).stem}_{Path(model_path).stem}.wav"
                output_path = output_path_dir / output_filename

                fake_audio_processed = fake_audio.squeeze(0).cpu()
                torchaudio.save(output_path, fake_audio_processed, sample_rate=44100, backend="ffmpeg")
            except Exception as e:
                print(f"Error processing {audio_path}: {e}")

                continue


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run VAE-GAN audio inference.")
    parser.add_argument('--indir', type=str, default='./data', help='Input directory for audio files.')
    parser.add_argument('--model_path', type=str, default='./pretrained_weight/ear_vae_44k.pyt', help='Path to the pretrained model weight.')
    parser.add_argument('--outdir', type=str, default='./results', help='Output directory for generated audio files.')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to run the model on (e.g., "cuda:0" or "cpu").')
    parser.add_argument('--config', type=str, default='./config/model_config.json', help='Path to the model config file.')
    
    args = parser.parse_args()
    main(args)