import io import json import math import os import shutil import subprocess import sys import tempfile import time import urllib.request import warnings from datetime import datetime from pathlib import Path from typing import List, Optional, Tuple import gradio as gr import librosa import numpy as np import soundfile as sf import torch import torchaudio from datasets import ( Audio, Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value, concatenate_datasets, load_dataset, ) from dotenv import load_dotenv from huggingface_hub import HfApi, hf_hub_download from rich.console import Console try: import spaces except Exception: # pragma: no cover - spaces is only available on HF Spaces class _SpacesFallback: def GPU(self, *args, **kwargs): def decorator(fn): return fn return decorator spaces = _SpacesFallback() # Initialize console for pretty printing console = Console() DEFAULT_SAMPLE_RATE = 16000 AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac") DEFAULT_TARGET_DBFS = -20.0 DEFAULT_MAX_BOOST_DB = 20.0 DEFAULT_MAX_ATTEN_DB = 10.0 DEFAULT_AUTO_RESUME = bool(os.getenv("SPACE_ID")) DEFAULT_ZERO_GPU_SHARD_SIZE = int( os.getenv("CHIZZLER_ZERO_GPU_SHARD_SIZE", "25") ) DEFAULT_ZERO_GPU_MAX_SHARDS = int( os.getenv("CHIZZLER_ZERO_GPU_MAX_SHARDS", "1") ) SPACE_ID = os.getenv("SPACE_ID") warnings.filterwarnings( "ignore", message=( "LoginButton created outside of a Blocks context\\. " "May not work unless you call its `activate\\(\\)` method manually\\." ), ) def log_progress(message: str, level: int = 1, enabled: bool = True) -> None: """Log a progress message with timestamp and indentation.""" if not enabled: return indent = " " * (level - 1) timestamp = datetime.now().strftime("%H:%M:%S") console.print(f"[dim]{timestamp}[/dim] {indent}[bold blue]>[/bold blue] {message}") sys.stdout.flush() # Load environment variables load_dotenv() _GPU_DURATION = os.getenv("CHIZZLER_GPU_DURATION") _GPU_DURATION_MAX = os.getenv("CHIZZLER_GPU_DURATION_MAX") if _GPU_DURATION is not None: try: DEFAULT_GPU_DURATION = int(_GPU_DURATION) except ValueError: DEFAULT_GPU_DURATION = 0 else: DEFAULT_GPU_DURATION = 0 if _GPU_DURATION_MAX is not None: try: max_duration = int(_GPU_DURATION_MAX) if max_duration > 0: DEFAULT_GPU_DURATION = min(DEFAULT_GPU_DURATION, max_duration) except ValueError: pass def gpu_decorator(duration: int): if duration and duration > 0: try: return spaces.GPU(duration=duration) except TypeError: return spaces.GPU return spaces.GPU def get_hf_token() -> Optional[str]: return ( os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_API_TOKEN") ) def normalize_dataset_id(value: str) -> str: if not value: return "" value = value.strip() if value.startswith("http"): if "datasets/" in value: value = value.split("datasets/", 1)[1] elif "huggingface.co/" in value: value = value.split("huggingface.co/", 1)[1] value = value.split("?")[0].split("#")[0].strip("/") parts = [part for part in value.split("/") if part] if len(parts) >= 2: return "/".join(parts[:2]) return value CURRENT_DIR = Path(__file__).parent.resolve() DEFAULT_MP_SENET_DIR = Path(os.getenv("MPSENET_DIR", CURRENT_DIR / "MP-SENet")) MPSENET_GIT_REPO = os.getenv( "MPSENET_GIT_REPO", "https://github.com/yxlu-0102/MP-SENet.git" ) CACHE_DIR = Path(os.getenv("CHIZZLER_CACHE_DIR", CURRENT_DIR / "chizzler_cache")) _ENV_MAX_SHARDS = os.getenv("CHIZZLER_MAX_SHARDS_PER_RUN") if _ENV_MAX_SHARDS is not None: DEFAULT_MAX_SHARDS_PER_RUN = int(_ENV_MAX_SHARDS) else: DEFAULT_MAX_SHARDS_PER_RUN = 1 if os.getenv("SPACE_ID") else 0 _ENV_CACHE_TO_HUB = os.getenv("CHIZZLER_CACHE_TO_HUB") if _ENV_CACHE_TO_HUB is None: DEFAULT_CACHE_TO_HUB = bool(os.getenv("SPACE_ID")) else: DEFAULT_CACHE_TO_HUB = _ENV_CACHE_TO_HUB.strip().lower() in ("1", "true", "yes") def ensure_mpsenet_repo() -> Path: if DEFAULT_MP_SENET_DIR.exists(): return DEFAULT_MP_SENET_DIR auto_download = os.getenv("MPSENET_AUTO_DOWNLOAD") == "1" or os.getenv("SPACE_ID") if auto_download: log_progress("MP-SENet repo not found. Cloning...", 2) try: subprocess.run( [ "git", "clone", "--depth=1", MPSENET_GIT_REPO, str(DEFAULT_MP_SENET_DIR), ], check=True, ) except Exception as exc: raise RuntimeError( "Failed to clone MP-SENet. Clone it manually or set MPSENET_DIR." ) from exc return DEFAULT_MP_SENET_DIR raise FileNotFoundError( "MP-SENet repo not found. Clone it into MP-SENet/ or set MPSENET_DIR." ) def resolve_mpsenet_files(mp_senet_dir: Path) -> Tuple[Path, Path]: config_path = mp_senet_dir / "best_ckpt" / "config.json" ckpt_path = mp_senet_dir / "best_ckpt" / "g_best_dns" if config_path.exists() and ckpt_path.exists(): return config_path, ckpt_path repo_id = os.getenv("MPSENET_REPO") if repo_id: config_filename = os.getenv("MPSENET_CONFIG_FILENAME", "config.json") ckpt_filename = os.getenv("MPSENET_CKPT_FILENAME", "g_best_dns") config_path = Path( hf_hub_download(repo_id=repo_id, filename=config_filename) ) ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_filename)) return config_path, ckpt_path raise FileNotFoundError( "MP-SENet checkpoint files not found. Place best_ckpt/config.json and " "best_ckpt/g_best_dns under MP-SENet/ or set MPSENET_REPO." ) mp_senet_dir = ensure_mpsenet_repo() sys.path.append(str(mp_senet_dir)) from dataset import mag_pha_istft, mag_pha_stft # noqa: E402 from env import AttrDict # noqa: E402 from models.model import MPNet # noqa: E402 def select_device() -> torch.device: override = os.getenv("CHIZZLER_DEVICE", "").strip().lower() if override: return torch.device(override) if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def initialize_models(device_override: Optional[torch.device] = None): log_progress("Initializing models...") device = device_override or select_device() log_progress(f"Using {device.type.upper()} for all operations", 2) log_progress("Loading Silero VAD model...", 2) model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False, trust_repo=True, ) vad_model = model.to(device) log_progress("Loading MP-SENet model...", 2) config_path, ckpt_path = resolve_mpsenet_files(mp_senet_dir) with open(config_path, "r") as f: config = AttrDict(json.loads(f.read())) mpnet_model = MPNet(config).to(device) state = torch.load(ckpt_path, map_location=device) if isinstance(state, dict): if "generator" in state: state = state["generator"] elif "state_dict" in state: state = state["state_dict"] mpnet_model.load_state_dict(state) mpnet_model.eval() return vad_model, utils, mpnet_model, config, device vad_model = None vad_utils = None mpnet_model = None mpnet_config = None device = None def get_models(): global vad_model, vad_utils, mpnet_model, mpnet_config, device desired_device = select_device() if vad_model is None or mpnet_model is None or mpnet_config is None: vad_model, vad_utils, mpnet_model, mpnet_config, device = ( initialize_models(desired_device) ) return vad_model, vad_utils, mpnet_model, mpnet_config, device if device is None or str(device) != str(desired_device): log_progress(f"Moving models to {desired_device}...", 2) vad_model = vad_model.to(desired_device) mpnet_model = mpnet_model.to(desired_device) device = desired_device return vad_model, vad_utils, mpnet_model, mpnet_config, device def ensure_mono(waveform: torch.Tensor) -> torch.Tensor: if waveform.dim() == 1: return waveform.unsqueeze(0) if waveform.dim() == 2 and waveform.size(0) > waveform.size(1): waveform = waveform.transpose(0, 1) if waveform.size(0) > 1: return torch.mean(waveform, dim=0, keepdim=True) return waveform def resample_waveform( waveform: torch.Tensor, sample_rate: int, target_rate: int = DEFAULT_SAMPLE_RATE ) -> Tuple[torch.Tensor, int]: if sample_rate == target_rate: return waveform, sample_rate resampler = torchaudio.transforms.Resample(sample_rate, target_rate) return resampler(waveform), target_rate def load_audio_file(file_path: str, log: bool = True) -> Tuple[torch.Tensor, int]: log_progress(f"Loading audio: {Path(file_path).name}", enabled=log) waveform = None sample_rate = None try: waveform, sample_rate = torchaudio.load(file_path) waveform = ensure_mono(waveform) except Exception as exc: log_progress(f"torchaudio load failed: {exc}", 2, enabled=log) if waveform is None or sample_rate is None: try: data, sample_rate = sf.read( file_path, always_2d=True, dtype="float32" ) waveform = torch.from_numpy(data.T) except Exception as exc: log_progress(f"soundfile load failed: {exc}", 2, enabled=log) data, sample_rate = librosa.load( file_path, sr=None, mono=False, dtype=np.float32 ) if data.ndim == 1: data = data[None, :] waveform = torch.from_numpy(data) waveform = ensure_mono(waveform) if sample_rate != DEFAULT_SAMPLE_RATE: log_progress( f"Resampling from {sample_rate}Hz to {DEFAULT_SAMPLE_RATE}Hz...", 2, enabled=log, ) waveform, sample_rate = resample_waveform( waveform, sample_rate, DEFAULT_SAMPLE_RATE ) return waveform, sample_rate def get_speech_timestamps( waveform: torch.Tensor, sample_rate: int, threshold: float = 0.5, log: bool = True, ) -> List[dict]: log_progress("Detecting speech segments...", enabled=log) vad_model, vad_utils, _, _, _ = get_models() (get_speech_timestamps_fn, _, _, _, _) = vad_utils speech_timestamps = get_speech_timestamps_fn( waveform, vad_model, threshold=threshold, return_seconds=True, ) log_progress(f"Found {len(speech_timestamps)} speech segments", 2, enabled=log) return speech_timestamps def merge_close_segments(segments: List[dict], max_gap: float = 4.0) -> List[dict]: if not segments: return segments merged = [] current_segment = segments[0].copy() for segment in segments[1:]: gap_duration = segment["start"] - current_segment["end"] if gap_duration <= max_gap: current_segment["end"] = segment["end"] else: merged.append(current_segment) current_segment = segment.copy() merged.append(current_segment) return merged def extract_speech_waveform( waveform: torch.Tensor, sample_rate: int, segments: List[dict] ) -> Optional[torch.Tensor]: if not segments: return None parts = [] total_samples = waveform.size(1) for segment in segments: start = max(0, int(segment["start"] * sample_rate)) end = min(total_samples, int(segment["end"] * sample_rate)) if end > start: parts.append(waveform[:, start:end]) if not parts: return None return torch.cat(parts, dim=1) def denoise_audio_chunk( audio_tensor: torch.Tensor, mpnet_model: torch.nn.Module, mpnet_config: AttrDict, chunk_size: int = 5 * DEFAULT_SAMPLE_RATE, ) -> torch.Tensor: chunks = [] for i in range(0, audio_tensor.size(1), chunk_size): chunk = audio_tensor[:, i : min(i + chunk_size, audio_tensor.size(1))] energy = torch.sum(chunk**2.0, dim=1) norm_factor = torch.sqrt(chunk.size(1) / (energy + 1e-8)) chunk = chunk * norm_factor.unsqueeze(1) with torch.no_grad(): noisy_amp, noisy_pha, _ = mag_pha_stft( chunk, mpnet_config.n_fft, mpnet_config.hop_size, mpnet_config.win_size, mpnet_config.compress_factor, ) amp_g, pha_g, _ = mpnet_model(noisy_amp, noisy_pha) audio_g = mag_pha_istft( amp_g, pha_g, mpnet_config.n_fft, mpnet_config.hop_size, mpnet_config.win_size, mpnet_config.compress_factor, ) audio_g = audio_g / norm_factor.unsqueeze(1) chunks.append(audio_g) del chunk, noisy_amp, noisy_pha, amp_g, pha_g return torch.cat(chunks, dim=1) def process_waveform( waveform: torch.Tensor, sample_rate: int, threshold: float = 0.5, max_gap: float = 4.0, log: bool = True, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], str, bool]: vad_model, vad_utils, mpnet_model, mpnet_config, device = get_models() if waveform.device != device: waveform = waveform.to(device) log_progress("Stage 1: Voice Activity Detection", 2, enabled=log) speech_timestamps = get_speech_timestamps( waveform, sample_rate, threshold=threshold, log=log ) merged_timestamps = merge_close_segments(speech_timestamps, max_gap) details = ["Processing details:"] if not merged_timestamps: details.append("No speech detected in the audio.") return None, None, "\n".join(details), False for i, segment in enumerate(merged_timestamps, 1): duration = segment["end"] - segment["start"] details.append( f"Segment {i}/{len(merged_timestamps)}: " f"{segment['start']:.1f}s to {segment['end']:.1f}s " f"(duration: {duration:.1f}s)" ) vad_waveform = extract_speech_waveform( waveform, sample_rate, merged_timestamps ) if vad_waveform is None or vad_waveform.numel() == 0: details.append("No speech detected after merging segments.") return None, None, "\n".join(details), False vad_duration = vad_waveform.size(1) / sample_rate original_duration = waveform.size(1) / sample_rate if original_duration > 0: reduction = (1 - vad_duration / original_duration) * 100 else: reduction = 0.0 details.append(f"VAD output duration: {vad_duration:.1f}s") details.append(f"Reduced by: {reduction:.1f}%") log_progress("Stage 2: MP-SENet denoising", 2, enabled=log) with torch.no_grad(): denoised_waveform = denoise_audio_chunk( vad_waveform, mpnet_model, mpnet_config ) return vad_waveform, denoised_waveform, "\n".join(details), True def process_audio_file( audio_path: str, threshold: float = 0.5, max_gap: float = 4.0, normalize_audio: bool = True, target_dbfs: float = DEFAULT_TARGET_DBFS, max_boost_db: float = DEFAULT_MAX_BOOST_DB, max_atten_db: float = DEFAULT_MAX_ATTEN_DB, ) -> Tuple[str, str, str, str]: log_progress(f"Processing: {Path(audio_path).name}") waveform, sample_rate = load_audio_file(audio_path) _, _, _, mpnet_config, _ = get_models() vad_waveform, denoised_waveform, details, has_speech = process_waveform( waveform, sample_rate, threshold=threshold, max_gap=max_gap, log=True ) if not has_speech or vad_waveform is None or denoised_waveform is None: return audio_path, audio_path, audio_path, details if normalize_audio: vad_waveform = normalize_waveform( vad_waveform, target_dbfs=target_dbfs, max_boost_db=max_boost_db, max_atten_db=max_atten_db, ) denoised_waveform = normalize_waveform( denoised_waveform, target_dbfs=target_dbfs, max_boost_db=max_boost_db, max_atten_db=max_atten_db, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vad_file, \ tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as denoised_file: vad_path = vad_file.name denoised_path = denoised_file.name sf.write( vad_path, vad_waveform.squeeze().cpu().numpy(), mpnet_config.sampling_rate, ) sf.write( denoised_path, denoised_waveform.squeeze().cpu().numpy(), mpnet_config.sampling_rate, ) return audio_path, vad_path, denoised_path, details def load_audio_bytes(audio_bytes: bytes, log: bool = False) -> Tuple[torch.Tensor, int]: data, sample_rate = sf.read( io.BytesIO(audio_bytes), always_2d=True, dtype="float32" ) waveform = torch.from_numpy(data.T) waveform = ensure_mono(waveform) if sample_rate != DEFAULT_SAMPLE_RATE: log_progress( f"Resampling from {sample_rate}Hz to {DEFAULT_SAMPLE_RATE}Hz...", 2, enabled=log, ) waveform, sample_rate = resample_waveform( waveform, sample_rate, DEFAULT_SAMPLE_RATE ) return waveform, sample_rate def normalize_waveform( waveform: Optional[torch.Tensor], target_dbfs: float = DEFAULT_TARGET_DBFS, max_boost_db: float = DEFAULT_MAX_BOOST_DB, max_atten_db: float = DEFAULT_MAX_ATTEN_DB, ) -> Optional[torch.Tensor]: if waveform is None or waveform.numel() == 0: return waveform rms = torch.sqrt(torch.mean(waveform**2)) if not torch.isfinite(rms) or rms <= 1e-8: return waveform current_db = 20.0 * torch.log10(rms) gain_db = target_dbfs - current_db gain_db = torch.clamp(gain_db, -max_atten_db, max_boost_db) gain = torch.pow(torch.tensor(10.0, device=waveform.device), gain_db / 20.0) normalized = waveform * gain return torch.clamp(normalized, -1.0, 1.0) def _is_http_url(value: str) -> bool: return value.startswith("http://") or value.startswith("https://") def _parse_hf_dataset_uri(uri: str) -> Optional[Tuple[str, str, Optional[str]]]: prefix = "hf://datasets/" if not uri.startswith(prefix): return None rest = uri[len(prefix) :] if "/" not in rest: return None repo_id, file_path = rest.split("/", 1) revision = None if "@" in repo_id: repo_id, revision = repo_id.split("@", 1) return repo_id, file_path, revision def load_audio_url(url: str, token: Optional[str], log: bool = False) -> Tuple[torch.Tensor, int]: headers = {} if token and "huggingface.co" in url: headers["Authorization"] = f"Bearer {token}" request = urllib.request.Request(url, headers=headers) with urllib.request.urlopen(request) as response: data = response.read() return load_audio_bytes(data, log=log) def resolve_audio_path( path: str, dataset_id: Optional[str], token: Optional[str] ) -> str: if os.path.exists(path): return path parsed = _parse_hf_dataset_uri(path) if parsed: repo_id, filename, revision = parsed try: return hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=filename, revision=revision, token=token, ) except Exception: return path if dataset_id and not os.path.isabs(path): try: return hf_hub_download( repo_id=dataset_id, repo_type="dataset", filename=path, token=token, ) except Exception: return path return path def prepare_waveform_from_entry( entry, log: bool = False, dataset_id: Optional[str] = None, token: Optional[str] = None, ) -> Tuple[torch.Tensor, int]: if entry is None: raise ValueError("Empty audio entry.") if hasattr(entry, "get_all_samples"): samples = entry.get_all_samples() waveform = samples.data sample_rate = samples.sample_rate waveform = ensure_mono(waveform) if sample_rate != DEFAULT_SAMPLE_RATE: waveform, sample_rate = resample_waveform( waveform, sample_rate, DEFAULT_SAMPLE_RATE ) return waveform, sample_rate if isinstance(entry, dict): if entry.get("array") is not None: sample_rate = entry.get("sampling_rate", DEFAULT_SAMPLE_RATE) waveform = torch.tensor(entry["array"], dtype=torch.float32) waveform = ensure_mono(waveform) if sample_rate != DEFAULT_SAMPLE_RATE: waveform, sample_rate = resample_waveform( waveform, sample_rate, DEFAULT_SAMPLE_RATE ) return waveform, sample_rate if entry.get("bytes"): audio_bytes = entry["bytes"] if not isinstance(audio_bytes, (bytes, bytearray)): audio_bytes = bytes(audio_bytes) return load_audio_bytes(audio_bytes, log=log) if entry.get("path"): path = resolve_audio_path(entry["path"], dataset_id, token) if _is_http_url(path): return load_audio_url(path, token, log=log) return load_audio_file(path, log=log) if isinstance(entry, str): path = resolve_audio_path(entry, dataset_id, token) if _is_http_url(path): return load_audio_url(path, token, log=log) return load_audio_file(path, log=log) raise ValueError("Unsupported audio entry format.") def get_dataset_cache_dir(dataset_id: str, config: Optional[str]) -> Path: slug = dataset_id.replace("/", "__") if config: slug = f"{slug}__{config}" return CACHE_DIR / slug def get_cache_slug(dataset_id: str, config: Optional[str]) -> str: slug = dataset_id.replace("/", "__") if config: slug = f"{slug}__{config}" return slug def get_hub_cache_prefix( dataset_id: str, config: Optional[str], split_name: str ) -> str: slug = get_cache_slug(dataset_id, config) return f"chizzler_cache/{slug}/{split_name}" def load_split_meta( split_cache_dir: Path, hub_cache_prefix: str, cache_on_hub: bool, repo_id: str, token: Optional[str], ) -> Optional[dict]: meta_file = split_cache_dir / "_meta.json" if meta_file.exists(): return json.loads(meta_file.read_text()) if cache_on_hub: try: meta_path = hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=f"{hub_cache_prefix}/_meta.json", token=token, ) return json.loads(Path(meta_path).read_text()) except Exception: return None return None def infer_audio_column(dataset_obj) -> Optional[str]: sample_ds = dataset_obj if isinstance(dataset_obj, (DatasetDict, IterableDatasetDict)): sample_ds = next(iter(dataset_obj.values())) if hasattr(sample_ds, "features"): for column, feature in sample_ds.features.items(): if isinstance(feature, Audio): return column if isinstance(sample_ds, Dataset) and len(sample_ds) > 0: sample = sample_ds[0] for column, value in sample.items(): if isinstance(value, dict) and ( "array" in value or "path" in value or "bytes" in value ): return column if isinstance(value, str) and value.lower().endswith(AUDIO_EXTENSIONS): return column return None def default_output_repo(source_id: str, username: str) -> str: name = source_id.split("/")[-1] suffix = "representation-chizzler" if not name.endswith(suffix): name = f"{name}-{suffix}" return f"{username}/{name}" def _apply_zero_gpu_limits( shard_size: int, max_shards: Optional[int] ) -> Tuple[int, Optional[int]]: if not os.getenv("SPACE_ID"): return shard_size, max_shards adjusted_shard_size = min(shard_size, DEFAULT_ZERO_GPU_SHARD_SIZE) if max_shards is None: adjusted_max_shards = DEFAULT_ZERO_GPU_MAX_SHARDS else: adjusted_max_shards = min(max_shards, DEFAULT_ZERO_GPU_MAX_SHARDS) return adjusted_shard_size, adjusted_max_shards @gpu_decorator(DEFAULT_GPU_DURATION) def _process_dataset_and_push_gpu( dataset_id: str, config: str, split: str, audio_column: str, output_repo: str, private_repo: bool, vad_threshold: float, max_silence_gap: float, normalize_audio: bool, target_dbfs: float, max_boost_db: float, max_atten_db: float, max_examples: Optional[float], resume_processing: bool, shard_size: Optional[float], cache_on_hub: bool, max_shards_per_run: Optional[float], request: gr.Request | None = None, progress=gr.Progress(), ) -> str: token = get_hf_token() if not token: return "Missing HF token. Set HF_TOKEN as a secret or env var." dataset_id = normalize_dataset_id(dataset_id) if not dataset_id: return "Provide a dataset ID or URL." # Ensure models are loaded on the correct device before heavy processing. get_models() config = config.strip() or None split = split.strip() audio_column = audio_column.strip() output_repo = normalize_dataset_id(output_repo) if output_repo else "" cache_on_hub = bool(cache_on_hub) normalize_audio = bool(normalize_audio) max_examples_int = int(max_examples) if max_examples and max_examples > 0 else None shard_size_int = int(shard_size) if shard_size and shard_size > 0 else 1000 max_shards_int = ( int(max_shards_per_run) if max_shards_per_run and max_shards_per_run > 0 else None ) if os.getenv("SPACE_ID"): adjusted_shard_size, adjusted_max_shards = _apply_zero_gpu_limits( shard_size_int, max_shards_int ) if adjusted_shard_size != shard_size_int: log_progress( f"ZeroGPU safe mode: shard size capped at {adjusted_shard_size}", 2, ) shard_size_int = adjusted_shard_size if adjusted_max_shards != max_shards_int: log_progress( f"ZeroGPU safe mode: max shards per run capped at {adjusted_max_shards}", 2, ) max_shards_int = adjusted_max_shards api = HfApi(token=token) username = api.whoami()["name"] repo_id = output_repo or default_output_repo(dataset_id, username) if cache_on_hub: api.create_repo( repo_id, repo_type="dataset", private=private_repo, exist_ok=True ) log_progress( f"Caching shards to Hub repo: {repo_id}", 2 ) log_progress(f"Loading dataset: {dataset_id}") progress(0, desc="Downloading dataset...") if split and split.lower() != "all": dataset_obj = load_dataset( dataset_id, name=config, split=split, token=token ) dataset_dict = DatasetDict({split: dataset_obj}) else: dataset_obj = load_dataset(dataset_id, name=config, token=token) dataset_dict = ( DatasetDict({"train": dataset_obj}) if isinstance(dataset_obj, Dataset) else dataset_obj ) progress(0.01, desc="Preparing splits...") if not audio_column: audio_column = infer_audio_column(dataset_dict) or "" if not audio_column: return ( "Could not infer audio column. Please specify the audio column " "name manually." ) processed_splits = {} shards_processed = 0 cached_shards = 0 total_shards = 0 incomplete = False repo_files = set() if resume_processing and cache_on_hub: try: repo_files = set( api.list_repo_files(repo_id, repo_type="dataset") ) except Exception: repo_files = set() cache_root = get_dataset_cache_dir(dataset_id, config) cache_root.mkdir(parents=True, exist_ok=True) for split_name, split_ds in dataset_dict.items(): if ( hasattr(split_ds, "column_names") and audio_column not in split_ds.column_names ): return f"Audio column '{audio_column}' not found in split '{split_name}'." try: split_ds = split_ds.cast_column( audio_column, Audio(sampling_rate=DEFAULT_SAMPLE_RATE) ) except Exception: split_ds = split_ds.cast_column(audio_column, Audio()) total = len(split_ds) if isinstance(split_ds, Dataset) else None if max_examples_int and total is not None: total = min(total, max_examples_int) update_every = max(1, (total or max_examples_int or 100) // 100) split_cache_dir = cache_root / split_name if not resume_processing and split_cache_dir.exists(): shutil.rmtree(split_cache_dir) split_cache_dir.mkdir(parents=True, exist_ok=True) hub_cache_prefix = get_hub_cache_prefix(dataset_id, config, split_name) features = split_ds.features.copy() features[audio_column] = Audio( sampling_rate=DEFAULT_SAMPLE_RATE ) features["chizzler_ok"] = Value("bool") features["chizzler_error"] = Value("string") def make_map_fn(offset: int = 0): def map_fn(example, idx): entry = example.get(audio_column) ok = True error_message = "" try: waveform, sample_rate = prepare_waveform_from_entry( entry, log=False, dataset_id=dataset_id, token=token ) vad_waveform, denoised_waveform, _, has_speech = process_waveform( waveform, sample_rate, threshold=vad_threshold, max_gap=max_silence_gap, log=False, ) output_waveform = ( denoised_waveform if has_speech and denoised_waveform is not None else waveform ) if normalize_audio: output_waveform = normalize_waveform( output_waveform, target_dbfs=target_dbfs, max_boost_db=max_boost_db, max_atten_db=max_atten_db, ) output_np = ( output_waveform.squeeze() .detach() .cpu() .numpy() .astype(np.float32) ) if output_np.size == 0: ok = False error_message = ( "Empty output waveform; using original audio." ) output_np = ( waveform.squeeze() .detach() .cpu() .numpy() .astype(np.float32) ) output_entry = { "array": output_np, "sampling_rate": DEFAULT_SAMPLE_RATE, } except Exception as exc: ok = False error_message = str(exc) output_entry = entry if entry is not None else { "array": np.zeros(1, dtype=np.float32), "sampling_rate": DEFAULT_SAMPLE_RATE, } example[audio_column] = output_entry example["chizzler_ok"] = ok example["chizzler_error"] = error_message global_idx = offset + idx + 1 if total: if global_idx % update_every == 0 or global_idx == total: progress( global_idx / total, desc=( f"Processing {split_name}: {global_idx}/{total}" ), ) else: if global_idx % update_every == 0: progress( 0, desc=f"Processing {split_name}: {global_idx} examples", ) return example return map_fn if total: num_shards = math.ceil(total / shard_size_int) total_shards += num_shards meta = { "dataset_id": dataset_id, "config": config or "", "split": split_name, "audio_column": audio_column, "total": total, "shard_size": shard_size_int, "num_shards": num_shards, } meta_file = split_cache_dir / "_meta.json" meta_file.write_text(json.dumps(meta, indent=2)) if cache_on_hub: api.upload_file( path_or_fileobj=str(meta_file), path_in_repo=f"{hub_cache_prefix}/_meta.json", repo_id=repo_id, repo_type="dataset", ) shards = [] for shard_idx in range(num_shards): start = shard_idx * shard_size_int end = min(total, start + shard_size_int) cache_file = split_cache_dir / ( f"{split_name}-{start:07d}-{end:07d}.arrow" ) hub_cache_path = f"{hub_cache_prefix}/{cache_file.name}" if resume_processing and cache_file.exists(): processed_shard = Dataset.from_file(str(cache_file)) progress( end / total, desc=f"Processing {split_name}: {end}/{total}", ) cached_shards += 1 elif resume_processing and cache_on_hub and hub_cache_path in repo_files: cache_path = hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=hub_cache_path, token=token, ) processed_shard = Dataset.from_file(cache_path) progress( end / total, desc=f"Processing {split_name}: {end}/{total}", ) cached_shards += 1 else: if max_shards_int and shards_processed >= max_shards_int: incomplete = True break indices = range(start, end) try: shard_ds = split_ds.select(indices) except Exception: shard_ds = split_ds.select(list(indices)) processed_shard = shard_ds.map( make_map_fn(offset=start), with_indices=True, load_from_cache_file=False, cache_file_name=str(cache_file), writer_batch_size=50, num_proc=None, features=features, desc=( f"Chizzling {split_name} " f"({shard_idx + 1}/{num_shards})" ), ) shards_processed += 1 if cache_on_hub: api.upload_file( path_or_fileobj=str(cache_file), path_in_repo=hub_cache_path, repo_id=repo_id, repo_type="dataset", ) repo_files.add(hub_cache_path) shards.append(processed_shard) if incomplete: break processed_split = ( concatenate_datasets(shards) if len(shards) > 1 else shards[0] ) else: if max_shards_int and shards_processed >= max_shards_int: incomplete = True break processed_split = split_ds.map( make_map_fn(offset=0), with_indices=True, load_from_cache_file=False, writer_batch_size=50, num_proc=None, features=features, desc=f"Chizzling {split_name}", ) shards_processed += 1 processed_splits[split_name] = processed_split if incomplete: total_done = cached_shards + shards_processed progress_note = ( f" ({total_done}/{total_shards} shards ready)" if total_shards else "" ) return ( f"Processed {shards_processed} new shard(s)" f"{f', cached {cached_shards}' if cached_shards else ''}" f"{progress_note}." " Resume with cached shards to continue." ) processed_dataset = ( DatasetDict(processed_splits) if len(processed_splits) > 1 else next(iter(processed_splits.values())) ) progress(0, desc="Uploading to the Hub...") processed_dataset.push_to_hub(repo_id, private=private_repo, token=token) progress(1.0, desc="Upload complete.") return ( f"Uploaded cleaned dataset to {repo_id} " f"(audio column: {audio_column})." ) def process_dataset_and_push( dataset_id: str, config: str, split: str, audio_column: str, output_repo: str, private_repo: bool, vad_threshold: float, max_silence_gap: float, normalize_audio: bool, target_dbfs: float, max_boost_db: float, max_atten_db: float, max_examples: Optional[float], resume_processing: bool, auto_resume: bool, shard_size: Optional[float], cache_on_hub: bool, max_shards_per_run: Optional[float], request: gr.Request | None = None, progress=gr.Progress(), ) -> str: if SPACE_ID and request is not None: headers = getattr(request, "headers", None) token_header = None if headers and hasattr(headers, "get"): token_header = headers.get("x-ip-token") if not token_header: log_progress( "ZeroGPU auth header missing. Use the Space on huggingface.co " "to attach your login to ZeroGPU quota.", 2, ) attempts = 0 while True: try: result = _process_dataset_and_push_gpu( dataset_id, config, split, audio_column, output_repo, private_repo, vad_threshold, max_silence_gap, normalize_audio, target_dbfs, max_boost_db, max_atten_db, max_examples, resume_processing, shard_size, cache_on_hub, max_shards_per_run, request=request, progress=progress, ) except Exception as exc: message = str(exc) if "ZeroGPU proxy token expired" in message: return ( "ZeroGPU login token expired. Click Process/Resume again " "to refresh your session." ) if auto_resume and "GPU task aborted" in message: attempts += 1 log_progress( f"ZeroGPU preempted. Retrying (attempt {attempts})...", 2, ) time.sleep(2) continue raise if not auto_resume: return result if "Resume with cached shards" in result: attempts += 1 log_progress( f"Auto-resume: continuing (attempt {attempts})...", 2, ) time.sleep(2) continue return result def assemble_cached_dataset_and_push( dataset_id: str, config: str, split: str, audio_column: str, output_repo: str, private_repo: bool, cache_on_hub: bool, progress=gr.Progress(), ) -> str: token = get_hf_token() if not token: return "Missing HF token. Set HF_TOKEN as a secret or env var." dataset_id = normalize_dataset_id(dataset_id) if not dataset_id: return "Provide a dataset ID or URL." config = config.strip() or None split = split.strip() audio_column = audio_column.strip() output_repo = normalize_dataset_id(output_repo) if output_repo else "" cache_on_hub = bool(cache_on_hub) api = HfApi(token=token) username = api.whoami()["name"] repo_id = output_repo or default_output_repo(dataset_id, username) cache_root = get_dataset_cache_dir(dataset_id, config) cache_slug = get_cache_slug(dataset_id, config) if split and split.lower() != "all": split_names = [split] else: if cache_on_hub: repo_files = api.list_repo_files(repo_id, repo_type="dataset") prefix = f"chizzler_cache/{cache_slug}/" split_names = sorted( { path.split("/")[2] for path in repo_files if path.startswith(prefix) and len(path.split("/")) >= 3 } ) else: split_names = sorted( [ path.name for path in cache_root.iterdir() if path.is_dir() ] ) if not split_names: return "No cached shards found. Run processing first." repo_files = set() if cache_on_hub: try: repo_files = set( api.list_repo_files(repo_id, repo_type="dataset") ) except Exception: repo_files = set() processed_splits = {} for split_name in split_names: split_cache_dir = cache_root / split_name hub_cache_prefix = get_hub_cache_prefix(dataset_id, config, split_name) meta = load_split_meta( split_cache_dir, hub_cache_prefix, cache_on_hub, repo_id, token ) if not meta: return ( f"Missing cache metadata for split '{split_name}'. " "Re-run processing to rebuild shards." ) total = int(meta.get("total", 0)) shard_size = int(meta.get("shard_size", 0)) num_shards = int(meta.get("num_shards", 0)) if not total or not shard_size or not num_shards: return ( f"Incomplete cache metadata for split '{split_name}'. " "Re-run processing to rebuild shards." ) shards = [] missing = [] for shard_idx in range(num_shards): start = shard_idx * shard_size end = min(total, start + shard_size) cache_file = split_cache_dir / ( f"{split_name}-{start:07d}-{end:07d}.arrow" ) hub_cache_path = f"{hub_cache_prefix}/{cache_file.name}" if cache_file.exists(): cache_path = str(cache_file) elif cache_on_hub and hub_cache_path in repo_files: cache_path = hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=hub_cache_path, token=token, ) else: missing.append(cache_file.name) continue shards.append(Dataset.from_file(cache_path)) if missing: return ( f"Missing {len(missing)} shard(s) for split '{split_name}'. " "Run processing with resume enabled." ) processed_splits[split_name] = ( concatenate_datasets(shards) if len(shards) > 1 else shards[0] ) processed_dataset = ( DatasetDict(processed_splits) if len(processed_splits) > 1 else next(iter(processed_splits.values())) ) progress(0, desc="Uploading to the Hub...") processed_dataset.push_to_hub(repo_id, private=private_repo, token=token) progress(1.0, desc="Upload complete.") inferred_audio_column = ( audio_column or infer_audio_column(processed_dataset) or "audio" ) return ( f"Uploaded cleaned dataset to {repo_id} " f"(audio column: {inferred_audio_column})." ) @gpu_decorator(DEFAULT_GPU_DURATION) def _gradio_single_file_gpu( audio_file, vad_threshold, max_silence_gap, normalize_audio, target_dbfs, max_boost_db, max_atten_db, request: gr.Request | None = None, ): if audio_file is None: return None, None, None, "Please upload an audio file." return process_audio_file( audio_file, threshold=vad_threshold, max_gap=max_silence_gap, normalize_audio=normalize_audio, target_dbfs=target_dbfs, max_boost_db=max_boost_db, max_atten_db=max_atten_db, ) def gradio_single_file( audio_file, vad_threshold, max_silence_gap, normalize_audio, target_dbfs, max_boost_db, max_atten_db, request: gr.Request | None = None, ): return _gradio_single_file_gpu( audio_file, vad_threshold, max_silence_gap, normalize_audio, target_dbfs, max_boost_db, max_atten_db, request=request, ) with gr.Blocks(title="Representation Chizzler") as demo: gr.Markdown( "# Representation Chizzler\n" "Two-stage audio processing: VAD-based speech extraction followed by MP-SENet " "denoising. Use the Single File tab for ad-hoc processing or the Dataset tab " "to clean and publish a dataset to the Hugging Face Hub." ) with gr.Column(): with gr.Tabs(): with gr.Tab("Single File"): audio_input = gr.Audio(label="Upload Audio File", type="filepath") vad_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.5, step=0.1, label="VAD Threshold (higher = stricter voice detection)", ) gap_slider = gr.Slider( minimum=1.0, maximum=10.0, value=4.0, step=0.5, label="Max Silence Gap (seconds)", ) normalize_checkbox = gr.Checkbox( label="Normalize volume", value=True ) target_db_slider = gr.Slider( minimum=-35.0, maximum=-10.0, value=DEFAULT_TARGET_DBFS, step=1.0, label="Target loudness (dBFS)", ) max_boost_slider = gr.Slider( minimum=0.0, maximum=30.0, value=DEFAULT_MAX_BOOST_DB, step=1.0, label="Max boost (dB)", ) max_atten_slider = gr.Slider( minimum=0.0, maximum=20.0, value=DEFAULT_MAX_ATTEN_DB, step=1.0, label="Max attenuation (dB)", ) run_button = gr.Button("Process Audio") original_audio = gr.Audio(label="Original Audio") vad_audio = gr.Audio(label="VAD Processed (Speech Only)") denoised_audio = gr.Audio(label="Final Denoised") details_box = gr.Textbox(label="Processing Details", lines=10) run_button.click( fn=gradio_single_file, inputs=[ audio_input, vad_slider, gap_slider, normalize_checkbox, target_db_slider, max_boost_slider, max_atten_slider, ], outputs=[original_audio, vad_audio, denoised_audio, details_box], concurrency_limit=1, ) with gr.Tab("Dataset to Hub"): with gr.Row(): gr.LoginButton() dataset_id_input = gr.Textbox( label="Dataset ID or URL", value="https://huggingface.co/datasets/MohammadGholizadeh/fleurs-farsi", ) config_input = gr.Textbox(label="Config (optional)", value="") split_input = gr.Textbox(label="Split (optional, or 'all')", value="dev") audio_column_input = gr.Textbox( label="Audio column (optional, auto-detect if empty)", value="" ) output_repo_input = gr.Textbox( label="Output dataset repo (optional)", value="" ) private_checkbox = gr.Checkbox(label="Create private repo", value=False) max_examples_input = gr.Number( label="Max examples per split (optional)", value=None ) resume_checkbox = gr.Checkbox( label="Resume from cached shards", value=True ) auto_resume_checkbox = gr.Checkbox( label="Auto-resume on ZeroGPU preemption", value=DEFAULT_AUTO_RESUME, ) cache_to_hub_checkbox = gr.Checkbox( label="Cache shards on Hub (recommended for ZeroGPU)", value=DEFAULT_CACHE_TO_HUB, ) shard_size_input = gr.Number( label="Shard size (examples)", value=25 ) max_shards_input = gr.Number( label="Max shards per run (ZeroGPU: 1-5, 0 = no limit)", value=DEFAULT_MAX_SHARDS_PER_RUN, ) vad_slider_ds = gr.Slider( minimum=0.1, maximum=0.9, value=0.5, step=0.1, label="VAD Threshold", ) gap_slider_ds = gr.Slider( minimum=1.0, maximum=10.0, value=4.0, step=0.5, label="Max Silence Gap (seconds)", ) normalize_checkbox_ds = gr.Checkbox( label="Normalize volume", value=True ) target_db_slider_ds = gr.Slider( minimum=-35.0, maximum=-10.0, value=DEFAULT_TARGET_DBFS, step=1.0, label="Target loudness (dBFS)", ) max_boost_slider_ds = gr.Slider( minimum=0.0, maximum=30.0, value=DEFAULT_MAX_BOOST_DB, step=1.0, label="Max boost (dB)", ) max_atten_slider_ds = gr.Slider( minimum=0.0, maximum=20.0, value=DEFAULT_MAX_ATTEN_DB, step=1.0, label="Max attenuation (dB)", ) process_button = gr.Button( "Process/Resume Dataset (cache & push when complete)" ) assemble_button = gr.Button( "Assemble & Push Cached Dataset" ) status_box = gr.Textbox(label="Status", lines=6) process_button.click( fn=process_dataset_and_push, inputs=[ dataset_id_input, config_input, split_input, audio_column_input, output_repo_input, private_checkbox, vad_slider_ds, gap_slider_ds, normalize_checkbox_ds, target_db_slider_ds, max_boost_slider_ds, max_atten_slider_ds, max_examples_input, resume_checkbox, auto_resume_checkbox, shard_size_input, cache_to_hub_checkbox, max_shards_input, ], outputs=[status_box], concurrency_limit=1, ) assemble_button.click( fn=assemble_cached_dataset_and_push, inputs=[ dataset_id_input, config_input, split_input, audio_column_input, output_repo_input, private_checkbox, cache_to_hub_checkbox, ], outputs=[status_box], concurrency_limit=1, ) demo.queue() if __name__ == "__main__": demo.launch()