Reza2kn commited on
Commit
fba9ebe
·
verified ·
1 Parent(s): 2b6e54a

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +216 -70
.gitignore CHANGED
@@ -26,6 +26,7 @@ build/
26
  *.wav
27
  *.mp3
28
  *.flac
 
29
  *.ogg
30
 
31
  # macOS
 
26
  *.wav
27
  *.mp3
28
  *.flac
29
+ chizzler_cache/
30
  *.ogg
31
 
32
  # macOS
app.py CHANGED
@@ -1,5 +1,8 @@
 
1
  import json
 
2
  import os
 
3
  import subprocess
4
  import sys
5
  import tempfile
@@ -13,7 +16,16 @@ import numpy as np
13
  import soundfile as sf
14
  import torch
15
  import torchaudio
16
- from datasets import Audio, Dataset, DatasetDict, load_dataset
 
 
 
 
 
 
 
 
 
17
  from dotenv import load_dotenv
18
  from huggingface_hub import HfApi, hf_hub_download
19
  from rich.console import Console
@@ -79,6 +91,7 @@ DEFAULT_MP_SENET_DIR = Path(os.getenv("MPSENET_DIR", CURRENT_DIR / "MP-SENet"))
79
  MPSENET_GIT_REPO = os.getenv(
80
  "MPSENET_GIT_REPO", "https://github.com/yxlu-0102/MP-SENet.git"
81
  )
 
82
 
83
 
84
  def ensure_mpsenet_repo() -> Path:
@@ -410,37 +423,75 @@ def process_audio_file(
410
  return audio_path, vad_path, denoised_path, details
411
 
412
 
413
- def prepare_waveform_from_audio(audio_dict: dict) -> Tuple[torch.Tensor, int]:
414
- if not audio_dict:
415
- raise ValueError("Empty audio entry.")
416
- array = audio_dict.get("array")
417
- sample_rate = audio_dict.get("sampling_rate", DEFAULT_SAMPLE_RATE)
418
- waveform = torch.tensor(array, dtype=torch.float32)
419
  waveform = ensure_mono(waveform)
420
  if sample_rate != DEFAULT_SAMPLE_RATE:
 
 
 
 
 
421
  waveform, sample_rate = resample_waveform(
422
  waveform, sample_rate, DEFAULT_SAMPLE_RATE
423
  )
424
  return waveform, sample_rate
425
 
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  def infer_audio_column(dataset_obj) -> Optional[str]:
428
  sample_ds = dataset_obj
429
- if isinstance(dataset_obj, DatasetDict):
430
  sample_ds = next(iter(dataset_obj.values()))
431
- if isinstance(sample_ds, Dataset):
432
  for column, feature in sample_ds.features.items():
433
  if isinstance(feature, Audio):
434
  return column
435
- if len(sample_ds) > 0:
436
- sample = sample_ds[0]
437
- for column, value in sample.items():
438
- if isinstance(value, dict) and (
439
- "array" in value or "path" in value or "bytes" in value
440
- ):
441
- return column
442
- if isinstance(value, str) and value.lower().endswith(AUDIO_EXTENSIONS):
443
- return column
444
  return None
445
 
446
 
@@ -463,6 +514,8 @@ def process_dataset_and_push(
463
  vad_threshold: float,
464
  max_silence_gap: float,
465
  max_examples: Optional[float],
 
 
466
  progress=gr.Progress(),
467
  ) -> str:
468
  token = get_hf_token()
@@ -477,9 +530,11 @@ def process_dataset_and_push(
477
  split = split.strip()
478
  audio_column = audio_column.strip()
479
  output_repo = normalize_dataset_id(output_repo) if output_repo else ""
 
 
480
 
481
  log_progress(f"Loading dataset: {dataset_id}")
482
- progress(0, desc="Loading dataset...")
483
  if split and split.lower() != "all":
484
  dataset_obj = load_dataset(
485
  dataset_id, name=config, split=split, token=token
@@ -492,6 +547,7 @@ def process_dataset_and_push(
492
  if isinstance(dataset_obj, Dataset)
493
  else dataset_obj
494
  )
 
495
 
496
  if not audio_column:
497
  audio_column = infer_audio_column(dataset_dict) or ""
@@ -502,68 +558,150 @@ def process_dataset_and_push(
502
  )
503
 
504
  processed_splits = {}
 
 
 
505
  for split_name, split_ds in dataset_dict.items():
506
- if audio_column not in split_ds.column_names:
 
 
 
507
  return f"Audio column '{audio_column}' not found in split '{split_name}'."
508
 
509
- split_ds = split_ds.cast_column(
510
- audio_column, Audio(sampling_rate=DEFAULT_SAMPLE_RATE)
511
- )
 
 
 
512
 
513
- if max_examples and max_examples > 0:
514
- limit = min(int(max_examples), len(split_ds))
515
- split_ds = split_ds.select(range(limit))
516
 
517
- total = len(split_ds)
518
- update_every = max(1, total // 100) if total else 1
 
 
 
519
 
520
- def map_fn(example, idx):
521
- try:
522
- waveform, sample_rate = prepare_waveform_from_audio(
523
- example[audio_column]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  )
525
- except Exception:
526
- return {audio_column: example[audio_column]}
527
-
528
- vad_waveform, denoised_waveform, _, has_speech = process_waveform(
529
- waveform,
530
- sample_rate,
531
- threshold=vad_threshold,
532
- max_gap=max_silence_gap,
533
- log=False,
534
- )
535
 
536
- output_waveform = (
537
- denoised_waveform
538
- if has_speech and denoised_waveform is not None
539
- else waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  )
541
- output_np = (
542
- output_waveform.squeeze()
543
- .detach()
544
- .cpu()
545
- .numpy()
546
- .astype(np.float32)
 
 
 
547
  )
548
 
549
- if total and (idx % update_every == 0 or idx == total - 1):
550
- progress(
551
- (idx + 1) / total,
552
- desc=f"Processing {split_name}: {idx + 1}/{total}",
553
- )
554
- return {
555
- audio_column: {
556
- "array": output_np,
557
- "sampling_rate": DEFAULT_SAMPLE_RATE,
558
- }
559
- }
560
-
561
- processed_split = split_ds.map(
562
- map_fn,
563
- with_indices=True,
564
- desc=f"Chizzling {split_name}",
565
- num_proc=1,
566
- )
567
  processed_splits[split_name] = processed_split
568
 
569
  processed_dataset = (
@@ -649,6 +787,12 @@ with gr.Blocks(title="Representation Chizzler") as demo:
649
  max_examples_input = gr.Number(
650
  label="Max examples per split (optional)", value=None
651
  )
 
 
 
 
 
 
652
  vad_slider_ds = gr.Slider(
653
  minimum=0.1,
654
  maximum=0.9,
@@ -678,6 +822,8 @@ with gr.Blocks(title="Representation Chizzler") as demo:
678
  vad_slider_ds,
679
  gap_slider_ds,
680
  max_examples_input,
 
 
681
  ],
682
  outputs=[status_box],
683
  concurrency_limit=1,
 
1
+ import io
2
  import json
3
+ import math
4
  import os
5
+ import shutil
6
  import subprocess
7
  import sys
8
  import tempfile
 
16
  import soundfile as sf
17
  import torch
18
  import torchaudio
19
+ from datasets import (
20
+ Audio,
21
+ Dataset,
22
+ DatasetDict,
23
+ IterableDataset,
24
+ IterableDatasetDict,
25
+ Value,
26
+ concatenate_datasets,
27
+ load_dataset,
28
+ )
29
  from dotenv import load_dotenv
30
  from huggingface_hub import HfApi, hf_hub_download
31
  from rich.console import Console
 
91
  MPSENET_GIT_REPO = os.getenv(
92
  "MPSENET_GIT_REPO", "https://github.com/yxlu-0102/MP-SENet.git"
93
  )
94
+ CACHE_DIR = Path(os.getenv("CHIZZLER_CACHE_DIR", CURRENT_DIR / "chizzler_cache"))
95
 
96
 
97
  def ensure_mpsenet_repo() -> Path:
 
423
  return audio_path, vad_path, denoised_path, details
424
 
425
 
426
+ def load_audio_bytes(audio_bytes: bytes, log: bool = False) -> Tuple[torch.Tensor, int]:
427
+ data, sample_rate = sf.read(
428
+ io.BytesIO(audio_bytes), always_2d=True, dtype="float32"
429
+ )
430
+ waveform = torch.from_numpy(data.T)
 
431
  waveform = ensure_mono(waveform)
432
  if sample_rate != DEFAULT_SAMPLE_RATE:
433
+ log_progress(
434
+ f"Resampling from {sample_rate}Hz to {DEFAULT_SAMPLE_RATE}Hz...",
435
+ 2,
436
+ enabled=log,
437
+ )
438
  waveform, sample_rate = resample_waveform(
439
  waveform, sample_rate, DEFAULT_SAMPLE_RATE
440
  )
441
  return waveform, sample_rate
442
 
443
 
444
+ def prepare_waveform_from_entry(entry, log: bool = False) -> Tuple[torch.Tensor, int]:
445
+ if entry is None:
446
+ raise ValueError("Empty audio entry.")
447
+
448
+ if isinstance(entry, dict):
449
+ if entry.get("array") is not None:
450
+ sample_rate = entry.get("sampling_rate", DEFAULT_SAMPLE_RATE)
451
+ waveform = torch.tensor(entry["array"], dtype=torch.float32)
452
+ waveform = ensure_mono(waveform)
453
+ if sample_rate != DEFAULT_SAMPLE_RATE:
454
+ waveform, sample_rate = resample_waveform(
455
+ waveform, sample_rate, DEFAULT_SAMPLE_RATE
456
+ )
457
+ return waveform, sample_rate
458
+
459
+ if entry.get("path"):
460
+ return load_audio_file(entry["path"], log=log)
461
+
462
+ if entry.get("bytes"):
463
+ return load_audio_bytes(entry["bytes"], log=log)
464
+
465
+ if isinstance(entry, str):
466
+ return load_audio_file(entry, log=log)
467
+
468
+ raise ValueError("Unsupported audio entry format.")
469
+
470
+
471
+ def get_dataset_cache_dir(dataset_id: str, config: Optional[str]) -> Path:
472
+ slug = dataset_id.replace("/", "__")
473
+ if config:
474
+ slug = f"{slug}__{config}"
475
+ return CACHE_DIR / slug
476
+
477
+
478
  def infer_audio_column(dataset_obj) -> Optional[str]:
479
  sample_ds = dataset_obj
480
+ if isinstance(dataset_obj, (DatasetDict, IterableDatasetDict)):
481
  sample_ds = next(iter(dataset_obj.values()))
482
+ if hasattr(sample_ds, "features"):
483
  for column, feature in sample_ds.features.items():
484
  if isinstance(feature, Audio):
485
  return column
486
+ if isinstance(sample_ds, Dataset) and len(sample_ds) > 0:
487
+ sample = sample_ds[0]
488
+ for column, value in sample.items():
489
+ if isinstance(value, dict) and (
490
+ "array" in value or "path" in value or "bytes" in value
491
+ ):
492
+ return column
493
+ if isinstance(value, str) and value.lower().endswith(AUDIO_EXTENSIONS):
494
+ return column
495
  return None
496
 
497
 
 
514
  vad_threshold: float,
515
  max_silence_gap: float,
516
  max_examples: Optional[float],
517
+ resume_processing: bool,
518
+ shard_size: Optional[float],
519
  progress=gr.Progress(),
520
  ) -> str:
521
  token = get_hf_token()
 
530
  split = split.strip()
531
  audio_column = audio_column.strip()
532
  output_repo = normalize_dataset_id(output_repo) if output_repo else ""
533
+ max_examples_int = int(max_examples) if max_examples and max_examples > 0 else None
534
+ shard_size_int = int(shard_size) if shard_size and shard_size > 0 else 1000
535
 
536
  log_progress(f"Loading dataset: {dataset_id}")
537
+ progress(0, desc="Downloading dataset...")
538
  if split and split.lower() != "all":
539
  dataset_obj = load_dataset(
540
  dataset_id, name=config, split=split, token=token
 
547
  if isinstance(dataset_obj, Dataset)
548
  else dataset_obj
549
  )
550
+ progress(0.01, desc="Preparing splits...")
551
 
552
  if not audio_column:
553
  audio_column = infer_audio_column(dataset_dict) or ""
 
558
  )
559
 
560
  processed_splits = {}
561
+ cache_root = get_dataset_cache_dir(dataset_id, config)
562
+ cache_root.mkdir(parents=True, exist_ok=True)
563
+
564
  for split_name, split_ds in dataset_dict.items():
565
+ if (
566
+ hasattr(split_ds, "column_names")
567
+ and audio_column not in split_ds.column_names
568
+ ):
569
  return f"Audio column '{audio_column}' not found in split '{split_name}'."
570
 
571
+ try:
572
+ split_ds = split_ds.cast_column(audio_column, Audio(decode=False))
573
+ except Exception:
574
+ split_ds = split_ds.cast_column(
575
+ audio_column, Audio(sampling_rate=DEFAULT_SAMPLE_RATE, decode=False)
576
+ )
577
 
578
+ total = len(split_ds) if isinstance(split_ds, Dataset) else None
579
+ if max_examples_int and total is not None:
580
+ total = min(total, max_examples_int)
581
 
582
+ update_every = max(1, (total or max_examples_int or 100) // 100)
583
+ split_cache_dir = cache_root / split_name
584
+ if not resume_processing and split_cache_dir.exists():
585
+ shutil.rmtree(split_cache_dir)
586
+ split_cache_dir.mkdir(parents=True, exist_ok=True)
587
 
588
+ features = split_ds.features.copy()
589
+ features[audio_column] = Audio(
590
+ sampling_rate=DEFAULT_SAMPLE_RATE, decode=False
591
+ )
592
+ features["chizzler_ok"] = Value("bool")
593
+ features["chizzler_error"] = Value("string")
594
+
595
+ def make_map_fn(offset: int = 0):
596
+ def map_fn(example, idx):
597
+ entry = example.get(audio_column)
598
+ ok = True
599
+ error_message = ""
600
+ try:
601
+ waveform, sample_rate = prepare_waveform_from_entry(
602
+ entry, log=False
603
+ )
604
+ vad_waveform, denoised_waveform, _, has_speech = process_waveform(
605
+ waveform,
606
+ sample_rate,
607
+ threshold=vad_threshold,
608
+ max_gap=max_silence_gap,
609
+ log=False,
610
+ )
611
+
612
+ output_waveform = (
613
+ denoised_waveform
614
+ if has_speech and denoised_waveform is not None
615
+ else waveform
616
+ )
617
+ output_np = (
618
+ output_waveform.squeeze()
619
+ .detach()
620
+ .cpu()
621
+ .numpy()
622
+ .astype(np.float32)
623
+ )
624
+ except Exception as exc:
625
+ ok = False
626
+ error_message = str(exc)
627
+ output_np = np.zeros(1, dtype=np.float32)
628
+
629
+ example[audio_column] = {
630
+ "array": output_np,
631
+ "sampling_rate": DEFAULT_SAMPLE_RATE,
632
+ }
633
+ example["chizzler_ok"] = ok
634
+ example["chizzler_error"] = error_message
635
+
636
+ global_idx = offset + idx + 1
637
+ if total:
638
+ if global_idx % update_every == 0 or global_idx == total:
639
+ progress(
640
+ global_idx / total,
641
+ desc=(
642
+ f"Processing {split_name}: {global_idx}/{total}"
643
+ ),
644
+ )
645
+ else:
646
+ if global_idx % update_every == 0:
647
+ progress(
648
+ 0,
649
+ desc=f"Processing {split_name}: {global_idx} examples",
650
+ )
651
+ return example
652
+
653
+ return map_fn
654
+
655
+ if total:
656
+ num_shards = math.ceil(total / shard_size_int)
657
+ shards = []
658
+ for shard_idx in range(num_shards):
659
+ start = shard_idx * shard_size_int
660
+ end = min(total, start + shard_size_int)
661
+ cache_file = split_cache_dir / (
662
+ f"{split_name}-{start:07d}-{end:07d}.arrow"
663
  )
 
 
 
 
 
 
 
 
 
 
664
 
665
+ if resume_processing and cache_file.exists():
666
+ processed_shard = Dataset.from_file(str(cache_file))
667
+ progress(
668
+ end / total,
669
+ desc=f"Processing {split_name}: {end}/{total}",
670
+ )
671
+ else:
672
+ shard_ds = split_ds.select(list(range(start, end)))
673
+ processed_shard = shard_ds.map(
674
+ make_map_fn(offset=start),
675
+ with_indices=True,
676
+ load_from_cache_file=False,
677
+ cache_file_name=str(cache_file),
678
+ writer_batch_size=50,
679
+ num_proc=None,
680
+ features=features,
681
+ desc=(
682
+ f"Chizzling {split_name} "
683
+ f"({shard_idx + 1}/{num_shards})"
684
+ ),
685
+ )
686
+
687
+ shards.append(processed_shard)
688
+
689
+ processed_split = (
690
+ concatenate_datasets(shards)
691
+ if len(shards) > 1
692
+ else shards[0]
693
  )
694
+ else:
695
+ processed_split = split_ds.map(
696
+ make_map_fn(offset=0),
697
+ with_indices=True,
698
+ load_from_cache_file=False,
699
+ writer_batch_size=50,
700
+ num_proc=None,
701
+ features=features,
702
+ desc=f"Chizzling {split_name}",
703
  )
704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  processed_splits[split_name] = processed_split
706
 
707
  processed_dataset = (
 
787
  max_examples_input = gr.Number(
788
  label="Max examples per split (optional)", value=None
789
  )
790
+ resume_checkbox = gr.Checkbox(
791
+ label="Resume from cached shards", value=True
792
+ )
793
+ shard_size_input = gr.Number(
794
+ label="Shard size (examples)", value=1000
795
+ )
796
  vad_slider_ds = gr.Slider(
797
  minimum=0.1,
798
  maximum=0.9,
 
822
  vad_slider_ds,
823
  gap_slider_ds,
824
  max_examples_input,
825
+ resume_checkbox,
826
+ shard_size_input,
827
  ],
828
  outputs=[status_box],
829
  concurrency_limit=1,