Spaces:
Running
on
Zero
Running
on
Zero
| import io | |
| import json | |
| import math | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| import urllib.request | |
| import warnings | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| from datasets import ( | |
| Audio, | |
| Dataset, | |
| DatasetDict, | |
| IterableDataset, | |
| IterableDatasetDict, | |
| Value, | |
| concatenate_datasets, | |
| load_dataset, | |
| ) | |
| from dotenv import load_dotenv | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from rich.console import Console | |
| try: | |
| import spaces | |
| except Exception: # pragma: no cover - spaces is only available on HF Spaces | |
| class _SpacesFallback: | |
| def GPU(self, *args, **kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _SpacesFallback() | |
| # Initialize console for pretty printing | |
| console = Console() | |
| DEFAULT_SAMPLE_RATE = 16000 | |
| AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac") | |
| DEFAULT_TARGET_DBFS = -20.0 | |
| DEFAULT_MAX_BOOST_DB = 20.0 | |
| DEFAULT_MAX_ATTEN_DB = 10.0 | |
| DEFAULT_AUTO_RESUME = bool(os.getenv("SPACE_ID")) | |
| DEFAULT_ZERO_GPU_SHARD_SIZE = int( | |
| os.getenv("CHIZZLER_ZERO_GPU_SHARD_SIZE", "25") | |
| ) | |
| DEFAULT_ZERO_GPU_MAX_SHARDS = int( | |
| os.getenv("CHIZZLER_ZERO_GPU_MAX_SHARDS", "1") | |
| ) | |
| SPACE_ID = os.getenv("SPACE_ID") | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=( | |
| "LoginButton created outside of a Blocks context\\. " | |
| "May not work unless you call its `activate\\(\\)` method manually\\." | |
| ), | |
| ) | |
| def log_progress(message: str, level: int = 1, enabled: bool = True) -> None: | |
| """Log a progress message with timestamp and indentation.""" | |
| if not enabled: | |
| return | |
| indent = " " * (level - 1) | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| console.print(f"[dim]{timestamp}[/dim] {indent}[bold blue]>[/bold blue] {message}") | |
| sys.stdout.flush() | |
| # Load environment variables | |
| load_dotenv() | |
| _GPU_DURATION = os.getenv("CHIZZLER_GPU_DURATION") | |
| _GPU_DURATION_MAX = os.getenv("CHIZZLER_GPU_DURATION_MAX") | |
| if _GPU_DURATION is not None: | |
| try: | |
| DEFAULT_GPU_DURATION = int(_GPU_DURATION) | |
| except ValueError: | |
| DEFAULT_GPU_DURATION = 0 | |
| else: | |
| DEFAULT_GPU_DURATION = 0 | |
| if _GPU_DURATION_MAX is not None: | |
| try: | |
| max_duration = int(_GPU_DURATION_MAX) | |
| if max_duration > 0: | |
| DEFAULT_GPU_DURATION = min(DEFAULT_GPU_DURATION, max_duration) | |
| except ValueError: | |
| pass | |
| def gpu_decorator(duration: int): | |
| if duration and duration > 0: | |
| try: | |
| return spaces.GPU(duration=duration) | |
| except TypeError: | |
| return spaces.GPU | |
| return spaces.GPU | |
| def get_hf_token() -> Optional[str]: | |
| return ( | |
| os.getenv("HF_TOKEN") | |
| or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| or os.getenv("HF_API_TOKEN") | |
| ) | |
| def normalize_dataset_id(value: str) -> str: | |
| if not value: | |
| return "" | |
| value = value.strip() | |
| if value.startswith("http"): | |
| if "datasets/" in value: | |
| value = value.split("datasets/", 1)[1] | |
| elif "huggingface.co/" in value: | |
| value = value.split("huggingface.co/", 1)[1] | |
| value = value.split("?")[0].split("#")[0].strip("/") | |
| parts = [part for part in value.split("/") if part] | |
| if len(parts) >= 2: | |
| return "/".join(parts[:2]) | |
| return value | |
| CURRENT_DIR = Path(__file__).parent.resolve() | |
| DEFAULT_MP_SENET_DIR = Path(os.getenv("MPSENET_DIR", CURRENT_DIR / "MP-SENet")) | |
| MPSENET_GIT_REPO = os.getenv( | |
| "MPSENET_GIT_REPO", "https://github.com/yxlu-0102/MP-SENet.git" | |
| ) | |
| CACHE_DIR = Path(os.getenv("CHIZZLER_CACHE_DIR", CURRENT_DIR / "chizzler_cache")) | |
| _ENV_MAX_SHARDS = os.getenv("CHIZZLER_MAX_SHARDS_PER_RUN") | |
| if _ENV_MAX_SHARDS is not None: | |
| DEFAULT_MAX_SHARDS_PER_RUN = int(_ENV_MAX_SHARDS) | |
| else: | |
| DEFAULT_MAX_SHARDS_PER_RUN = 1 if os.getenv("SPACE_ID") else 0 | |
| _ENV_CACHE_TO_HUB = os.getenv("CHIZZLER_CACHE_TO_HUB") | |
| if _ENV_CACHE_TO_HUB is None: | |
| DEFAULT_CACHE_TO_HUB = bool(os.getenv("SPACE_ID")) | |
| else: | |
| DEFAULT_CACHE_TO_HUB = _ENV_CACHE_TO_HUB.strip().lower() in ("1", "true", "yes") | |
| def ensure_mpsenet_repo() -> Path: | |
| if DEFAULT_MP_SENET_DIR.exists(): | |
| return DEFAULT_MP_SENET_DIR | |
| auto_download = os.getenv("MPSENET_AUTO_DOWNLOAD") == "1" or os.getenv("SPACE_ID") | |
| if auto_download: | |
| log_progress("MP-SENet repo not found. Cloning...", 2) | |
| try: | |
| subprocess.run( | |
| [ | |
| "git", | |
| "clone", | |
| "--depth=1", | |
| MPSENET_GIT_REPO, | |
| str(DEFAULT_MP_SENET_DIR), | |
| ], | |
| check=True, | |
| ) | |
| except Exception as exc: | |
| raise RuntimeError( | |
| "Failed to clone MP-SENet. Clone it manually or set MPSENET_DIR." | |
| ) from exc | |
| return DEFAULT_MP_SENET_DIR | |
| raise FileNotFoundError( | |
| "MP-SENet repo not found. Clone it into MP-SENet/ or set MPSENET_DIR." | |
| ) | |
| def resolve_mpsenet_files(mp_senet_dir: Path) -> Tuple[Path, Path]: | |
| config_path = mp_senet_dir / "best_ckpt" / "config.json" | |
| ckpt_path = mp_senet_dir / "best_ckpt" / "g_best_dns" | |
| if config_path.exists() and ckpt_path.exists(): | |
| return config_path, ckpt_path | |
| repo_id = os.getenv("MPSENET_REPO") | |
| if repo_id: | |
| config_filename = os.getenv("MPSENET_CONFIG_FILENAME", "config.json") | |
| ckpt_filename = os.getenv("MPSENET_CKPT_FILENAME", "g_best_dns") | |
| config_path = Path( | |
| hf_hub_download(repo_id=repo_id, filename=config_filename) | |
| ) | |
| ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_filename)) | |
| return config_path, ckpt_path | |
| raise FileNotFoundError( | |
| "MP-SENet checkpoint files not found. Place best_ckpt/config.json and " | |
| "best_ckpt/g_best_dns under MP-SENet/ or set MPSENET_REPO." | |
| ) | |
| mp_senet_dir = ensure_mpsenet_repo() | |
| sys.path.append(str(mp_senet_dir)) | |
| from dataset import mag_pha_istft, mag_pha_stft # noqa: E402 | |
| from env import AttrDict # noqa: E402 | |
| from models.model import MPNet # noqa: E402 | |
| def select_device() -> torch.device: | |
| override = os.getenv("CHIZZLER_DEVICE", "").strip().lower() | |
| if override: | |
| return torch.device(override) | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def initialize_models(device_override: Optional[torch.device] = None): | |
| log_progress("Initializing models...") | |
| device = device_override or select_device() | |
| log_progress(f"Using {device.type.upper()} for all operations", 2) | |
| log_progress("Loading Silero VAD model...", 2) | |
| model, utils = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad", | |
| model="silero_vad", | |
| force_reload=False, | |
| trust_repo=True, | |
| ) | |
| vad_model = model.to(device) | |
| log_progress("Loading MP-SENet model...", 2) | |
| config_path, ckpt_path = resolve_mpsenet_files(mp_senet_dir) | |
| with open(config_path, "r") as f: | |
| config = AttrDict(json.loads(f.read())) | |
| mpnet_model = MPNet(config).to(device) | |
| state = torch.load(ckpt_path, map_location=device) | |
| if isinstance(state, dict): | |
| if "generator" in state: | |
| state = state["generator"] | |
| elif "state_dict" in state: | |
| state = state["state_dict"] | |
| mpnet_model.load_state_dict(state) | |
| mpnet_model.eval() | |
| return vad_model, utils, mpnet_model, config, device | |
| vad_model = None | |
| vad_utils = None | |
| mpnet_model = None | |
| mpnet_config = None | |
| device = None | |
| def get_models(): | |
| global vad_model, vad_utils, mpnet_model, mpnet_config, device | |
| desired_device = select_device() | |
| if vad_model is None or mpnet_model is None or mpnet_config is None: | |
| vad_model, vad_utils, mpnet_model, mpnet_config, device = ( | |
| initialize_models(desired_device) | |
| ) | |
| return vad_model, vad_utils, mpnet_model, mpnet_config, device | |
| if device is None or str(device) != str(desired_device): | |
| log_progress(f"Moving models to {desired_device}...", 2) | |
| vad_model = vad_model.to(desired_device) | |
| mpnet_model = mpnet_model.to(desired_device) | |
| device = desired_device | |
| return vad_model, vad_utils, mpnet_model, mpnet_config, device | |
| def ensure_mono(waveform: torch.Tensor) -> torch.Tensor: | |
| if waveform.dim() == 1: | |
| return waveform.unsqueeze(0) | |
| if waveform.dim() == 2 and waveform.size(0) > waveform.size(1): | |
| waveform = waveform.transpose(0, 1) | |
| if waveform.size(0) > 1: | |
| return torch.mean(waveform, dim=0, keepdim=True) | |
| return waveform | |
| def resample_waveform( | |
| waveform: torch.Tensor, sample_rate: int, target_rate: int = DEFAULT_SAMPLE_RATE | |
| ) -> Tuple[torch.Tensor, int]: | |
| if sample_rate == target_rate: | |
| return waveform, sample_rate | |
| resampler = torchaudio.transforms.Resample(sample_rate, target_rate) | |
| return resampler(waveform), target_rate | |
| def load_audio_file(file_path: str, log: bool = True) -> Tuple[torch.Tensor, int]: | |
| log_progress(f"Loading audio: {Path(file_path).name}", enabled=log) | |
| waveform = None | |
| sample_rate = None | |
| try: | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| waveform = ensure_mono(waveform) | |
| except Exception as exc: | |
| log_progress(f"torchaudio load failed: {exc}", 2, enabled=log) | |
| if waveform is None or sample_rate is None: | |
| try: | |
| data, sample_rate = sf.read( | |
| file_path, always_2d=True, dtype="float32" | |
| ) | |
| waveform = torch.from_numpy(data.T) | |
| except Exception as exc: | |
| log_progress(f"soundfile load failed: {exc}", 2, enabled=log) | |
| data, sample_rate = librosa.load( | |
| file_path, sr=None, mono=False, dtype=np.float32 | |
| ) | |
| if data.ndim == 1: | |
| data = data[None, :] | |
| waveform = torch.from_numpy(data) | |
| waveform = ensure_mono(waveform) | |
| if sample_rate != DEFAULT_SAMPLE_RATE: | |
| log_progress( | |
| f"Resampling from {sample_rate}Hz to {DEFAULT_SAMPLE_RATE}Hz...", | |
| 2, | |
| enabled=log, | |
| ) | |
| waveform, sample_rate = resample_waveform( | |
| waveform, sample_rate, DEFAULT_SAMPLE_RATE | |
| ) | |
| return waveform, sample_rate | |
| def get_speech_timestamps( | |
| waveform: torch.Tensor, | |
| sample_rate: int, | |
| threshold: float = 0.5, | |
| log: bool = True, | |
| ) -> List[dict]: | |
| log_progress("Detecting speech segments...", enabled=log) | |
| vad_model, vad_utils, _, _, _ = get_models() | |
| (get_speech_timestamps_fn, _, _, _, _) = vad_utils | |
| speech_timestamps = get_speech_timestamps_fn( | |
| waveform, | |
| vad_model, | |
| threshold=threshold, | |
| return_seconds=True, | |
| ) | |
| log_progress(f"Found {len(speech_timestamps)} speech segments", 2, enabled=log) | |
| return speech_timestamps | |
| def merge_close_segments(segments: List[dict], max_gap: float = 4.0) -> List[dict]: | |
| if not segments: | |
| return segments | |
| merged = [] | |
| current_segment = segments[0].copy() | |
| for segment in segments[1:]: | |
| gap_duration = segment["start"] - current_segment["end"] | |
| if gap_duration <= max_gap: | |
| current_segment["end"] = segment["end"] | |
| else: | |
| merged.append(current_segment) | |
| current_segment = segment.copy() | |
| merged.append(current_segment) | |
| return merged | |
| def extract_speech_waveform( | |
| waveform: torch.Tensor, sample_rate: int, segments: List[dict] | |
| ) -> Optional[torch.Tensor]: | |
| if not segments: | |
| return None | |
| parts = [] | |
| total_samples = waveform.size(1) | |
| for segment in segments: | |
| start = max(0, int(segment["start"] * sample_rate)) | |
| end = min(total_samples, int(segment["end"] * sample_rate)) | |
| if end > start: | |
| parts.append(waveform[:, start:end]) | |
| if not parts: | |
| return None | |
| return torch.cat(parts, dim=1) | |
| def denoise_audio_chunk( | |
| audio_tensor: torch.Tensor, | |
| mpnet_model: torch.nn.Module, | |
| mpnet_config: AttrDict, | |
| chunk_size: int = 5 * DEFAULT_SAMPLE_RATE, | |
| ) -> torch.Tensor: | |
| chunks = [] | |
| for i in range(0, audio_tensor.size(1), chunk_size): | |
| chunk = audio_tensor[:, i : min(i + chunk_size, audio_tensor.size(1))] | |
| energy = torch.sum(chunk**2.0, dim=1) | |
| norm_factor = torch.sqrt(chunk.size(1) / (energy + 1e-8)) | |
| chunk = chunk * norm_factor.unsqueeze(1) | |
| with torch.no_grad(): | |
| noisy_amp, noisy_pha, _ = mag_pha_stft( | |
| chunk, | |
| mpnet_config.n_fft, | |
| mpnet_config.hop_size, | |
| mpnet_config.win_size, | |
| mpnet_config.compress_factor, | |
| ) | |
| amp_g, pha_g, _ = mpnet_model(noisy_amp, noisy_pha) | |
| audio_g = mag_pha_istft( | |
| amp_g, | |
| pha_g, | |
| mpnet_config.n_fft, | |
| mpnet_config.hop_size, | |
| mpnet_config.win_size, | |
| mpnet_config.compress_factor, | |
| ) | |
| audio_g = audio_g / norm_factor.unsqueeze(1) | |
| chunks.append(audio_g) | |
| del chunk, noisy_amp, noisy_pha, amp_g, pha_g | |
| return torch.cat(chunks, dim=1) | |
| def process_waveform( | |
| waveform: torch.Tensor, | |
| sample_rate: int, | |
| threshold: float = 0.5, | |
| max_gap: float = 4.0, | |
| log: bool = True, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], str, bool]: | |
| vad_model, vad_utils, mpnet_model, mpnet_config, device = get_models() | |
| if waveform.device != device: | |
| waveform = waveform.to(device) | |
| log_progress("Stage 1: Voice Activity Detection", 2, enabled=log) | |
| speech_timestamps = get_speech_timestamps( | |
| waveform, sample_rate, threshold=threshold, log=log | |
| ) | |
| merged_timestamps = merge_close_segments(speech_timestamps, max_gap) | |
| details = ["Processing details:"] | |
| if not merged_timestamps: | |
| details.append("No speech detected in the audio.") | |
| return None, None, "\n".join(details), False | |
| for i, segment in enumerate(merged_timestamps, 1): | |
| duration = segment["end"] - segment["start"] | |
| details.append( | |
| f"Segment {i}/{len(merged_timestamps)}: " | |
| f"{segment['start']:.1f}s to {segment['end']:.1f}s " | |
| f"(duration: {duration:.1f}s)" | |
| ) | |
| vad_waveform = extract_speech_waveform( | |
| waveform, sample_rate, merged_timestamps | |
| ) | |
| if vad_waveform is None or vad_waveform.numel() == 0: | |
| details.append("No speech detected after merging segments.") | |
| return None, None, "\n".join(details), False | |
| vad_duration = vad_waveform.size(1) / sample_rate | |
| original_duration = waveform.size(1) / sample_rate | |
| if original_duration > 0: | |
| reduction = (1 - vad_duration / original_duration) * 100 | |
| else: | |
| reduction = 0.0 | |
| details.append(f"VAD output duration: {vad_duration:.1f}s") | |
| details.append(f"Reduced by: {reduction:.1f}%") | |
| log_progress("Stage 2: MP-SENet denoising", 2, enabled=log) | |
| with torch.no_grad(): | |
| denoised_waveform = denoise_audio_chunk( | |
| vad_waveform, mpnet_model, mpnet_config | |
| ) | |
| return vad_waveform, denoised_waveform, "\n".join(details), True | |
| def process_audio_file( | |
| audio_path: str, | |
| threshold: float = 0.5, | |
| max_gap: float = 4.0, | |
| normalize_audio: bool = True, | |
| target_dbfs: float = DEFAULT_TARGET_DBFS, | |
| max_boost_db: float = DEFAULT_MAX_BOOST_DB, | |
| max_atten_db: float = DEFAULT_MAX_ATTEN_DB, | |
| ) -> Tuple[str, str, str, str]: | |
| log_progress(f"Processing: {Path(audio_path).name}") | |
| waveform, sample_rate = load_audio_file(audio_path) | |
| _, _, _, mpnet_config, _ = get_models() | |
| vad_waveform, denoised_waveform, details, has_speech = process_waveform( | |
| waveform, sample_rate, threshold=threshold, max_gap=max_gap, log=True | |
| ) | |
| if not has_speech or vad_waveform is None or denoised_waveform is None: | |
| return audio_path, audio_path, audio_path, details | |
| if normalize_audio: | |
| vad_waveform = normalize_waveform( | |
| vad_waveform, | |
| target_dbfs=target_dbfs, | |
| max_boost_db=max_boost_db, | |
| max_atten_db=max_atten_db, | |
| ) | |
| denoised_waveform = normalize_waveform( | |
| denoised_waveform, | |
| target_dbfs=target_dbfs, | |
| max_boost_db=max_boost_db, | |
| max_atten_db=max_atten_db, | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vad_file, \ | |
| tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as denoised_file: | |
| vad_path = vad_file.name | |
| denoised_path = denoised_file.name | |
| sf.write( | |
| vad_path, | |
| vad_waveform.squeeze().cpu().numpy(), | |
| mpnet_config.sampling_rate, | |
| ) | |
| sf.write( | |
| denoised_path, | |
| denoised_waveform.squeeze().cpu().numpy(), | |
| mpnet_config.sampling_rate, | |
| ) | |
| return audio_path, vad_path, denoised_path, details | |
| def load_audio_bytes(audio_bytes: bytes, log: bool = False) -> Tuple[torch.Tensor, int]: | |
| data, sample_rate = sf.read( | |
| io.BytesIO(audio_bytes), always_2d=True, dtype="float32" | |
| ) | |
| waveform = torch.from_numpy(data.T) | |
| waveform = ensure_mono(waveform) | |
| if sample_rate != DEFAULT_SAMPLE_RATE: | |
| log_progress( | |
| f"Resampling from {sample_rate}Hz to {DEFAULT_SAMPLE_RATE}Hz...", | |
| 2, | |
| enabled=log, | |
| ) | |
| waveform, sample_rate = resample_waveform( | |
| waveform, sample_rate, DEFAULT_SAMPLE_RATE | |
| ) | |
| return waveform, sample_rate | |
| def normalize_waveform( | |
| waveform: Optional[torch.Tensor], | |
| target_dbfs: float = DEFAULT_TARGET_DBFS, | |
| max_boost_db: float = DEFAULT_MAX_BOOST_DB, | |
| max_atten_db: float = DEFAULT_MAX_ATTEN_DB, | |
| ) -> Optional[torch.Tensor]: | |
| if waveform is None or waveform.numel() == 0: | |
| return waveform | |
| rms = torch.sqrt(torch.mean(waveform**2)) | |
| if not torch.isfinite(rms) or rms <= 1e-8: | |
| return waveform | |
| current_db = 20.0 * torch.log10(rms) | |
| gain_db = target_dbfs - current_db | |
| gain_db = torch.clamp(gain_db, -max_atten_db, max_boost_db) | |
| gain = torch.pow(torch.tensor(10.0, device=waveform.device), gain_db / 20.0) | |
| normalized = waveform * gain | |
| return torch.clamp(normalized, -1.0, 1.0) | |
| def _is_http_url(value: str) -> bool: | |
| return value.startswith("http://") or value.startswith("https://") | |
| def _parse_hf_dataset_uri(uri: str) -> Optional[Tuple[str, str, Optional[str]]]: | |
| prefix = "hf://datasets/" | |
| if not uri.startswith(prefix): | |
| return None | |
| rest = uri[len(prefix) :] | |
| if "/" not in rest: | |
| return None | |
| repo_id, file_path = rest.split("/", 1) | |
| revision = None | |
| if "@" in repo_id: | |
| repo_id, revision = repo_id.split("@", 1) | |
| return repo_id, file_path, revision | |
| def load_audio_url(url: str, token: Optional[str], log: bool = False) -> Tuple[torch.Tensor, int]: | |
| headers = {} | |
| if token and "huggingface.co" in url: | |
| headers["Authorization"] = f"Bearer {token}" | |
| request = urllib.request.Request(url, headers=headers) | |
| with urllib.request.urlopen(request) as response: | |
| data = response.read() | |
| return load_audio_bytes(data, log=log) | |
| def resolve_audio_path( | |
| path: str, dataset_id: Optional[str], token: Optional[str] | |
| ) -> str: | |
| if os.path.exists(path): | |
| return path | |
| parsed = _parse_hf_dataset_uri(path) | |
| if parsed: | |
| repo_id, filename, revision = parsed | |
| try: | |
| return hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=filename, | |
| revision=revision, | |
| token=token, | |
| ) | |
| except Exception: | |
| return path | |
| if dataset_id and not os.path.isabs(path): | |
| try: | |
| return hf_hub_download( | |
| repo_id=dataset_id, | |
| repo_type="dataset", | |
| filename=path, | |
| token=token, | |
| ) | |
| except Exception: | |
| return path | |
| return path | |
| def prepare_waveform_from_entry( | |
| entry, | |
| log: bool = False, | |
| dataset_id: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> Tuple[torch.Tensor, int]: | |
| if entry is None: | |
| raise ValueError("Empty audio entry.") | |
| if hasattr(entry, "get_all_samples"): | |
| samples = entry.get_all_samples() | |
| waveform = samples.data | |
| sample_rate = samples.sample_rate | |
| waveform = ensure_mono(waveform) | |
| if sample_rate != DEFAULT_SAMPLE_RATE: | |
| waveform, sample_rate = resample_waveform( | |
| waveform, sample_rate, DEFAULT_SAMPLE_RATE | |
| ) | |
| return waveform, sample_rate | |
| if isinstance(entry, dict): | |
| if entry.get("array") is not None: | |
| sample_rate = entry.get("sampling_rate", DEFAULT_SAMPLE_RATE) | |
| waveform = torch.tensor(entry["array"], dtype=torch.float32) | |
| waveform = ensure_mono(waveform) | |
| if sample_rate != DEFAULT_SAMPLE_RATE: | |
| waveform, sample_rate = resample_waveform( | |
| waveform, sample_rate, DEFAULT_SAMPLE_RATE | |
| ) | |
| return waveform, sample_rate | |
| if entry.get("bytes"): | |
| audio_bytes = entry["bytes"] | |
| if not isinstance(audio_bytes, (bytes, bytearray)): | |
| audio_bytes = bytes(audio_bytes) | |
| return load_audio_bytes(audio_bytes, log=log) | |
| if entry.get("path"): | |
| path = resolve_audio_path(entry["path"], dataset_id, token) | |
| if _is_http_url(path): | |
| return load_audio_url(path, token, log=log) | |
| return load_audio_file(path, log=log) | |
| if isinstance(entry, str): | |
| path = resolve_audio_path(entry, dataset_id, token) | |
| if _is_http_url(path): | |
| return load_audio_url(path, token, log=log) | |
| return load_audio_file(path, log=log) | |
| raise ValueError("Unsupported audio entry format.") | |
| def get_dataset_cache_dir(dataset_id: str, config: Optional[str]) -> Path: | |
| slug = dataset_id.replace("/", "__") | |
| if config: | |
| slug = f"{slug}__{config}" | |
| return CACHE_DIR / slug | |
| def get_cache_slug(dataset_id: str, config: Optional[str]) -> str: | |
| slug = dataset_id.replace("/", "__") | |
| if config: | |
| slug = f"{slug}__{config}" | |
| return slug | |
| def get_hub_cache_prefix( | |
| dataset_id: str, config: Optional[str], split_name: str | |
| ) -> str: | |
| slug = get_cache_slug(dataset_id, config) | |
| return f"chizzler_cache/{slug}/{split_name}" | |
| def load_split_meta( | |
| split_cache_dir: Path, | |
| hub_cache_prefix: str, | |
| cache_on_hub: bool, | |
| repo_id: str, | |
| token: Optional[str], | |
| ) -> Optional[dict]: | |
| meta_file = split_cache_dir / "_meta.json" | |
| if meta_file.exists(): | |
| return json.loads(meta_file.read_text()) | |
| if cache_on_hub: | |
| try: | |
| meta_path = hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=f"{hub_cache_prefix}/_meta.json", | |
| token=token, | |
| ) | |
| return json.loads(Path(meta_path).read_text()) | |
| except Exception: | |
| return None | |
| return None | |
| def infer_audio_column(dataset_obj) -> Optional[str]: | |
| sample_ds = dataset_obj | |
| if isinstance(dataset_obj, (DatasetDict, IterableDatasetDict)): | |
| sample_ds = next(iter(dataset_obj.values())) | |
| if hasattr(sample_ds, "features"): | |
| for column, feature in sample_ds.features.items(): | |
| if isinstance(feature, Audio): | |
| return column | |
| if isinstance(sample_ds, Dataset) and len(sample_ds) > 0: | |
| sample = sample_ds[0] | |
| for column, value in sample.items(): | |
| if isinstance(value, dict) and ( | |
| "array" in value or "path" in value or "bytes" in value | |
| ): | |
| return column | |
| if isinstance(value, str) and value.lower().endswith(AUDIO_EXTENSIONS): | |
| return column | |
| return None | |
| def default_output_repo(source_id: str, username: str) -> str: | |
| name = source_id.split("/")[-1] | |
| suffix = "representation-chizzler" | |
| if not name.endswith(suffix): | |
| name = f"{name}-{suffix}" | |
| return f"{username}/{name}" | |
| def _apply_zero_gpu_limits( | |
| shard_size: int, max_shards: Optional[int] | |
| ) -> Tuple[int, Optional[int]]: | |
| if not os.getenv("SPACE_ID"): | |
| return shard_size, max_shards | |
| adjusted_shard_size = min(shard_size, DEFAULT_ZERO_GPU_SHARD_SIZE) | |
| if max_shards is None: | |
| adjusted_max_shards = DEFAULT_ZERO_GPU_MAX_SHARDS | |
| else: | |
| adjusted_max_shards = min(max_shards, DEFAULT_ZERO_GPU_MAX_SHARDS) | |
| return adjusted_shard_size, adjusted_max_shards | |
| def _process_dataset_and_push_gpu( | |
| dataset_id: str, | |
| config: str, | |
| split: str, | |
| audio_column: str, | |
| output_repo: str, | |
| private_repo: bool, | |
| vad_threshold: float, | |
| max_silence_gap: float, | |
| normalize_audio: bool, | |
| target_dbfs: float, | |
| max_boost_db: float, | |
| max_atten_db: float, | |
| max_examples: Optional[float], | |
| resume_processing: bool, | |
| shard_size: Optional[float], | |
| cache_on_hub: bool, | |
| max_shards_per_run: Optional[float], | |
| request: gr.Request | None = None, | |
| progress=gr.Progress(), | |
| ) -> str: | |
| token = get_hf_token() | |
| if not token: | |
| return "Missing HF token. Set HF_TOKEN as a secret or env var." | |
| dataset_id = normalize_dataset_id(dataset_id) | |
| if not dataset_id: | |
| return "Provide a dataset ID or URL." | |
| # Ensure models are loaded on the correct device before heavy processing. | |
| get_models() | |
| config = config.strip() or None | |
| split = split.strip() | |
| audio_column = audio_column.strip() | |
| output_repo = normalize_dataset_id(output_repo) if output_repo else "" | |
| cache_on_hub = bool(cache_on_hub) | |
| normalize_audio = bool(normalize_audio) | |
| max_examples_int = int(max_examples) if max_examples and max_examples > 0 else None | |
| shard_size_int = int(shard_size) if shard_size and shard_size > 0 else 1000 | |
| max_shards_int = ( | |
| int(max_shards_per_run) | |
| if max_shards_per_run and max_shards_per_run > 0 | |
| else None | |
| ) | |
| if os.getenv("SPACE_ID"): | |
| adjusted_shard_size, adjusted_max_shards = _apply_zero_gpu_limits( | |
| shard_size_int, max_shards_int | |
| ) | |
| if adjusted_shard_size != shard_size_int: | |
| log_progress( | |
| f"ZeroGPU safe mode: shard size capped at {adjusted_shard_size}", | |
| 2, | |
| ) | |
| shard_size_int = adjusted_shard_size | |
| if adjusted_max_shards != max_shards_int: | |
| log_progress( | |
| f"ZeroGPU safe mode: max shards per run capped at {adjusted_max_shards}", | |
| 2, | |
| ) | |
| max_shards_int = adjusted_max_shards | |
| api = HfApi(token=token) | |
| username = api.whoami()["name"] | |
| repo_id = output_repo or default_output_repo(dataset_id, username) | |
| if cache_on_hub: | |
| api.create_repo( | |
| repo_id, repo_type="dataset", private=private_repo, exist_ok=True | |
| ) | |
| log_progress( | |
| f"Caching shards to Hub repo: {repo_id}", 2 | |
| ) | |
| log_progress(f"Loading dataset: {dataset_id}") | |
| progress(0, desc="Downloading dataset...") | |
| if split and split.lower() != "all": | |
| dataset_obj = load_dataset( | |
| dataset_id, name=config, split=split, token=token | |
| ) | |
| dataset_dict = DatasetDict({split: dataset_obj}) | |
| else: | |
| dataset_obj = load_dataset(dataset_id, name=config, token=token) | |
| dataset_dict = ( | |
| DatasetDict({"train": dataset_obj}) | |
| if isinstance(dataset_obj, Dataset) | |
| else dataset_obj | |
| ) | |
| progress(0.01, desc="Preparing splits...") | |
| if not audio_column: | |
| audio_column = infer_audio_column(dataset_dict) or "" | |
| if not audio_column: | |
| return ( | |
| "Could not infer audio column. Please specify the audio column " | |
| "name manually." | |
| ) | |
| processed_splits = {} | |
| shards_processed = 0 | |
| cached_shards = 0 | |
| total_shards = 0 | |
| incomplete = False | |
| repo_files = set() | |
| if resume_processing and cache_on_hub: | |
| try: | |
| repo_files = set( | |
| api.list_repo_files(repo_id, repo_type="dataset") | |
| ) | |
| except Exception: | |
| repo_files = set() | |
| cache_root = get_dataset_cache_dir(dataset_id, config) | |
| cache_root.mkdir(parents=True, exist_ok=True) | |
| for split_name, split_ds in dataset_dict.items(): | |
| if ( | |
| hasattr(split_ds, "column_names") | |
| and audio_column not in split_ds.column_names | |
| ): | |
| return f"Audio column '{audio_column}' not found in split '{split_name}'." | |
| try: | |
| split_ds = split_ds.cast_column( | |
| audio_column, Audio(sampling_rate=DEFAULT_SAMPLE_RATE) | |
| ) | |
| except Exception: | |
| split_ds = split_ds.cast_column(audio_column, Audio()) | |
| total = len(split_ds) if isinstance(split_ds, Dataset) else None | |
| if max_examples_int and total is not None: | |
| total = min(total, max_examples_int) | |
| update_every = max(1, (total or max_examples_int or 100) // 100) | |
| split_cache_dir = cache_root / split_name | |
| if not resume_processing and split_cache_dir.exists(): | |
| shutil.rmtree(split_cache_dir) | |
| split_cache_dir.mkdir(parents=True, exist_ok=True) | |
| hub_cache_prefix = get_hub_cache_prefix(dataset_id, config, split_name) | |
| features = split_ds.features.copy() | |
| features[audio_column] = Audio( | |
| sampling_rate=DEFAULT_SAMPLE_RATE | |
| ) | |
| features["chizzler_ok"] = Value("bool") | |
| features["chizzler_error"] = Value("string") | |
| def make_map_fn(offset: int = 0): | |
| def map_fn(example, idx): | |
| entry = example.get(audio_column) | |
| ok = True | |
| error_message = "" | |
| try: | |
| waveform, sample_rate = prepare_waveform_from_entry( | |
| entry, log=False, dataset_id=dataset_id, token=token | |
| ) | |
| vad_waveform, denoised_waveform, _, has_speech = process_waveform( | |
| waveform, | |
| sample_rate, | |
| threshold=vad_threshold, | |
| max_gap=max_silence_gap, | |
| log=False, | |
| ) | |
| output_waveform = ( | |
| denoised_waveform | |
| if has_speech and denoised_waveform is not None | |
| else waveform | |
| ) | |
| if normalize_audio: | |
| output_waveform = normalize_waveform( | |
| output_waveform, | |
| target_dbfs=target_dbfs, | |
| max_boost_db=max_boost_db, | |
| max_atten_db=max_atten_db, | |
| ) | |
| output_np = ( | |
| output_waveform.squeeze() | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .astype(np.float32) | |
| ) | |
| if output_np.size == 0: | |
| ok = False | |
| error_message = ( | |
| "Empty output waveform; using original audio." | |
| ) | |
| output_np = ( | |
| waveform.squeeze() | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .astype(np.float32) | |
| ) | |
| output_entry = { | |
| "array": output_np, | |
| "sampling_rate": DEFAULT_SAMPLE_RATE, | |
| } | |
| except Exception as exc: | |
| ok = False | |
| error_message = str(exc) | |
| output_entry = entry if entry is not None else { | |
| "array": np.zeros(1, dtype=np.float32), | |
| "sampling_rate": DEFAULT_SAMPLE_RATE, | |
| } | |
| example[audio_column] = output_entry | |
| example["chizzler_ok"] = ok | |
| example["chizzler_error"] = error_message | |
| global_idx = offset + idx + 1 | |
| if total: | |
| if global_idx % update_every == 0 or global_idx == total: | |
| progress( | |
| global_idx / total, | |
| desc=( | |
| f"Processing {split_name}: {global_idx}/{total}" | |
| ), | |
| ) | |
| else: | |
| if global_idx % update_every == 0: | |
| progress( | |
| 0, | |
| desc=f"Processing {split_name}: {global_idx} examples", | |
| ) | |
| return example | |
| return map_fn | |
| if total: | |
| num_shards = math.ceil(total / shard_size_int) | |
| total_shards += num_shards | |
| meta = { | |
| "dataset_id": dataset_id, | |
| "config": config or "", | |
| "split": split_name, | |
| "audio_column": audio_column, | |
| "total": total, | |
| "shard_size": shard_size_int, | |
| "num_shards": num_shards, | |
| } | |
| meta_file = split_cache_dir / "_meta.json" | |
| meta_file.write_text(json.dumps(meta, indent=2)) | |
| if cache_on_hub: | |
| api.upload_file( | |
| path_or_fileobj=str(meta_file), | |
| path_in_repo=f"{hub_cache_prefix}/_meta.json", | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| ) | |
| shards = [] | |
| for shard_idx in range(num_shards): | |
| start = shard_idx * shard_size_int | |
| end = min(total, start + shard_size_int) | |
| cache_file = split_cache_dir / ( | |
| f"{split_name}-{start:07d}-{end:07d}.arrow" | |
| ) | |
| hub_cache_path = f"{hub_cache_prefix}/{cache_file.name}" | |
| if resume_processing and cache_file.exists(): | |
| processed_shard = Dataset.from_file(str(cache_file)) | |
| progress( | |
| end / total, | |
| desc=f"Processing {split_name}: {end}/{total}", | |
| ) | |
| cached_shards += 1 | |
| elif resume_processing and cache_on_hub and hub_cache_path in repo_files: | |
| cache_path = hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=hub_cache_path, | |
| token=token, | |
| ) | |
| processed_shard = Dataset.from_file(cache_path) | |
| progress( | |
| end / total, | |
| desc=f"Processing {split_name}: {end}/{total}", | |
| ) | |
| cached_shards += 1 | |
| else: | |
| if max_shards_int and shards_processed >= max_shards_int: | |
| incomplete = True | |
| break | |
| indices = range(start, end) | |
| try: | |
| shard_ds = split_ds.select(indices) | |
| except Exception: | |
| shard_ds = split_ds.select(list(indices)) | |
| processed_shard = shard_ds.map( | |
| make_map_fn(offset=start), | |
| with_indices=True, | |
| load_from_cache_file=False, | |
| cache_file_name=str(cache_file), | |
| writer_batch_size=50, | |
| num_proc=None, | |
| features=features, | |
| desc=( | |
| f"Chizzling {split_name} " | |
| f"({shard_idx + 1}/{num_shards})" | |
| ), | |
| ) | |
| shards_processed += 1 | |
| if cache_on_hub: | |
| api.upload_file( | |
| path_or_fileobj=str(cache_file), | |
| path_in_repo=hub_cache_path, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| ) | |
| repo_files.add(hub_cache_path) | |
| shards.append(processed_shard) | |
| if incomplete: | |
| break | |
| processed_split = ( | |
| concatenate_datasets(shards) | |
| if len(shards) > 1 | |
| else shards[0] | |
| ) | |
| else: | |
| if max_shards_int and shards_processed >= max_shards_int: | |
| incomplete = True | |
| break | |
| processed_split = split_ds.map( | |
| make_map_fn(offset=0), | |
| with_indices=True, | |
| load_from_cache_file=False, | |
| writer_batch_size=50, | |
| num_proc=None, | |
| features=features, | |
| desc=f"Chizzling {split_name}", | |
| ) | |
| shards_processed += 1 | |
| processed_splits[split_name] = processed_split | |
| if incomplete: | |
| total_done = cached_shards + shards_processed | |
| progress_note = ( | |
| f" ({total_done}/{total_shards} shards ready)" | |
| if total_shards | |
| else "" | |
| ) | |
| return ( | |
| f"Processed {shards_processed} new shard(s)" | |
| f"{f', cached {cached_shards}' if cached_shards else ''}" | |
| f"{progress_note}." | |
| " Resume with cached shards to continue." | |
| ) | |
| processed_dataset = ( | |
| DatasetDict(processed_splits) | |
| if len(processed_splits) > 1 | |
| else next(iter(processed_splits.values())) | |
| ) | |
| progress(0, desc="Uploading to the Hub...") | |
| processed_dataset.push_to_hub(repo_id, private=private_repo, token=token) | |
| progress(1.0, desc="Upload complete.") | |
| return ( | |
| f"Uploaded cleaned dataset to {repo_id} " | |
| f"(audio column: {audio_column})." | |
| ) | |
| def process_dataset_and_push( | |
| dataset_id: str, | |
| config: str, | |
| split: str, | |
| audio_column: str, | |
| output_repo: str, | |
| private_repo: bool, | |
| vad_threshold: float, | |
| max_silence_gap: float, | |
| normalize_audio: bool, | |
| target_dbfs: float, | |
| max_boost_db: float, | |
| max_atten_db: float, | |
| max_examples: Optional[float], | |
| resume_processing: bool, | |
| auto_resume: bool, | |
| shard_size: Optional[float], | |
| cache_on_hub: bool, | |
| max_shards_per_run: Optional[float], | |
| request: gr.Request | None = None, | |
| progress=gr.Progress(), | |
| ) -> str: | |
| if SPACE_ID and request is not None: | |
| headers = getattr(request, "headers", None) | |
| token_header = None | |
| if headers and hasattr(headers, "get"): | |
| token_header = headers.get("x-ip-token") | |
| if not token_header: | |
| log_progress( | |
| "ZeroGPU auth header missing. Use the Space on huggingface.co " | |
| "to attach your login to ZeroGPU quota.", | |
| 2, | |
| ) | |
| attempts = 0 | |
| while True: | |
| try: | |
| result = _process_dataset_and_push_gpu( | |
| dataset_id, | |
| config, | |
| split, | |
| audio_column, | |
| output_repo, | |
| private_repo, | |
| vad_threshold, | |
| max_silence_gap, | |
| normalize_audio, | |
| target_dbfs, | |
| max_boost_db, | |
| max_atten_db, | |
| max_examples, | |
| resume_processing, | |
| shard_size, | |
| cache_on_hub, | |
| max_shards_per_run, | |
| request=request, | |
| progress=progress, | |
| ) | |
| except Exception as exc: | |
| message = str(exc) | |
| if "ZeroGPU proxy token expired" in message: | |
| return ( | |
| "ZeroGPU login token expired. Click Process/Resume again " | |
| "to refresh your session." | |
| ) | |
| if auto_resume and "GPU task aborted" in message: | |
| attempts += 1 | |
| log_progress( | |
| f"ZeroGPU preempted. Retrying (attempt {attempts})...", | |
| 2, | |
| ) | |
| time.sleep(2) | |
| continue | |
| raise | |
| if not auto_resume: | |
| return result | |
| if "Resume with cached shards" in result: | |
| attempts += 1 | |
| log_progress( | |
| f"Auto-resume: continuing (attempt {attempts})...", | |
| 2, | |
| ) | |
| time.sleep(2) | |
| continue | |
| return result | |
| def assemble_cached_dataset_and_push( | |
| dataset_id: str, | |
| config: str, | |
| split: str, | |
| audio_column: str, | |
| output_repo: str, | |
| private_repo: bool, | |
| cache_on_hub: bool, | |
| progress=gr.Progress(), | |
| ) -> str: | |
| token = get_hf_token() | |
| if not token: | |
| return "Missing HF token. Set HF_TOKEN as a secret or env var." | |
| dataset_id = normalize_dataset_id(dataset_id) | |
| if not dataset_id: | |
| return "Provide a dataset ID or URL." | |
| config = config.strip() or None | |
| split = split.strip() | |
| audio_column = audio_column.strip() | |
| output_repo = normalize_dataset_id(output_repo) if output_repo else "" | |
| cache_on_hub = bool(cache_on_hub) | |
| api = HfApi(token=token) | |
| username = api.whoami()["name"] | |
| repo_id = output_repo or default_output_repo(dataset_id, username) | |
| cache_root = get_dataset_cache_dir(dataset_id, config) | |
| cache_slug = get_cache_slug(dataset_id, config) | |
| if split and split.lower() != "all": | |
| split_names = [split] | |
| else: | |
| if cache_on_hub: | |
| repo_files = api.list_repo_files(repo_id, repo_type="dataset") | |
| prefix = f"chizzler_cache/{cache_slug}/" | |
| split_names = sorted( | |
| { | |
| path.split("/")[2] | |
| for path in repo_files | |
| if path.startswith(prefix) and len(path.split("/")) >= 3 | |
| } | |
| ) | |
| else: | |
| split_names = sorted( | |
| [ | |
| path.name | |
| for path in cache_root.iterdir() | |
| if path.is_dir() | |
| ] | |
| ) | |
| if not split_names: | |
| return "No cached shards found. Run processing first." | |
| repo_files = set() | |
| if cache_on_hub: | |
| try: | |
| repo_files = set( | |
| api.list_repo_files(repo_id, repo_type="dataset") | |
| ) | |
| except Exception: | |
| repo_files = set() | |
| processed_splits = {} | |
| for split_name in split_names: | |
| split_cache_dir = cache_root / split_name | |
| hub_cache_prefix = get_hub_cache_prefix(dataset_id, config, split_name) | |
| meta = load_split_meta( | |
| split_cache_dir, hub_cache_prefix, cache_on_hub, repo_id, token | |
| ) | |
| if not meta: | |
| return ( | |
| f"Missing cache metadata for split '{split_name}'. " | |
| "Re-run processing to rebuild shards." | |
| ) | |
| total = int(meta.get("total", 0)) | |
| shard_size = int(meta.get("shard_size", 0)) | |
| num_shards = int(meta.get("num_shards", 0)) | |
| if not total or not shard_size or not num_shards: | |
| return ( | |
| f"Incomplete cache metadata for split '{split_name}'. " | |
| "Re-run processing to rebuild shards." | |
| ) | |
| shards = [] | |
| missing = [] | |
| for shard_idx in range(num_shards): | |
| start = shard_idx * shard_size | |
| end = min(total, start + shard_size) | |
| cache_file = split_cache_dir / ( | |
| f"{split_name}-{start:07d}-{end:07d}.arrow" | |
| ) | |
| hub_cache_path = f"{hub_cache_prefix}/{cache_file.name}" | |
| if cache_file.exists(): | |
| cache_path = str(cache_file) | |
| elif cache_on_hub and hub_cache_path in repo_files: | |
| cache_path = hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=hub_cache_path, | |
| token=token, | |
| ) | |
| else: | |
| missing.append(cache_file.name) | |
| continue | |
| shards.append(Dataset.from_file(cache_path)) | |
| if missing: | |
| return ( | |
| f"Missing {len(missing)} shard(s) for split '{split_name}'. " | |
| "Run processing with resume enabled." | |
| ) | |
| processed_splits[split_name] = ( | |
| concatenate_datasets(shards) | |
| if len(shards) > 1 | |
| else shards[0] | |
| ) | |
| processed_dataset = ( | |
| DatasetDict(processed_splits) | |
| if len(processed_splits) > 1 | |
| else next(iter(processed_splits.values())) | |
| ) | |
| progress(0, desc="Uploading to the Hub...") | |
| processed_dataset.push_to_hub(repo_id, private=private_repo, token=token) | |
| progress(1.0, desc="Upload complete.") | |
| inferred_audio_column = ( | |
| audio_column or infer_audio_column(processed_dataset) or "audio" | |
| ) | |
| return ( | |
| f"Uploaded cleaned dataset to {repo_id} " | |
| f"(audio column: {inferred_audio_column})." | |
| ) | |
| def _gradio_single_file_gpu( | |
| audio_file, | |
| vad_threshold, | |
| max_silence_gap, | |
| normalize_audio, | |
| target_dbfs, | |
| max_boost_db, | |
| max_atten_db, | |
| request: gr.Request | None = None, | |
| ): | |
| if audio_file is None: | |
| return None, None, None, "Please upload an audio file." | |
| return process_audio_file( | |
| audio_file, | |
| threshold=vad_threshold, | |
| max_gap=max_silence_gap, | |
| normalize_audio=normalize_audio, | |
| target_dbfs=target_dbfs, | |
| max_boost_db=max_boost_db, | |
| max_atten_db=max_atten_db, | |
| ) | |
| def gradio_single_file( | |
| audio_file, | |
| vad_threshold, | |
| max_silence_gap, | |
| normalize_audio, | |
| target_dbfs, | |
| max_boost_db, | |
| max_atten_db, | |
| request: gr.Request | None = None, | |
| ): | |
| return _gradio_single_file_gpu( | |
| audio_file, | |
| vad_threshold, | |
| max_silence_gap, | |
| normalize_audio, | |
| target_dbfs, | |
| max_boost_db, | |
| max_atten_db, | |
| request=request, | |
| ) | |
| with gr.Blocks(title="Representation Chizzler") as demo: | |
| gr.Markdown( | |
| "# Representation Chizzler\n" | |
| "Two-stage audio processing: VAD-based speech extraction followed by MP-SENet " | |
| "denoising. Use the Single File tab for ad-hoc processing or the Dataset tab " | |
| "to clean and publish a dataset to the Hugging Face Hub." | |
| ) | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.Tab("Single File"): | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| vad_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.1, | |
| label="VAD Threshold (higher = stricter voice detection)", | |
| ) | |
| gap_slider = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=4.0, | |
| step=0.5, | |
| label="Max Silence Gap (seconds)", | |
| ) | |
| normalize_checkbox = gr.Checkbox( | |
| label="Normalize volume", value=True | |
| ) | |
| target_db_slider = gr.Slider( | |
| minimum=-35.0, | |
| maximum=-10.0, | |
| value=DEFAULT_TARGET_DBFS, | |
| step=1.0, | |
| label="Target loudness (dBFS)", | |
| ) | |
| max_boost_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=30.0, | |
| value=DEFAULT_MAX_BOOST_DB, | |
| step=1.0, | |
| label="Max boost (dB)", | |
| ) | |
| max_atten_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=20.0, | |
| value=DEFAULT_MAX_ATTEN_DB, | |
| step=1.0, | |
| label="Max attenuation (dB)", | |
| ) | |
| run_button = gr.Button("Process Audio") | |
| original_audio = gr.Audio(label="Original Audio") | |
| vad_audio = gr.Audio(label="VAD Processed (Speech Only)") | |
| denoised_audio = gr.Audio(label="Final Denoised") | |
| details_box = gr.Textbox(label="Processing Details", lines=10) | |
| run_button.click( | |
| fn=gradio_single_file, | |
| inputs=[ | |
| audio_input, | |
| vad_slider, | |
| gap_slider, | |
| normalize_checkbox, | |
| target_db_slider, | |
| max_boost_slider, | |
| max_atten_slider, | |
| ], | |
| outputs=[original_audio, vad_audio, denoised_audio, details_box], | |
| concurrency_limit=1, | |
| ) | |
| with gr.Tab("Dataset to Hub"): | |
| with gr.Row(): | |
| gr.LoginButton() | |
| dataset_id_input = gr.Textbox( | |
| label="Dataset ID or URL", | |
| value="https://huggingface.co/datasets/MohammadGholizadeh/fleurs-farsi", | |
| ) | |
| config_input = gr.Textbox(label="Config (optional)", value="") | |
| split_input = gr.Textbox(label="Split (optional, or 'all')", value="dev") | |
| audio_column_input = gr.Textbox( | |
| label="Audio column (optional, auto-detect if empty)", value="" | |
| ) | |
| output_repo_input = gr.Textbox( | |
| label="Output dataset repo (optional)", value="" | |
| ) | |
| private_checkbox = gr.Checkbox(label="Create private repo", value=False) | |
| max_examples_input = gr.Number( | |
| label="Max examples per split (optional)", value=None | |
| ) | |
| resume_checkbox = gr.Checkbox( | |
| label="Resume from cached shards", value=True | |
| ) | |
| auto_resume_checkbox = gr.Checkbox( | |
| label="Auto-resume on ZeroGPU preemption", | |
| value=DEFAULT_AUTO_RESUME, | |
| ) | |
| cache_to_hub_checkbox = gr.Checkbox( | |
| label="Cache shards on Hub (recommended for ZeroGPU)", | |
| value=DEFAULT_CACHE_TO_HUB, | |
| ) | |
| shard_size_input = gr.Number( | |
| label="Shard size (examples)", value=25 | |
| ) | |
| max_shards_input = gr.Number( | |
| label="Max shards per run (ZeroGPU: 1-5, 0 = no limit)", | |
| value=DEFAULT_MAX_SHARDS_PER_RUN, | |
| ) | |
| vad_slider_ds = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.1, | |
| label="VAD Threshold", | |
| ) | |
| gap_slider_ds = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=4.0, | |
| step=0.5, | |
| label="Max Silence Gap (seconds)", | |
| ) | |
| normalize_checkbox_ds = gr.Checkbox( | |
| label="Normalize volume", value=True | |
| ) | |
| target_db_slider_ds = gr.Slider( | |
| minimum=-35.0, | |
| maximum=-10.0, | |
| value=DEFAULT_TARGET_DBFS, | |
| step=1.0, | |
| label="Target loudness (dBFS)", | |
| ) | |
| max_boost_slider_ds = gr.Slider( | |
| minimum=0.0, | |
| maximum=30.0, | |
| value=DEFAULT_MAX_BOOST_DB, | |
| step=1.0, | |
| label="Max boost (dB)", | |
| ) | |
| max_atten_slider_ds = gr.Slider( | |
| minimum=0.0, | |
| maximum=20.0, | |
| value=DEFAULT_MAX_ATTEN_DB, | |
| step=1.0, | |
| label="Max attenuation (dB)", | |
| ) | |
| process_button = gr.Button( | |
| "Process/Resume Dataset (cache & push when complete)" | |
| ) | |
| assemble_button = gr.Button( | |
| "Assemble & Push Cached Dataset" | |
| ) | |
| status_box = gr.Textbox(label="Status", lines=6) | |
| process_button.click( | |
| fn=process_dataset_and_push, | |
| inputs=[ | |
| dataset_id_input, | |
| config_input, | |
| split_input, | |
| audio_column_input, | |
| output_repo_input, | |
| private_checkbox, | |
| vad_slider_ds, | |
| gap_slider_ds, | |
| normalize_checkbox_ds, | |
| target_db_slider_ds, | |
| max_boost_slider_ds, | |
| max_atten_slider_ds, | |
| max_examples_input, | |
| resume_checkbox, | |
| auto_resume_checkbox, | |
| shard_size_input, | |
| cache_to_hub_checkbox, | |
| max_shards_input, | |
| ], | |
| outputs=[status_box], | |
| concurrency_limit=1, | |
| ) | |
| assemble_button.click( | |
| fn=assemble_cached_dataset_and_push, | |
| inputs=[ | |
| dataset_id_input, | |
| config_input, | |
| split_input, | |
| audio_column_input, | |
| output_repo_input, | |
| private_checkbox, | |
| cache_to_hub_checkbox, | |
| ], | |
| outputs=[status_box], | |
| concurrency_limit=1, | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |