primepake commited on
Commit
ea8cd35
·
1 Parent(s): 34bf06f

Edit cli s3

Browse files
speech/tools/S3Tokenizer/s3tokenizer/cli.py CHANGED
@@ -15,18 +15,18 @@
15
  cpu:
16
 
17
  s3tokenizer --root_path /path/to/audio/files \
18
- --model speech_tokenizer_v2_25hz \
19
  --device "cpu" \
20
  --batch_size 32
21
 
22
  gpu:
23
 
24
- torchrun --nproc_per_node=1 --nnodes=1 \
25
  --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
26
- `which s3tokenizer` --root_path /data/dataset \
27
- --model speech_tokenizer_v2_25hz \
28
  --device "cuda" \
29
- --batch_size 64
30
 
31
  """
32
 
@@ -44,13 +44,60 @@ import s3tokenizer
44
 
45
  class AudioDataset(Dataset):
46
 
47
- def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3']):
 
48
  self.data = []
49
 
50
- # Recursively find all audio files
51
- root = Path(root_path)
52
- for ext in extensions:
53
- self.data.extend(root.rglob(f'*{ext}'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Sort for consistent ordering
56
  self.data.sort()
@@ -59,18 +106,39 @@ class AudioDataset(Dataset):
59
  raise ValueError(f"No audio files found in {root_path}")
60
 
61
  print(f"Found {len(self.data)} audio files")
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def __len__(self):
64
  return len(self.data)
65
 
66
  def __getitem__(self, idx):
67
  file_path = self.data[idx]
68
- audio = s3tokenizer.load_audio(str(file_path))
69
- mel = s3tokenizer.log_mel_spectrogram(audio)
70
- return file_path, mel
 
 
 
 
71
 
72
 
73
  def collate_fn(batch):
 
 
 
 
 
 
74
  file_paths = [item[0] for item in batch]
75
  mels = [item[1] for item in batch]
76
  mels, mels_lens = s3tokenizer.padding(mels)
@@ -123,6 +191,27 @@ def get_args():
123
  nargs='+',
124
  default=['.wav', '.flac', '.mp3'],
125
  help='audio file extensions to process')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  args = parser.parse_args()
127
  return args
128
 
@@ -135,8 +224,6 @@ def save_tokens(file_path, codes, codes_len):
135
 
136
  # Extract only valid codes (up to codes_len)
137
  valid_codes = codes[:codes_len]
138
- # convert valid codes to list
139
- valid_codes = valid_codes.tolist()
140
 
141
  # Save as tensor
142
  torch.save(valid_codes, output_path)
@@ -155,7 +242,78 @@ def main():
155
 
156
  device = torch.device(args.device)
157
  model = s3tokenizer.load_model(args.model).to(device)
158
- dataset = AudioDataset(args.root_path, args.extensions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  if args.device == "cuda":
161
  model = torch.nn.parallel.DistributedDataParallel(
@@ -180,28 +338,50 @@ def main():
180
  progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
181
 
182
  processed_count = 0
 
 
 
183
  for file_paths, mels, mels_lens in dataloader:
 
 
 
 
184
  codes, codes_lens = model(mels.to(device), mels_lens.to(device))
185
 
186
  # Process each file in the batch
187
  for i, file_path in enumerate(file_paths):
188
- code = codes[i]
189
- code_len = codes_lens[i].item()
190
-
191
- # Save tokens as .pt file
192
- output_path = save_tokens(file_path, code, code_len)
193
-
194
- if rank == 0:
195
- tqdm.write(f"Saved: {file_path} -> {output_path}")
196
-
197
- processed_count += len(file_paths)
 
 
 
 
 
 
198
 
199
  if rank == 0:
200
- progress_bar.update(world_size * len(file_paths))
201
 
202
  if rank == 0:
203
  progress_bar.close()
204
- print(f"\nProcessed {processed_count} files on rank {rank}")
 
 
 
 
 
 
 
 
 
205
 
206
  if args.device == "cuda":
207
  dist.barrier()
 
15
  cpu:
16
 
17
  s3tokenizer --root_path /path/to/audio/files \
18
+ --model speech_tokenizer_v1 \
19
  --device "cpu" \
20
  --batch_size 32
21
 
22
  gpu:
23
 
24
+ torchrun --nproc_per_node=8 --nnodes=1 \
25
  --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
26
+ `which s3tokenizer` --root_path /path/to/audio/files \
27
+ --model speech_tokenizer_v1 \
28
  --device "cuda" \
29
+ --batch_size 32
30
 
31
  """
32
 
 
44
 
45
  class AudioDataset(Dataset):
46
 
47
+ def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3'],
48
+ use_cache=True, cache_file=None, max_workers=8):
49
  self.data = []
50
 
51
+ # Define cache file path
52
+ if cache_file is None:
53
+ cache_file = Path(root_path) / '.audio_file_cache.pkl'
54
+ else:
55
+ cache_file = Path(cache_file)
56
+
57
+ # Try to load from cache first
58
+ if use_cache and cache_file.exists():
59
+ import pickle
60
+ print(f"Loading file list from cache: {cache_file}")
61
+ try:
62
+ with open(cache_file, 'rb') as f:
63
+ self.data = pickle.load(f)
64
+ print(f"Loaded {len(self.data)} files from cache")
65
+ return
66
+ except Exception as e:
67
+ print(f"Failed to load cache: {e}, scanning directory...")
68
+
69
+ # Method 1: Use os.walk() which is typically faster than pathlib
70
+ print(f"Scanning directory: {root_path}")
71
+ print(f"Looking for extensions: {extensions}")
72
+
73
+ import os
74
+ from concurrent.futures import ThreadPoolExecutor, as_completed
75
+
76
+ def scan_directory(args):
77
+ dirpath, extensions = args
78
+ files = []
79
+ try:
80
+ with os.scandir(dirpath) as entries:
81
+ for entry in entries:
82
+ if entry.is_file() and any(entry.name.endswith(ext) for ext in extensions):
83
+ files.append(Path(entry.path))
84
+ except PermissionError:
85
+ pass
86
+ return files
87
+
88
+ # Collect all directories first
89
+ all_dirs = [root_path]
90
+ for dirpath, dirnames, _ in os.walk(root_path):
91
+ all_dirs.extend(os.path.join(dirpath, d) for d in dirnames)
92
+
93
+ # Process directories in parallel
94
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
95
+ futures = [executor.submit(scan_directory, (d, extensions)) for d in all_dirs]
96
+
97
+ with tqdm(total=len(all_dirs), desc="Scanning directories") as pbar:
98
+ for future in as_completed(futures):
99
+ self.data.extend(future.result())
100
+ pbar.update(1)
101
 
102
  # Sort for consistent ordering
103
  self.data.sort()
 
106
  raise ValueError(f"No audio files found in {root_path}")
107
 
108
  print(f"Found {len(self.data)} audio files")
109
+
110
+ # Save to cache
111
+ if use_cache:
112
+ try:
113
+ import pickle
114
+ print(f"Saving file list to cache: {cache_file}")
115
+ cache_file.parent.mkdir(exist_ok=True)
116
+ with open(cache_file, 'wb') as f:
117
+ pickle.dump(self.data, f)
118
+ except Exception as e:
119
+ print(f"Failed to save cache: {e}")
120
 
121
  def __len__(self):
122
  return len(self.data)
123
 
124
  def __getitem__(self, idx):
125
  file_path = self.data[idx]
126
+ try:
127
+ audio = s3tokenizer.load_audio(str(file_path))
128
+ mel = s3tokenizer.log_mel_spectrogram(audio)
129
+ return file_path, mel
130
+ except Exception as e:
131
+ print(f"Error processing {file_path}: {e}")
132
+ return None, None
133
 
134
 
135
  def collate_fn(batch):
136
+ # Filter out None entries (failed files)
137
+ batch = [item for item in batch if item[0] is not None]
138
+
139
+ if len(batch) == 0:
140
+ return [], None, None
141
+
142
  file_paths = [item[0] for item in batch]
143
  mels = [item[1] for item in batch]
144
  mels, mels_lens = s3tokenizer.padding(mels)
 
191
  nargs='+',
192
  default=['.wav', '.flac', '.mp3'],
193
  help='audio file extensions to process')
194
+ parser.add_argument('--use_cache',
195
+ action='store_true',
196
+ help='use cached file list to avoid re-scanning')
197
+ parser.add_argument('--no_cache',
198
+ action='store_true',
199
+ help='force re-scan even if cache exists')
200
+ parser.add_argument('--cache_file',
201
+ type=str,
202
+ default=None,
203
+ help='path to cache file (default: root_path/.audio_file_cache.pkl)')
204
+ parser.add_argument('--scan_workers',
205
+ type=int,
206
+ default=8,
207
+ help='number of workers for directory scanning')
208
+ parser.add_argument('--file_list',
209
+ type=str,
210
+ default=None,
211
+ help='path to pre-generated file list (one file per line)')
212
+ parser.add_argument('--skip_existing',
213
+ action='store_true',
214
+ help='skip files that already have _fsq.pt output')
215
  args = parser.parse_args()
216
  return args
217
 
 
224
 
225
  # Extract only valid codes (up to codes_len)
226
  valid_codes = codes[:codes_len]
 
 
227
 
228
  # Save as tensor
229
  torch.save(valid_codes, output_path)
 
242
 
243
  device = torch.device(args.device)
244
  model = s3tokenizer.load_model(args.model).to(device)
