Spaces:
Sleeping
Sleeping
| # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ Example Usage | |
| cpu: | |
| s3tokenizer --root_path /path/to/audio/files \ | |
| --model speech_tokenizer_v1 \ | |
| --device "cpu" \ | |
| --batch_size 32 | |
| gpu: | |
| torchrun --nproc_per_node=8 --nnodes=1 \ | |
| --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ | |
| `which s3tokenizer` --root_path /path/to/audio/files \ | |
| --model speech_tokenizer_v1 \ | |
| --device "cuda" \ | |
| --batch_size 32 | |
| """ | |
| import argparse | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import DataLoader, Dataset, DistributedSampler | |
| from tqdm import tqdm | |
| import s3tokenizer | |
| class AudioDataset(Dataset): | |
| def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3'], | |
| use_cache=True, cache_file=None, max_workers=8): | |
| self.data = [] | |
| # Define cache file path | |
| if cache_file is None: | |
| cache_file = os.path.join(root_path, '.audio_file_cache.pkl') | |
| # Try to load from cache first | |
| if use_cache and os.path.exists(cache_file): | |
| import pickle | |
| print(f"Loading file list from cache: {cache_file}") | |
| try: | |
| with open(cache_file, 'rb') as f: | |
| self.data = pickle.load(f) | |
| print(f"Loaded {len(self.data)} files from cache") | |
| return | |
| except Exception as e: | |
| print(f"Failed to load cache: {e}, scanning directory...") | |
| # Method 1: Use os.walk() which is typically faster than pathlib | |
| print(f"Scanning directory: {root_path}") | |
| print(f"Looking for extensions: {extensions}") | |
| import os | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| def scan_directory(args): | |
| dirpath, extensions = args | |
| files = [] | |
| try: | |
| with os.scandir(dirpath) as entries: | |
| for entry in entries: | |
| if entry.is_file() and any(entry.name.endswith(ext) for ext in extensions): | |
| files.append(entry.path) | |
| except PermissionError: | |
| pass | |
| return files | |
| # Collect all directories first | |
| all_dirs = [root_path] | |
| for dirpath, dirnames, _ in os.walk(root_path): | |
| all_dirs.extend(os.path.join(dirpath, d) for d in dirnames) | |
| # Process directories in parallel | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| futures = [executor.submit(scan_directory, (d, extensions)) for d in all_dirs] | |
| with tqdm(total=len(all_dirs), desc="Scanning directories") as pbar: | |
| for future in as_completed(futures): | |
| self.data.extend(future.result()) | |
| pbar.update(1) | |
| # Sort for consistent ordering | |
| self.data.sort() | |
| if len(self.data) == 0: | |
| raise ValueError(f"No audio files found in {root_path}") | |
| print(f"Found {len(self.data)} audio files") | |
| # Save to cache | |
| if use_cache: | |
| try: | |
| import pickle | |
| print(f"Saving file list to cache: {cache_file}") | |
| # Ensure parent directory exists | |
| cache_dir = os.path.dirname(cache_file) | |
| if cache_dir and not os.path.exists(cache_dir): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(self.data, f) | |
| except Exception as e: | |
| print(f"Failed to save cache: {e}") | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| file_path = self.data[idx] | |
| try: | |
| audio = s3tokenizer.load_audio(file_path) | |
| mel = s3tokenizer.log_mel_spectrogram(audio) | |
| return file_path, mel | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {e}") | |
| return None, None | |
| def collate_fn(batch): | |
| # Filter out None entries (failed files) | |
| batch = [item for item in batch if item[0] is not None] | |
| if len(batch) == 0: | |
| return [], None, None | |
| file_paths = [item[0] for item in batch] | |
| mels = [item[1] for item in batch] | |
| mels, mels_lens = s3tokenizer.padding(mels) | |
| return file_paths, mels, mels_lens | |
| def init_distributed(): | |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) | |
| rank = int(os.environ.get('RANK', 0)) | |
| print('Inference on multiple gpus, this gpu {}'.format(local_rank) + | |
| ', rank {}, world_size {}'.format(rank, world_size)) | |
| torch.cuda.set_device(local_rank) | |
| dist.init_process_group("nccl") | |
| return world_size, local_rank, rank | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description='extract speech code') | |
| parser.add_argument('--model', | |
| required=True, | |
| type=str, | |
| choices=[ | |
| "speech_tokenizer_v1", "speech_tokenizer_v1_25hz", | |
| "speech_tokenizer_v2_25hz" | |
| ], | |
| help='model version') | |
| parser.add_argument('--root_path', | |
| required=True, | |
| type=str, | |
| help='root directory containing audio files') | |
| parser.add_argument('--device', | |
| required=True, | |
| type=str, | |
| choices=["cuda", "cpu"], | |
| help='device for inference') | |
| parser.add_argument('--batch_size', | |
| required=True, | |
| type=int, | |
| help='batch size (per-device) for inference') | |
| parser.add_argument('--num_workers', | |
| type=int, | |
| default=4, | |
| help='workers for dataloader') | |
| parser.add_argument('--prefetch', | |
| type=int, | |
| default=5, | |
| help='prefetch for dataloader') | |
| parser.add_argument('--extensions', | |
| nargs='+', | |
| default=['.wav', '.flac', '.mp3'], | |
| help='audio file extensions to process') | |
| parser.add_argument('--use_cache', | |
| action='store_true', | |
| help='use cached file list to avoid re-scanning') | |
| parser.add_argument('--no_cache', | |
| action='store_true', | |
| help='force re-scan even if cache exists') | |
| parser.add_argument('--cache_file', | |
| type=str, | |
| default=None, | |
| help='path to cache file (default: root_path/.audio_file_cache.pkl)') | |
| parser.add_argument('--scan_workers', | |
| type=int, | |
| default=8, | |
| help='number of workers for directory scanning') | |
| parser.add_argument('--file_list', | |
| type=str, | |
| default=None, | |
| help='path to pre-generated file list (one file per line)') | |
| parser.add_argument('--skip_existing', | |
| action='store_true', | |
| help='skip files that already have _fsq.pt output') | |
| args = parser.parse_args() | |
| return args | |
| def save_tokens(file_path, codes, codes_len): | |
| """Save tokens as .pt file with _fsq suffix""" | |
| # Remove extension and add _fsq.pt | |
| base_name = os.path.splitext(file_path)[0] | |
| output_path = f"{base_name}_fsq.pt" | |
| # Extract only valid codes (up to codes_len) | |
| valid_codes = codes[:codes_len] | |
| # Save as tensor | |
| torch.save(valid_codes, output_path) | |
| return output_path | |
| def main(): | |
| args = get_args() | |
| if args.device == "cuda": | |
| assert (torch.cuda.is_available()) | |
| world_size, local_rank, rank = init_distributed() | |
| else: | |
| world_size, local_rank, rank = 1, 0, 0 | |
| device = torch.device(args.device) | |
| model = s3tokenizer.load_model(args.model).to(device) | |
| # Handle different data loading methods | |
| if args.file_list: | |
| # Option 3: Load from pre-generated file list | |
| print(f"Loading file list from: {args.file_list}") | |
| with open(args.file_list, 'r') as f: | |
| file_paths = [] | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| file_paths.append(line) | |
| # Create a simple dataset | |
| class FileListDataset(Dataset): | |
| def __init__(self, file_paths, skip_existing=False): | |
| self.data = [] | |
| skipped_existing = 0 | |
| for fp in file_paths: | |
| if skip_existing: | |
| output_path = fp.replace('.wav', '_fsq.pt') | |
| if os.path.exists(output_path): | |
| print(f'*******skip file {output_path}') | |
| skipped_existing += 1 | |
| continue | |
| self.data.append(fp) | |
| print(f"Will process {len(self.data)} files") | |
| if skip_existing and skipped_existing > 0: | |
| print(f"Skipped {skipped_existing} already processed files") | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| file_path = self.data[idx] | |
| try: | |
| # Check if file exists | |
| if not os.path.exists(file_path): | |
| print(f"File not found: {file_path}") | |
| return None, None | |
| # Try to load audio | |
| audio = s3tokenizer.load_audio(file_path) | |
| mel = s3tokenizer.log_mel_spectrogram(audio) | |
| return file_path, mel | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {e}") | |
| return None, None | |
| dataset = FileListDataset(file_paths, skip_existing=args.skip_existing) | |
| else: | |
| # Use the enhanced AudioDataset with caching | |
| dataset = AudioDataset( | |
| args.root_path, | |
| args.extensions, | |
| use_cache=not args.no_cache, | |
| cache_file=args.cache_file, | |
| max_workers=args.scan_workers | |
| ) | |
| # Filter out existing files if requested | |
| if args.skip_existing: | |
| original_count = len(dataset.data) | |
| dataset.data = [ | |
| fp for fp in dataset.data | |
| if not os.path.exists(os.path.join(os.path.dirname(fp), f"{os.path.splitext(os.path.basename(fp))[0]}_fsq.pt")) | |
| ] | |
| print(f"Skipping {original_count - len(dataset.data)} already processed files") | |
| if args.device == "cuda": | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, device_ids=[local_rank]) | |
| sampler = DistributedSampler(dataset, | |
| num_replicas=world_size, | |
| rank=rank) | |
| else: | |
| sampler = None | |
| dataloader = DataLoader(dataset, | |
| batch_size=args.batch_size, | |
| sampler=sampler, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| prefetch_factor=args.prefetch, | |
| collate_fn=collate_fn) | |
| total_steps = len(dataset) | |
| if rank == 0: | |
| progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") | |
| processed_count = 0 | |
| failed_count = 0 | |
| failed_files = [] | |
| for file_paths, mels, mels_lens in dataloader: | |
| # Skip empty batches (all files failed) | |
| if len(file_paths) == 0: | |
| continue | |
| codes, codes_lens = model(mels.to(device), mels_lens.to(device)) | |
| # Process each file in the batch | |
| for i, file_path in enumerate(file_paths): | |
| try: | |
| code = codes[i] | |
| code_len = codes_lens[i].item() | |
| # Save tokens as .pt file | |
| output_path = save_tokens(file_path, code, code_len) | |
| if rank == 0 and processed_count < 10: # Only show first 10 to avoid spam | |
| tqdm.write(f"Saved: {file_path} -> {output_path}") | |
| processed_count += 1 | |
| except Exception as e: | |
| failed_count += 1 | |
| failed_files.append(file_path) | |
| if rank == 0: | |
| tqdm.write(f"Failed to save {file_path}: {e}") | |
| if rank == 0: | |
| progress_bar.update(world_size * (len(file_paths) + failed_count)) | |
| if rank == 0: | |
| progress_bar.close() | |
| print(f"\nProcessed {processed_count} files successfully on rank {rank}") | |
| if failed_count > 0: | |
| print(f"Failed to process {failed_count} files") | |
| # Save failed files list | |
| failed_list_path = os.path.join(args.root_path if not args.file_list else ".", "failed_files.txt") | |
| with open(failed_list_path, 'w') as f: | |
| for ff in failed_files: | |
| f.write(f"{ff}\n") | |
| print(f"Failed files saved to: {failed_list_path}") | |
| if args.device == "cuda": | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() |