Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
ea8cd35
1
Parent(s):
34bf06f
Edit cli s3
Browse files
speech/tools/S3Tokenizer/s3tokenizer/cli.py
CHANGED
|
@@ -15,18 +15,18 @@
|
|
| 15 |
cpu:
|
| 16 |
|
| 17 |
s3tokenizer --root_path /path/to/audio/files \
|
| 18 |
-
--model
|
| 19 |
--device "cpu" \
|
| 20 |
--batch_size 32
|
| 21 |
|
| 22 |
gpu:
|
| 23 |
|
| 24 |
-
torchrun --nproc_per_node=
|
| 25 |
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
| 26 |
-
`which s3tokenizer` --root_path /
|
| 27 |
-
--model
|
| 28 |
--device "cuda" \
|
| 29 |
-
--batch_size
|
| 30 |
|
| 31 |
"""
|
| 32 |
|
|
@@ -44,13 +44,60 @@ import s3tokenizer
|
|
| 44 |
|
| 45 |
class AudioDataset(Dataset):
|
| 46 |
|
| 47 |
-
def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3']
|
|
|
|
| 48 |
self.data = []
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
# Sort for consistent ordering
|
| 56 |
self.data.sort()
|
|
@@ -59,18 +106,39 @@ class AudioDataset(Dataset):
|
|
| 59 |
raise ValueError(f"No audio files found in {root_path}")
|
| 60 |
|
| 61 |
print(f"Found {len(self.data)} audio files")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def __len__(self):
|
| 64 |
return len(self.data)
|
| 65 |
|
| 66 |
def __getitem__(self, idx):
|
| 67 |
file_path = self.data[idx]
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def collate_fn(batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
file_paths = [item[0] for item in batch]
|
| 75 |
mels = [item[1] for item in batch]
|
| 76 |
mels, mels_lens = s3tokenizer.padding(mels)
|
|
@@ -123,6 +191,27 @@ def get_args():
|
|
| 123 |
nargs='+',
|
| 124 |
default=['.wav', '.flac', '.mp3'],
|
| 125 |
help='audio file extensions to process')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
args = parser.parse_args()
|
| 127 |
return args
|
| 128 |
|
|
@@ -135,8 +224,6 @@ def save_tokens(file_path, codes, codes_len):
|
|
| 135 |
|
| 136 |
# Extract only valid codes (up to codes_len)
|
| 137 |
valid_codes = codes[:codes_len]
|
| 138 |
-
# convert valid codes to list
|
| 139 |
-
valid_codes = valid_codes.tolist()
|
| 140 |
|
| 141 |
# Save as tensor
|
| 142 |
torch.save(valid_codes, output_path)
|
|
@@ -155,7 +242,78 @@ def main():
|
|
| 155 |
|
| 156 |
device = torch.device(args.device)
|
| 157 |
model = s3tokenizer.load_model(args.model).to(device)
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
if args.device == "cuda":
|
| 161 |
model = torch.nn.parallel.DistributedDataParallel(
|
|
@@ -180,28 +338,50 @@ def main():
|
|
| 180 |
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
| 181 |
|
| 182 |
processed_count = 0
|
|
|
|
|
|
|
|
|
|
| 183 |
for file_paths, mels, mels_lens in dataloader:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
codes, codes_lens = model(mels.to(device), mels_lens.to(device))
|
| 185 |
|
| 186 |
# Process each file in the batch
|
| 187 |
for i, file_path in enumerate(file_paths):
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
if rank == 0:
|
| 200 |
-
progress_bar.update(world_size * len(file_paths))
|
| 201 |
|
| 202 |
if rank == 0:
|
| 203 |
progress_bar.close()
|
| 204 |
-
print(f"\nProcessed {processed_count} files on rank {rank}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
if args.device == "cuda":
|
| 207 |
dist.barrier()
|
|
|
|
| 15 |
cpu:
|
| 16 |
|
| 17 |
s3tokenizer --root_path /path/to/audio/files \
|
| 18 |
+
--model speech_tokenizer_v1 \
|
| 19 |
--device "cpu" \
|
| 20 |
--batch_size 32
|
| 21 |
|
| 22 |
gpu:
|
| 23 |
|
| 24 |
+
torchrun --nproc_per_node=8 --nnodes=1 \
|
| 25 |
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
| 26 |
+
`which s3tokenizer` --root_path /path/to/audio/files \
|
| 27 |
+
--model speech_tokenizer_v1 \
|
| 28 |
--device "cuda" \
|
| 29 |
+
--batch_size 32
|
| 30 |
|
| 31 |
"""
|
| 32 |
|
|
|
|
| 44 |
|
| 45 |
class AudioDataset(Dataset):
|
| 46 |
|
| 47 |
+
def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3'],
|
| 48 |
+
use_cache=True, cache_file=None, max_workers=8):
|
| 49 |
self.data = []
|
| 50 |
|
| 51 |
+
# Define cache file path
|
| 52 |
+
if cache_file is None:
|
| 53 |
+
cache_file = Path(root_path) / '.audio_file_cache.pkl'
|
| 54 |
+
else:
|
| 55 |
+
cache_file = Path(cache_file)
|
| 56 |
+
|
| 57 |
+
# Try to load from cache first
|
| 58 |
+
if use_cache and cache_file.exists():
|
| 59 |
+
import pickle
|
| 60 |
+
print(f"Loading file list from cache: {cache_file}")
|
| 61 |
+
try:
|
| 62 |
+
with open(cache_file, 'rb') as f:
|
| 63 |
+
self.data = pickle.load(f)
|
| 64 |
+
print(f"Loaded {len(self.data)} files from cache")
|
| 65 |
+
return
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"Failed to load cache: {e}, scanning directory...")
|
| 68 |
+
|
| 69 |
+
# Method 1: Use os.walk() which is typically faster than pathlib
|
| 70 |
+
print(f"Scanning directory: {root_path}")
|
| 71 |
+
print(f"Looking for extensions: {extensions}")
|
| 72 |
+
|
| 73 |
+
import os
|
| 74 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 75 |
+
|
| 76 |
+
def scan_directory(args):
|
| 77 |
+
dirpath, extensions = args
|
| 78 |
+
files = []
|
| 79 |
+
try:
|
| 80 |
+
with os.scandir(dirpath) as entries:
|
| 81 |
+
for entry in entries:
|
| 82 |
+
if entry.is_file() and any(entry.name.endswith(ext) for ext in extensions):
|
| 83 |
+
files.append(Path(entry.path))
|
| 84 |
+
except PermissionError:
|
| 85 |
+
pass
|
| 86 |
+
return files
|
| 87 |
+
|
| 88 |
+
# Collect all directories first
|
| 89 |
+
all_dirs = [root_path]
|
| 90 |
+
for dirpath, dirnames, _ in os.walk(root_path):
|
| 91 |
+
all_dirs.extend(os.path.join(dirpath, d) for d in dirnames)
|
| 92 |
+
|
| 93 |
+
# Process directories in parallel
|
| 94 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 95 |
+
futures = [executor.submit(scan_directory, (d, extensions)) for d in all_dirs]
|
| 96 |
+
|
| 97 |
+
with tqdm(total=len(all_dirs), desc="Scanning directories") as pbar:
|
| 98 |
+
for future in as_completed(futures):
|
| 99 |
+
self.data.extend(future.result())
|
| 100 |
+
pbar.update(1)
|
| 101 |
|
| 102 |
# Sort for consistent ordering
|
| 103 |
self.data.sort()
|
|
|
|
| 106 |
raise ValueError(f"No audio files found in {root_path}")
|
| 107 |
|
| 108 |
print(f"Found {len(self.data)} audio files")
|
| 109 |
+
|
| 110 |
+
# Save to cache
|
| 111 |
+
if use_cache:
|
| 112 |
+
try:
|
| 113 |
+
import pickle
|
| 114 |
+
print(f"Saving file list to cache: {cache_file}")
|
| 115 |
+
cache_file.parent.mkdir(exist_ok=True)
|
| 116 |
+
with open(cache_file, 'wb') as f:
|
| 117 |
+
pickle.dump(self.data, f)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Failed to save cache: {e}")
|
| 120 |
|
| 121 |
def __len__(self):
|
| 122 |
return len(self.data)
|
| 123 |
|
| 124 |
def __getitem__(self, idx):
|
| 125 |
file_path = self.data[idx]
|
| 126 |
+
try:
|
| 127 |
+
audio = s3tokenizer.load_audio(str(file_path))
|
| 128 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 129 |
+
return file_path, mel
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"Error processing {file_path}: {e}")
|
| 132 |
+
return None, None
|
| 133 |
|
| 134 |
|
| 135 |
def collate_fn(batch):
|
| 136 |
+
# Filter out None entries (failed files)
|
| 137 |
+
batch = [item for item in batch if item[0] is not None]
|
| 138 |
+
|
| 139 |
+
if len(batch) == 0:
|
| 140 |
+
return [], None, None
|
| 141 |
+
|
| 142 |
file_paths = [item[0] for item in batch]
|
| 143 |
mels = [item[1] for item in batch]
|
| 144 |
mels, mels_lens = s3tokenizer.padding(mels)
|
|
|
|
| 191 |
nargs='+',
|
| 192 |
default=['.wav', '.flac', '.mp3'],
|
| 193 |
help='audio file extensions to process')
|
| 194 |
+
parser.add_argument('--use_cache',
|
| 195 |
+
action='store_true',
|
| 196 |
+
help='use cached file list to avoid re-scanning')
|
| 197 |
+
parser.add_argument('--no_cache',
|
| 198 |
+
action='store_true',
|
| 199 |
+
help='force re-scan even if cache exists')
|
| 200 |
+
parser.add_argument('--cache_file',
|
| 201 |
+
type=str,
|
| 202 |
+
default=None,
|
| 203 |
+
help='path to cache file (default: root_path/.audio_file_cache.pkl)')
|
| 204 |
+
parser.add_argument('--scan_workers',
|
| 205 |
+
type=int,
|
| 206 |
+
default=8,
|
| 207 |
+
help='number of workers for directory scanning')
|
| 208 |
+
parser.add_argument('--file_list',
|
| 209 |
+
type=str,
|
| 210 |
+
default=None,
|
| 211 |
+
help='path to pre-generated file list (one file per line)')
|
| 212 |
+
parser.add_argument('--skip_existing',
|
| 213 |
+
action='store_true',
|
| 214 |
+
help='skip files that already have _fsq.pt output')
|
| 215 |
args = parser.parse_args()
|
| 216 |
return args
|
| 217 |
|
|
|
|
| 224 |
|
| 225 |
# Extract only valid codes (up to codes_len)
|
| 226 |
valid_codes = codes[:codes_len]
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# Save as tensor
|
| 229 |
torch.save(valid_codes, output_path)
|
|
|
|
| 242 |
|
| 243 |
device = torch.device(args.device)
|
| 244 |
model = s3tokenizer.load_model(args.model).to(device)
|
| 245 |
+
|
| 246 |
+
# Handle different data loading methods
|
| 247 |
+
if args.file_list:
|
| 248 |
+
# Option 3: Load from pre-generated file list
|
| 249 |
+
print(f"Loading file list from: {args.file_list}")
|
| 250 |
+
with open(args.file_list, 'r') as f:
|
| 251 |
+
file_paths = [Path(line.strip()) for line in f if line.strip()]
|
| 252 |
+
|
| 253 |
+
# Filter by extensions if specified
|
| 254 |
+
if args.extensions:
|
| 255 |
+
file_paths = [f for f in file_paths if any(str(f).endswith(ext) for ext in args.extensions)]
|
| 256 |
+
|
| 257 |
+
# Create a simple dataset
|
| 258 |
+
class FileListDataset(Dataset):
|
| 259 |
+
def __init__(self, file_paths, skip_existing=False):
|
| 260 |
+
self.data = []
|
| 261 |
+
skipped_existing = 0
|
| 262 |
+
for fp in file_paths:
|
| 263 |
+
if skip_existing:
|
| 264 |
+
output_path = fp.with_suffix('').with_suffix('.pt')
|
| 265 |
+
output_path = output_path.parent / f"{output_path.stem}_fsq.pt"
|
| 266 |
+
if output_path.exists():
|
| 267 |
+
skipped_existing += 1
|
| 268 |
+
continue
|
| 269 |
+
self.data.append(fp)
|
| 270 |
+
print(f"Will process {len(self.data)} files")
|
| 271 |
+
if skip_existing and skipped_existing > 0:
|
| 272 |
+
print(f"Skipped {skipped_existing} already processed files")
|
| 273 |
+
|
| 274 |
+
def __len__(self):
|
| 275 |
+
return len(self.data)
|
| 276 |
+
|
| 277 |
+
def __getitem__(self, idx):
|
| 278 |
+
file_path = self.data[idx]
|
| 279 |
+
try:
|
| 280 |
+
# Check if file exists
|
| 281 |
+
if not file_path.exists():
|
| 282 |
+
print(f"File not found: {file_path}")
|
| 283 |
+
return None, None
|
| 284 |
+
|
| 285 |
+
# Check if it's a file (not directory)
|
| 286 |
+
if not file_path.is_file():
|
| 287 |
+
print(f"Not a file: {file_path}")
|
| 288 |
+
return None, None
|
| 289 |
+
|
| 290 |
+
# Try to load audio
|
| 291 |
+
audio = s3tokenizer.load_audio(str(file_path))
|
| 292 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 293 |
+
return file_path, mel
|
| 294 |
+
except Exception as e:
|
| 295 |
+
print(f"Error processing {file_path}: {e}")
|
| 296 |
+
return None, None
|
| 297 |
+
|
| 298 |
+
dataset = FileListDataset(file_paths, skip_existing=args.skip_existing)
|
| 299 |
+
else:
|
| 300 |
+
# Use the enhanced AudioDataset with caching
|
| 301 |
+
dataset = AudioDataset(
|
| 302 |
+
args.root_path,
|
| 303 |
+
args.extensions,
|
| 304 |
+
use_cache=not args.no_cache,
|
| 305 |
+
cache_file=args.cache_file,
|
| 306 |
+
max_workers=args.scan_workers
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Filter out existing files if requested
|
| 310 |
+
if args.skip_existing:
|
| 311 |
+
original_count = len(dataset.data)
|
| 312 |
+
dataset.data = [
|
| 313 |
+
fp for fp in dataset.data
|
| 314 |
+
if not (fp.parent / f"{fp.stem}_fsq.pt").exists()
|
| 315 |
+
]
|
| 316 |
+
print(f"Skipping {original_count - len(dataset.data)} already processed files")
|
| 317 |
|
| 318 |
if args.device == "cuda":
|
| 319 |
model = torch.nn.parallel.DistributedDataParallel(
|
|
|
|
| 338 |
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
| 339 |
|
| 340 |
processed_count = 0
|
| 341 |
+
failed_count = 0
|
| 342 |
+
failed_files = []
|
| 343 |
+
|
| 344 |
for file_paths, mels, mels_lens in dataloader:
|
| 345 |
+
# Skip empty batches (all files failed)
|
| 346 |
+
if len(file_paths) == 0:
|
| 347 |
+
continue
|
| 348 |
+
|
| 349 |
codes, codes_lens = model(mels.to(device), mels_lens.to(device))
|
| 350 |
|
| 351 |
# Process each file in the batch
|
| 352 |
for i, file_path in enumerate(file_paths):
|
| 353 |
+
try:
|
| 354 |
+
code = codes[i]
|
| 355 |
+
code_len = codes_lens[i].item()
|
| 356 |
+
|
| 357 |
+
# Save tokens as .pt file
|
| 358 |
+
output_path = save_tokens(file_path, code, code_len)
|
| 359 |
+
|
| 360 |
+
if rank == 0 and processed_count < 10: # Only show first 10 to avoid spam
|
| 361 |
+
tqdm.write(f"Saved: {file_path} -> {output_path}")
|
| 362 |
+
|
| 363 |
+
processed_count += 1
|
| 364 |
+
except Exception as e:
|
| 365 |
+
failed_count += 1
|
| 366 |
+
failed_files.append(str(file_path))
|
| 367 |
+
if rank == 0:
|
| 368 |
+
tqdm.write(f"Failed to save {file_path}: {e}")
|
| 369 |
|
| 370 |
if rank == 0:
|
| 371 |
+
progress_bar.update(world_size * (len(file_paths) + failed_count))
|
| 372 |
|
| 373 |
if rank == 0:
|
| 374 |
progress_bar.close()
|
| 375 |
+
print(f"\nProcessed {processed_count} files successfully on rank {rank}")
|
| 376 |
+
if failed_count > 0:
|
| 377 |
+
print(f"Failed to process {failed_count} files")
|
| 378 |
+
|
| 379 |
+
# Save failed files list
|
| 380 |
+
failed_list_path = Path(args.root_path if not args.file_list else ".") / "failed_files.txt"
|
| 381 |
+
with open(failed_list_path, 'w') as f:
|
| 382 |
+
for ff in failed_files:
|
| 383 |
+
f.write(f"{ff}\n")
|
| 384 |
+
print(f"Failed files saved to: {failed_list_path}")
|
| 385 |
|
| 386 |
if args.device == "cuda":
|
| 387 |
dist.barrier()
|