245
+
246
+ # Handle different data loading methods
247
+ if args.file_list:
248
+ # Option 3: Load from pre-generated file list
249
+ print(f"Loading file list from: {args.file_list}")
250
+ with open(args.file_list, 'r') as f:
251
+ file_paths = [Path(line.strip()) for line in f if line.strip()]
252
+
253
+ # Filter by extensions if specified
254
+ if args.extensions:
255
+ file_paths = [f for f in file_paths if any(str(f).endswith(ext) for ext in args.extensions)]
256
+
257
+ # Create a simple dataset
258
+ class FileListDataset(Dataset):
259
+ def __init__(self, file_paths, skip_existing=False):
260
+ self.data = []
261
+ skipped_existing = 0
262
+ for fp in file_paths:
263
+ if skip_existing:
264
+ output_path = fp.with_suffix('').with_suffix('.pt')
265
+ output_path = output_path.parent / f"{output_path.stem}_fsq.pt"
266
+ if output_path.exists():
267
+ skipped_existing += 1
268
+ continue
269
+ self.data.append(fp)
270
+ print(f"Will process {len(self.data)} files")
271
+ if skip_existing and skipped_existing > 0:
272
+ print(f"Skipped {skipped_existing} already processed files")
273
+
274
+ def __len__(self):
275
+ return len(self.data)
276
+
277
+ def __getitem__(self, idx):
278
+ file_path = self.data[idx]
279
+ try:
280
+ # Check if file exists
281
+ if not file_path.exists():
282
+ print(f"File not found: {file_path}")
283
+ return None, None
284
+
285
+ # Check if it's a file (not directory)
286
+ if not file_path.is_file():
287
+ print(f"Not a file: {file_path}")
288
+ return None, None
289
+
290
+ # Try to load audio
291
+ audio = s3tokenizer.load_audio(str(file_path))
292
+ mel = s3tokenizer.log_mel_spectrogram(audio)
293
+ return file_path, mel
294
+ except Exception as e:
295
+ print(f"Error processing {file_path}: {e}")
296
+ return None, None
297
+
298
+ dataset = FileListDataset(file_paths, skip_existing=args.skip_existing)
299
+ else:
300
+ # Use the enhanced AudioDataset with caching
301
+ dataset = AudioDataset(
302
+ args.root_path,
303
+ args.extensions,
304
+ use_cache=not args.no_cache,
305
+ cache_file=args.cache_file,
306
+ max_workers=args.scan_workers
307
+ )
308
+
309
+ # Filter out existing files if requested
310
+ if args.skip_existing:
311
+ original_count = len(dataset.data)
312
+ dataset.data = [
313
+ fp for fp in dataset.data
314
+ if not (fp.parent / f"{fp.stem}_fsq.pt").exists()
315
+ ]
316
+ print(f"Skipping {original_count - len(dataset.data)} already processed files")
317
 
318
  if args.device == "cuda":
319
  model = torch.nn.parallel.DistributedDataParallel(
 
338
  progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
339
 
340
  processed_count = 0
341
+ failed_count = 0
342
+ failed_files = []
343
+
344
  for file_paths, mels, mels_lens in dataloader:
345
+ # Skip empty batches (all files failed)
346
+ if len(file_paths) == 0:
347
+ continue
348
+
349
  codes, codes_lens = model(mels.to(device), mels_lens.to(device))
350
 
351
  # Process each file in the batch
352
  for i, file_path in enumerate(file_paths):
353
+ try:
354
+ code = codes[i]
355
+ code_len = codes_lens[i].item()
356
+
357
+ # Save tokens as .pt file
358
+ output_path = save_tokens(file_path, code, code_len)
359
+
360
+ if rank == 0 and processed_count < 10: # Only show first 10 to avoid spam
361
+ tqdm.write(f"Saved: {file_path} -> {output_path}")
362
+
363
+ processed_count += 1
364
+ except Exception as e:
365
+ failed_count += 1
366
+ failed_files.append(str(file_path))
367
+ if rank == 0:
368
+ tqdm.write(f"Failed to save {file_path}: {e}")
369
 
370
  if rank == 0:
371
+ progress_bar.update(world_size * (len(file_paths) + failed_count))
372
 
373
  if rank == 0:
374
  progress_bar.close()
375
+ print(f"\nProcessed {processed_count} files successfully on rank {rank}")
376
+ if failed_count > 0:
377
+ print(f"Failed to process {failed_count} files")
378
+
379
+ # Save failed files list
380
+ failed_list_path = Path(args.root_path if not args.file_list else ".") / "failed_files.txt"
381
+ with open(failed_list_path, 'w') as f:
382
+ for ff in failed_files:
383
+ f.write(f"{ff}\n")
384
+ print(f"Failed files saved to: {failed_list_path}")
385
 
386
  if args.device == "cuda":
387
  dist.barrier()