Reza2kn's picture
Fix zero-length audio by decoding inputs + fallback
7772bb7 verified
import io
import json
import math
import os
import shutil
import subprocess
import sys
import tempfile
import time
import urllib.request
import warnings
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Tuple
import gradio as gr
import librosa
import numpy as np
import soundfile as sf
import torch
import torchaudio
from datasets import (
Audio,
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
Value,
concatenate_datasets,
load_dataset,
)
from dotenv import load_dotenv
from huggingface_hub import HfApi, hf_hub_download
from rich.console import Console
try:
import spaces
except Exception: # pragma: no cover - spaces is only available on HF Spaces
class _SpacesFallback:
def GPU(self, *args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _SpacesFallback()
# Initialize console for pretty printing
console = Console()
DEFAULT_SAMPLE_RATE = 16000
AUDIO_EXTENSIONS = (".wav", ".mp3", ".flac")
DEFAULT_TARGET_DBFS = -20.0
DEFAULT_MAX_BOOST_DB = 20.0
DEFAULT_MAX_ATTEN_DB = 10.0
DEFAULT_AUTO_RESUME = bool(os.getenv("SPACE_ID"))
DEFAULT_ZERO_GPU_SHARD_SIZE = int(
os.getenv("CHIZZLER_ZERO_GPU_SHARD_SIZE", "25")
)
DEFAULT_ZERO_GPU_MAX_SHARDS = int(
os.getenv("CHIZZLER_ZERO_GPU_MAX_SHARDS", "1")
)
SPACE_ID = os.getenv("SPACE_ID")
warnings.filterwarnings(
"ignore",
message=(
"LoginButton created outside of a Blocks context\\. "
"May not work unless you call its `activate\\(\\)` method manually\\."
),
)
def log_progress(message: str, level: int = 1, enabled: bool = True) -> None:
"""Log a progress message with timestamp and indentation."""
if not enabled:
return
indent = " " * (level - 1)
timestamp = datetime.now().strftime("%H:%M:%S")
console.print(f"[dim]{timestamp}[/dim] {indent}[bold blue]>[/bold blue] {message}")
sys.stdout.flush()
# Load environment variables
load_dotenv()
_GPU_DURATION = os.getenv("CHIZZLER_GPU_DURATION")
_GPU_DURATION_MAX = os.getenv("CHIZZLER_GPU_DURATION_MAX")
if _GPU_DURATION is not None:
try:
DEFAULT_GPU_DURATION = int(_GPU_DURATION)
except ValueError:
DEFAULT_GPU_DURATION = 0
else:
DEFAULT_GPU_DURATION = 0
if _GPU_DURATION_MAX is not None:
try:
max_duration = int(_GPU_DURATION_MAX)
if max_duration > 0:
DEFAULT_GPU_DURATION = min(DEFAULT_GPU_DURATION, max_duration)
except ValueError:
pass
def gpu_decorator(duration: int):
if duration and duration > 0:
try:
return spaces.GPU(duration=duration)
except TypeError:
return spaces.GPU
return spaces.GPU
def get_hf_token() -> Optional[str]:
return (
os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
or os.getenv("HF_API_TOKEN")
)
def normalize_dataset_id(value: str) -> str:
if not value:
return ""
value = value.strip()
if value.startswith("http"):
if "datasets/" in value:
value = value.split("datasets/", 1)[1]
elif "huggingface.co/" in value:
value = value.split("huggingface.co/", 1)[1]
value = value.split("?")[0].split("#")[0].strip("/")
parts = [part for part in value.split("/") if part]
if len(parts) >= 2:
return "/".join(parts[:2])
return value
CURRENT_DIR = Path(__file__).parent.resolve()
DEFAULT_MP_SENET_DIR = Path(os.getenv("MPSENET_DIR", CURRENT_DIR / "MP-SENet"))
MPSENET_GIT_REPO = os.getenv(
"MPSENET_GIT_REPO", "https://github.com/yxlu-0102/MP-SENet.git"
)
CACHE_DIR = Path(os.getenv("CHIZZLER_CACHE_DIR", CURRENT_DIR / "chizzler_cache"))
_ENV_MAX_SHARDS = os.getenv("CHIZZLER_MAX_SHARDS_PER_RUN")
if _ENV_MAX_SHARDS is not None:
DEFAULT_MAX_SHARDS_PER_RUN = int(_ENV_MAX_SHARDS)
else:
DEFAULT_MAX_SHARDS_PER_RUN = 1 if os.getenv("SPACE_ID") else 0
_ENV_CACHE_TO_HUB = os.getenv("CHIZZLER_CACHE_TO_HUB")
if _ENV_CACHE_TO_HUB is None:
DEFAULT_CACHE_TO_HUB = bool(os.getenv("SPACE_ID"))
else:
DEFAULT_CACHE_TO_HUB = _ENV_CACHE_TO_HUB.strip().lower() in ("1", "true", "yes")
def ensure_mpsenet_repo() -> Path:
if DEFAULT_MP_SENET_DIR.exists():
return DEFAULT_MP_SENET_DIR
auto_download = os.getenv("MPSENET_AUTO_DOWNLOAD") == "1" or os.getenv("SPACE_ID")
if auto_download:
log_progress("MP-SENet repo not found. Cloning...", 2)
try:
subprocess.run(
[
"git",
"clone",
"--depth=1",
MPSENET_GIT_REPO,
str(DEFAULT_MP_SENET_DIR),
],
check=True,
)
except Exception as exc:
raise RuntimeError(
"Failed to clone MP-SENet. Clone it manually or set MPSENET_DIR."
) from exc
return DEFAULT_MP_SENET_DIR
raise FileNotFoundError(
"MP-SENet repo not found. Clone it into MP-SENet/ or set MPSENET_DIR."
)
def resolve_mpsenet_files(mp_senet_dir: Path) -> Tuple[Path, Path]:
config_path = mp_senet_dir / "best_ckpt" / "config.json"
ckpt_path = mp_senet_dir / "best_ckpt" / "g_best_dns"
if config_path.exists() and ckpt_path.exists():
return config_path, ckpt_path
repo_id = os.getenv("MPSENET_REPO")
if repo_id:
config_filename = os.getenv("MPSENET_CONFIG_FILENAME", "config.json")
ckpt_filename = os.getenv("MPSENET_CKPT_FILENAME", "g_best_dns")
config_path = Path(
hf_hub_download(repo_id=repo_id, filename=config_filename)
)
ckpt_path = Path(hf_hub_download(repo_id=repo_id, filename=ckpt_filename))
return config_path, ckpt_path
raise FileNotFoundError(
"MP-SENet checkpoint files not found. Place best_ckpt/config.json and "
"best_ckpt/g_best_dns under MP-SENet/ or set MPSENET_REPO."
)
mp_senet_dir = ensure_mpsenet_repo()
sys.path.append(str(mp_senet_dir))
from dataset import mag_pha_istft, mag_pha_stft # noqa: E402
from env import AttrDict # noqa: E402
from models.model import MPNet # noqa: E402
def select_device() -> torch.device:
override = os.getenv("CHIZZLER_DEVICE", "").strip().lower()
if override:
return torch.device(override)
if torch.cuda.is_available():
return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def initialize_models(device_override: Optional[torch.device] = None):
log_progress("Initializing models...")
device = device_override or select_device()
log_progress(f"Using {device.type.upper()} for all operations", 2)
log_progress("Loading Silero VAD model...", 2)
model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False,
trust_repo=True,
)
vad_model = model.to(device)
log_progress("Loading MP-SENet model...", 2)
config_path, ckpt_path = resolve_mpsenet_files(mp_senet_dir)
with open(config_path, "r") as f:
config = AttrDict(json.loads(f.read()))
mpnet_model = MPNet(config).to(device)
state = torch.load(ckpt_path, map_location=device)
if isinstance(state, dict):
if "generator" in state:
state = state["generator"]
elif "state_dict" in state:
state = state["state_dict"]
mpnet_model.load_state_dict(state)
mpnet_model.eval()
return vad_model, utils, mpnet_model, config, device
vad_model = None
vad_utils = None
mpnet_model = None
mpnet_config = None
device = None
def get_models():
global vad_model, vad_utils, mpnet_model, mpnet_config, device
desired_device = select_device()
if vad_model is None or mpnet_model is None or mpnet_config is None:
vad_model, vad_utils, mpnet_model, mpnet_config, device = (
initialize_models(desired_device)
)
return vad_model, vad_utils, mpnet_model, mpnet_config, device
if device is None or str(device) != str(desired_device):
log_progress(f"Moving models to {desired_device}...", 2)
vad_model = vad_model.to(desired_device)
mpnet_model = mpnet_model.to(desired_device)
device = desired_device
return vad_model, vad_utils, mpnet_model, mpnet_config, device
def ensure_mono(waveform: torch.Tensor) -> torch.Tensor:
if waveform.dim() == 1:
return waveform.unsqueeze(0)
if waveform.dim() == 2 and waveform.size(0) > waveform.size(1):
waveform = waveform.transpose(0, 1)
if waveform.size(0) > 1:
return torch.mean(waveform, dim=0, keepdim=True)
return waveform
def resample_waveform(
waveform: torch.Tensor, sample_rate: int, target_rate: int = DEFAULT_SAMPLE_RATE
) -> Tuple[torch.Tensor, int]:
if sample_rate == target_rate:
return waveform, sample_rate
resampler = torchaudio.transforms.Resample(sample_rate, target_rate)
return resampler(waveform), target_rate
def load_audio_file(file_path: str, log: bool = True) -> Tuple[torch.Tensor, int]:
log_progress(f"Loading audio: {Path(file_path).name}", enabled=log)
waveform = None
sample_rate = None
try:
waveform, sample_rate = torchaudio.load(file_path)
waveform = ensure_mono(waveform)
except Exception as exc:
log_progress(f"torchaudio load failed: {exc}", 2, enabled=log)
if waveform is None or sample_rate is None:
try:
data, sample_rate = sf.read(
file_path, always_2d=True, dtype="float32"
)
waveform = torch.from_numpy(data.T)
except Exception as exc:
log_progress(f"soundfile load failed: {exc}", 2, enabled=log)
data, sample_rate = librosa.load(
file_path, sr=None, mono=False, dtype=np.float32
)
if data.ndim == 1:
data = data[None, :]
waveform = torch.from_numpy(data)
waveform = ensure_mono(waveform)
if sample_rate != DEFAULT_SAMPLE_RATE:
log_progress(
f"Resampling from {sample_rate}Hz to {DEFAULT_SAMPLE_RATE}Hz...",
2,
enabled=log,
)
waveform, sample_rate = resample_waveform(
waveform, sample_rate, DEFAULT_SAMPLE_RATE
)
return waveform, sample_rate
def get_speech_timestamps(
waveform: torch.Tensor,
sample_rate: int,
threshold: float = 0.5,
log: bool = True,
) -> List[dict]:
log_progress("Detecting speech segments...", enabled=log)
vad_model, vad_utils, _, _, _ = get_models()
(get_speech_timestamps_fn, _, _, _, _) = vad_utils
speech_timestamps = get_speech_timestamps_fn(
waveform,
vad_model,
threshold=threshold,
return_seconds=True,
)
log_progress(f"Found {len(speech_timestamps)} speech segments", 2, enabled=log)
return speech_timestamps
def merge_close_segments(segments: List[dict], max_gap: float = 4.0) -> List[dict]:
if not segments:
return segments
merged = []
current_segment = segments[0].copy()
for segment in segments[1:]:
gap_duration = segment["start"] - current_segment["end"]
if gap_duration <= max_gap:
current_segment["end"] = segment["end"]
else:
merged.append(current_segment)
current_segment = segment.copy()
merged.append(current_segment)
return merged
def extract_speech_waveform(
waveform: torch.Tensor, sample_rate: int, segments: List[dict]
) -> Optional[torch.Tensor]:
if not segments:
return None
parts = []
total_samples = waveform.size(1)
for segment in segments:
start = max(0, int(segment["start"] * sample_rate))
end = min(total_samples, int(segment["end"] * sample_rate))
if end > start:
parts.append(waveform[:, start:end])
if not parts:
return None
return torch.cat(parts, dim=1)
def denoise_audio_chunk(
audio_tensor: torch.Tensor,
mpnet_model: torch.nn.Module,
mpnet_config: AttrDict,
chunk_size: int = 5 * DEFAULT_SAMPLE_RATE,
) -> torch.Tensor:
chunks = []
for i in range(0, audio_tensor.size(1), chunk_size):
chunk = audio_tensor[:, i : min(i + chunk_size, audio_tensor.size(1))]
energy = torch.sum(chunk**2.0, dim=1)
norm_factor = torch.sqrt(chunk.size(1) / (energy + 1e-8))
chunk = chunk * norm_factor.unsqueeze(1)
with torch.no_grad():
noisy_amp, noisy_pha, _ = mag_pha_stft(
chunk,
mpnet_config.n_fft,
mpnet_config.hop_size,
mpnet_config.win_size,
mpnet_config.compress_factor,
)
amp_g, pha_g, _ = mpnet_model(noisy_amp, noisy_pha)
audio_g = mag_pha_istft(
amp_g,
pha_g,
mpnet_config.n_fft,
mpnet_config.hop_size,
mpnet_config.win_size,
mpnet_config.compress_factor,
)
audio_g = audio_g / norm_factor.unsqueeze(1)
chunks.append(audio_g)
del chunk, noisy_amp, noisy_pha, amp_g, pha_g
return torch.cat(chunks, dim=1)
def process_waveform(
waveform: torch.Tensor,
sample_rate: int,
threshold: float = 0.5,
max_gap: float = 4.0,
log: bool = True,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], str, bool]:
vad_model, vad_utils, mpnet_model, mpnet_config, device = get_models()
if waveform.device != device:
waveform = waveform.to(device)
log_progress("Stage 1: Voice Activity Detection", 2, enabled=log)
speech_timestamps = get_speech_timestamps(
waveform, sample_rate, threshold=threshold, log=log
)
merged_timestamps = merge_close_segments(speech_timestamps, max_gap)
details = ["Processing details:"]
if not merged_timestamps:
details.append("No speech detected in the audio.")
return None, None, "\n".join(details), False
for i, segment in enumerate(merged_timestamps, 1):
duration = segment["end"] - segment["start"]
details.append(
f"Segment {i}/{len(merged_timestamps)}: "
f"{segment['start']:.1f}s to {segment['end']:.1f}s "
f"(duration: {duration:.1f}s)"
)
vad_waveform = extract_speech_waveform(
waveform, sample_rate, merged_timestamps
)
if vad_waveform is None or vad_waveform.numel() == 0:
details.append("No speech detected after merging segments.")
return None, None, "\n".join(details), False
vad_duration = vad_waveform.size(1) / sample_rate
original_duration = waveform.size(1) / sample_rate
if original_duration > 0:
reduction = (1 - vad_duration / original_duration) * 100
else:
reduction = 0.0
details.append(f"VAD output duration: {vad_duration:.1f}s")
details.append(f"Reduced by: {reduction:.1f}%")
log_progress("Stage 2: MP-SENet denoising", 2, enabled=log)
with torch.no_grad():
denoised_waveform = denoise_audio_chunk(
vad_waveform, mpnet_model, mpnet_config
)
return vad_waveform, denoised_waveform, "\n".join(details), True
def process_audio_file(
audio_path: str,
threshold: float = 0.5,
max_gap: float = 4.0,
normalize_audio: bool = True,
target_dbfs: float = DEFAULT_TARGET_DBFS,
max_boost_db: float = DEFAULT_MAX_BOOST_DB,
max_atten_db: float = DEFAULT_MAX_ATTEN_DB,
) -> Tuple[str, str, str, str]:
log_progress(f"Processing: {Path(audio_path).name}")
waveform, sample_rate = load_audio_file(audio_path)
_, _, _, mpnet_config, _ = get_models()
vad_waveform, denoised_waveform, details, has_speech = process_waveform(
waveform, sample_rate, threshold=threshold, max_gap=max_gap, log=True
)
if not has_speech or vad_waveform is None or denoised_waveform is None:
return audio_path, audio_path, audio_path, details
if normalize_audio:
vad_waveform = normalize_waveform(
vad_waveform,
target_dbfs=target_dbfs,
max_boost_db=max_boost_db,
max_atten_db=max_atten_db,
)
denoised_waveform = normalize_waveform(
denoised_waveform,
target_dbfs=target_dbfs,
max_boost_db=max_boost_db,
max_atten_db=max_atten_db,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as vad_file, \
tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as denoised_file:
vad_path = vad_file.name
denoised_path = denoised_file.name
sf.write(
vad_path,
vad_waveform.squeeze().cpu().numpy(),
mpnet_config.sampling_rate,
)
sf.write(
denoised_path,
denoised_waveform.squeeze().cpu().numpy(),
mpnet_config.sampling_rate,
)
return audio_path, vad_path, denoised_path, details
def load_audio_bytes(audio_bytes: bytes, log: bool = False) -> Tuple[torch.Tensor, int]:
data, sample_rate = sf.read(
io.BytesIO(audio_bytes), always_2d=True, dtype="float32"
)
waveform = torch.from_numpy(data.T)
waveform = ensure_mono(waveform)
if sample_rate != DEFAULT_SAMPLE_RATE:
log_progress(
f"Resampling from {sample_rate}Hz to {DEFAULT_SAMPLE_RATE}Hz...",
2,
enabled=log,
)
waveform, sample_rate = resample_waveform(
waveform, sample_rate, DEFAULT_SAMPLE_RATE
)
return waveform, sample_rate
def normalize_waveform(
waveform: Optional[torch.Tensor],
target_dbfs: float = DEFAULT_TARGET_DBFS,
max_boost_db: float = DEFAULT_MAX_BOOST_DB,
max_atten_db: float = DEFAULT_MAX_ATTEN_DB,
) -> Optional[torch.Tensor]:
if waveform is None or waveform.numel() == 0:
return waveform
rms = torch.sqrt(torch.mean(waveform**2))
if not torch.isfinite(rms) or rms <= 1e-8:
return waveform
current_db = 20.0 * torch.log10(rms)
gain_db = target_dbfs - current_db
gain_db = torch.clamp(gain_db, -max_atten_db, max_boost_db)
gain = torch.pow(torch.tensor(10.0, device=waveform.device), gain_db / 20.0)
normalized = waveform * gain
return torch.clamp(normalized, -1.0, 1.0)
def _is_http_url(value: str) -> bool:
return value.startswith("http://") or value.startswith("https://")
def _parse_hf_dataset_uri(uri: str) -> Optional[Tuple[str, str, Optional[str]]]:
prefix = "hf://datasets/"
if not uri.startswith(prefix):
return None
rest = uri[len(prefix) :]
if "/" not in rest:
return None
repo_id, file_path = rest.split("/", 1)
revision = None
if "@" in repo_id:
repo_id, revision = repo_id.split("@", 1)
return repo_id, file_path, revision
def load_audio_url(url: str, token: Optional[str], log: bool = False) -> Tuple[torch.Tensor, int]:
headers = {}
if token and "huggingface.co" in url:
headers["Authorization"] = f"Bearer {token}"
request = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(request) as response:
data = response.read()
return load_audio_bytes(data, log=log)
def resolve_audio_path(
path: str, dataset_id: Optional[str], token: Optional[str]
) -> str:
if os.path.exists(path):
return path
parsed = _parse_hf_dataset_uri(path)
if parsed:
repo_id, filename, revision = parsed
try:
return hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=filename,
revision=revision,
token=token,
)
except Exception:
return path
if dataset_id and not os.path.isabs(path):
try:
return hf_hub_download(
repo_id=dataset_id,
repo_type="dataset",
filename=path,
token=token,
)
except Exception:
return path
return path
def prepare_waveform_from_entry(
entry,
log: bool = False,
dataset_id: Optional[str] = None,
token: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
if entry is None:
raise ValueError("Empty audio entry.")
if hasattr(entry, "get_all_samples"):
samples = entry.get_all_samples()
waveform = samples.data
sample_rate = samples.sample_rate
waveform = ensure_mono(waveform)
if sample_rate != DEFAULT_SAMPLE_RATE:
waveform, sample_rate = resample_waveform(
waveform, sample_rate, DEFAULT_SAMPLE_RATE
)
return waveform, sample_rate
if isinstance(entry, dict):
if entry.get("array") is not None:
sample_rate = entry.get("sampling_rate", DEFAULT_SAMPLE_RATE)
waveform = torch.tensor(entry["array"], dtype=torch.float32)
waveform = ensure_mono(waveform)
if sample_rate != DEFAULT_SAMPLE_RATE:
waveform, sample_rate = resample_waveform(
waveform, sample_rate, DEFAULT_SAMPLE_RATE
)
return waveform, sample_rate
if entry.get("bytes"):
audio_bytes = entry["bytes"]
if not isinstance(audio_bytes, (bytes, bytearray)):
audio_bytes = bytes(audio_bytes)
return load_audio_bytes(audio_bytes, log=log)
if entry.get("path"):
path = resolve_audio_path(entry["path"], dataset_id, token)
if _is_http_url(path):
return load_audio_url(path, token, log=log)
return load_audio_file(path, log=log)
if isinstance(entry, str):
path = resolve_audio_path(entry, dataset_id, token)
if _is_http_url(path):
return load_audio_url(path, token, log=log)
return load_audio_file(path, log=log)
raise ValueError("Unsupported audio entry format.")
def get_dataset_cache_dir(dataset_id: str, config: Optional[str]) -> Path:
slug = dataset_id.replace("/", "__")
if config:
slug = f"{slug}__{config}"
return CACHE_DIR / slug
def get_cache_slug(dataset_id: str, config: Optional[str]) -> str:
slug = dataset_id.replace("/", "__")
if config:
slug = f"{slug}__{config}"
return slug
def get_hub_cache_prefix(
dataset_id: str, config: Optional[str], split_name: str
) -> str:
slug = get_cache_slug(dataset_id, config)
return f"chizzler_cache/{slug}/{split_name}"
def load_split_meta(
split_cache_dir: Path,
hub_cache_prefix: str,
cache_on_hub: bool,
repo_id: str,
token: Optional[str],
) -> Optional[dict]:
meta_file = split_cache_dir / "_meta.json"
if meta_file.exists():
return json.loads(meta_file.read_text())
if cache_on_hub:
try:
meta_path = hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=f"{hub_cache_prefix}/_meta.json",
token=token,
)
return json.loads(Path(meta_path).read_text())
except Exception:
return None
return None
def infer_audio_column(dataset_obj) -> Optional[str]:
sample_ds = dataset_obj
if isinstance(dataset_obj, (DatasetDict, IterableDatasetDict)):
sample_ds = next(iter(dataset_obj.values()))
if hasattr(sample_ds, "features"):
for column, feature in sample_ds.features.items():
if isinstance(feature, Audio):
return column
if isinstance(sample_ds, Dataset) and len(sample_ds) > 0:
sample = sample_ds[0]
for column, value in sample.items():
if isinstance(value, dict) and (
"array" in value or "path" in value or "bytes" in value
):
return column
if isinstance(value, str) and value.lower().endswith(AUDIO_EXTENSIONS):
return column
return None
def default_output_repo(source_id: str, username: str) -> str:
name = source_id.split("/")[-1]
suffix = "representation-chizzler"
if not name.endswith(suffix):
name = f"{name}-{suffix}"
return f"{username}/{name}"
def _apply_zero_gpu_limits(
shard_size: int, max_shards: Optional[int]
) -> Tuple[int, Optional[int]]:
if not os.getenv("SPACE_ID"):
return shard_size, max_shards
adjusted_shard_size = min(shard_size, DEFAULT_ZERO_GPU_SHARD_SIZE)
if max_shards is None:
adjusted_max_shards = DEFAULT_ZERO_GPU_MAX_SHARDS
else:
adjusted_max_shards = min(max_shards, DEFAULT_ZERO_GPU_MAX_SHARDS)
return adjusted_shard_size, adjusted_max_shards
@gpu_decorator(DEFAULT_GPU_DURATION)
def _process_dataset_and_push_gpu(
dataset_id: str,
config: str,
split: str,
audio_column: str,
output_repo: str,
private_repo: bool,
vad_threshold: float,
max_silence_gap: float,
normalize_audio: bool,
target_dbfs: float,
max_boost_db: float,
max_atten_db: float,
max_examples: Optional[float],
resume_processing: bool,
shard_size: Optional[float],
cache_on_hub: bool,
max_shards_per_run: Optional[float],
request: gr.Request | None = None,
progress=gr.Progress(),
) -> str:
token = get_hf_token()
if not token:
return "Missing HF token. Set HF_TOKEN as a secret or env var."
dataset_id = normalize_dataset_id(dataset_id)
if not dataset_id:
return "Provide a dataset ID or URL."
# Ensure models are loaded on the correct device before heavy processing.
get_models()
config = config.strip() or None
split = split.strip()
audio_column = audio_column.strip()
output_repo = normalize_dataset_id(output_repo) if output_repo else ""
cache_on_hub = bool(cache_on_hub)
normalize_audio = bool(normalize_audio)
max_examples_int = int(max_examples) if max_examples and max_examples > 0 else None
shard_size_int = int(shard_size) if shard_size and shard_size > 0 else 1000
max_shards_int = (
int(max_shards_per_run)
if max_shards_per_run and max_shards_per_run > 0
else None
)
if os.getenv("SPACE_ID"):
adjusted_shard_size, adjusted_max_shards = _apply_zero_gpu_limits(
shard_size_int, max_shards_int
)
if adjusted_shard_size != shard_size_int:
log_progress(
f"ZeroGPU safe mode: shard size capped at {adjusted_shard_size}",
2,
)
shard_size_int = adjusted_shard_size
if adjusted_max_shards != max_shards_int:
log_progress(
f"ZeroGPU safe mode: max shards per run capped at {adjusted_max_shards}",
2,
)
max_shards_int = adjusted_max_shards
api = HfApi(token=token)
username = api.whoami()["name"]
repo_id = output_repo or default_output_repo(dataset_id, username)
if cache_on_hub:
api.create_repo(
repo_id, repo_type="dataset", private=private_repo, exist_ok=True
)
log_progress(
f"Caching shards to Hub repo: {repo_id}", 2
)
log_progress(f"Loading dataset: {dataset_id}")
progress(0, desc="Downloading dataset...")
if split and split.lower() != "all":
dataset_obj = load_dataset(
dataset_id, name=config, split=split, token=token
)
dataset_dict = DatasetDict({split: dataset_obj})
else:
dataset_obj = load_dataset(dataset_id, name=config, token=token)
dataset_dict = (
DatasetDict({"train": dataset_obj})
if isinstance(dataset_obj, Dataset)
else dataset_obj
)
progress(0.01, desc="Preparing splits...")
if not audio_column:
audio_column = infer_audio_column(dataset_dict) or ""
if not audio_column:
return (
"Could not infer audio column. Please specify the audio column "
"name manually."
)
processed_splits = {}
shards_processed = 0
cached_shards = 0
total_shards = 0
incomplete = False
repo_files = set()
if resume_processing and cache_on_hub:
try:
repo_files = set(
api.list_repo_files(repo_id, repo_type="dataset")
)
except Exception:
repo_files = set()
cache_root = get_dataset_cache_dir(dataset_id, config)
cache_root.mkdir(parents=True, exist_ok=True)
for split_name, split_ds in dataset_dict.items():
if (
hasattr(split_ds, "column_names")
and audio_column not in split_ds.column_names
):
return f"Audio column '{audio_column}' not found in split '{split_name}'."
try:
split_ds = split_ds.cast_column(
audio_column, Audio(sampling_rate=DEFAULT_SAMPLE_RATE)
)
except Exception:
split_ds = split_ds.cast_column(audio_column, Audio())
total = len(split_ds) if isinstance(split_ds, Dataset) else None
if max_examples_int and total is not None:
total = min(total, max_examples_int)
update_every = max(1, (total or max_examples_int or 100) // 100)
split_cache_dir = cache_root / split_name
if not resume_processing and split_cache_dir.exists():
shutil.rmtree(split_cache_dir)
split_cache_dir.mkdir(parents=True, exist_ok=True)
hub_cache_prefix = get_hub_cache_prefix(dataset_id, config, split_name)
features = split_ds.features.copy()
features[audio_column] = Audio(
sampling_rate=DEFAULT_SAMPLE_RATE
)
features["chizzler_ok"] = Value("bool")
features["chizzler_error"] = Value("string")
def make_map_fn(offset: int = 0):
def map_fn(example, idx):
entry = example.get(audio_column)
ok = True
error_message = ""
try:
waveform, sample_rate = prepare_waveform_from_entry(
entry, log=False, dataset_id=dataset_id, token=token
)
vad_waveform, denoised_waveform, _, has_speech = process_waveform(
waveform,
sample_rate,
threshold=vad_threshold,
max_gap=max_silence_gap,
log=False,
)
output_waveform = (
denoised_waveform
if has_speech and denoised_waveform is not None
else waveform
)
if normalize_audio:
output_waveform = normalize_waveform(
output_waveform,
target_dbfs=target_dbfs,
max_boost_db=max_boost_db,
max_atten_db=max_atten_db,
)
output_np = (
output_waveform.squeeze()
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
if output_np.size == 0:
ok = False
error_message = (
"Empty output waveform; using original audio."
)
output_np = (
waveform.squeeze()
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
output_entry = {
"array": output_np,
"sampling_rate": DEFAULT_SAMPLE_RATE,
}
except Exception as exc:
ok = False
error_message = str(exc)
output_entry = entry if entry is not None else {
"array": np.zeros(1, dtype=np.float32),
"sampling_rate": DEFAULT_SAMPLE_RATE,
}
example[audio_column] = output_entry
example["chizzler_ok"] = ok
example["chizzler_error"] = error_message
global_idx = offset + idx + 1
if total:
if global_idx % update_every == 0 or global_idx == total:
progress(
global_idx / total,
desc=(
f"Processing {split_name}: {global_idx}/{total}"
),
)
else:
if global_idx % update_every == 0:
progress(
0,
desc=f"Processing {split_name}: {global_idx} examples",
)
return example
return map_fn
if total:
num_shards = math.ceil(total / shard_size_int)
total_shards += num_shards
meta = {
"dataset_id": dataset_id,
"config": config or "",
"split": split_name,
"audio_column": audio_column,
"total": total,
"shard_size": shard_size_int,
"num_shards": num_shards,
}
meta_file = split_cache_dir / "_meta.json"
meta_file.write_text(json.dumps(meta, indent=2))
if cache_on_hub:
api.upload_file(
path_or_fileobj=str(meta_file),
path_in_repo=f"{hub_cache_prefix}/_meta.json",
repo_id=repo_id,
repo_type="dataset",
)
shards = []
for shard_idx in range(num_shards):
start = shard_idx * shard_size_int
end = min(total, start + shard_size_int)
cache_file = split_cache_dir / (
f"{split_name}-{start:07d}-{end:07d}.arrow"
)
hub_cache_path = f"{hub_cache_prefix}/{cache_file.name}"
if resume_processing and cache_file.exists():
processed_shard = Dataset.from_file(str(cache_file))
progress(
end / total,
desc=f"Processing {split_name}: {end}/{total}",
)
cached_shards += 1
elif resume_processing and cache_on_hub and hub_cache_path in repo_files:
cache_path = hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=hub_cache_path,
token=token,
)
processed_shard = Dataset.from_file(cache_path)
progress(
end / total,
desc=f"Processing {split_name}: {end}/{total}",
)
cached_shards += 1
else:
if max_shards_int and shards_processed >= max_shards_int:
incomplete = True
break
indices = range(start, end)
try:
shard_ds = split_ds.select(indices)
except Exception:
shard_ds = split_ds.select(list(indices))
processed_shard = shard_ds.map(
make_map_fn(offset=start),
with_indices=True,
load_from_cache_file=False,
cache_file_name=str(cache_file),
writer_batch_size=50,
num_proc=None,
features=features,
desc=(
f"Chizzling {split_name} "
f"({shard_idx + 1}/{num_shards})"
),
)
shards_processed += 1
if cache_on_hub:
api.upload_file(
path_or_fileobj=str(cache_file),
path_in_repo=hub_cache_path,
repo_id=repo_id,
repo_type="dataset",
)
repo_files.add(hub_cache_path)
shards.append(processed_shard)
if incomplete:
break
processed_split = (
concatenate_datasets(shards)
if len(shards) > 1
else shards[0]
)
else:
if max_shards_int and shards_processed >= max_shards_int:
incomplete = True
break
processed_split = split_ds.map(
make_map_fn(offset=0),
with_indices=True,
load_from_cache_file=False,
writer_batch_size=50,
num_proc=None,
features=features,
desc=f"Chizzling {split_name}",
)
shards_processed += 1
processed_splits[split_name] = processed_split
if incomplete:
total_done = cached_shards + shards_processed
progress_note = (
f" ({total_done}/{total_shards} shards ready)"
if total_shards
else ""
)
return (
f"Processed {shards_processed} new shard(s)"
f"{f', cached {cached_shards}' if cached_shards else ''}"
f"{progress_note}."
" Resume with cached shards to continue."
)
processed_dataset = (
DatasetDict(processed_splits)
if len(processed_splits) > 1
else next(iter(processed_splits.values()))
)
progress(0, desc="Uploading to the Hub...")
processed_dataset.push_to_hub(repo_id, private=private_repo, token=token)
progress(1.0, desc="Upload complete.")
return (
f"Uploaded cleaned dataset to {repo_id} "
f"(audio column: {audio_column})."
)
def process_dataset_and_push(
dataset_id: str,
config: str,
split: str,
audio_column: str,
output_repo: str,
private_repo: bool,
vad_threshold: float,
max_silence_gap: float,
normalize_audio: bool,
target_dbfs: float,
max_boost_db: float,
max_atten_db: float,
max_examples: Optional[float],
resume_processing: bool,
auto_resume: bool,
shard_size: Optional[float],
cache_on_hub: bool,
max_shards_per_run: Optional[float],
request: gr.Request | None = None,
progress=gr.Progress(),
) -> str:
if SPACE_ID and request is not None:
headers = getattr(request, "headers", None)
token_header = None
if headers and hasattr(headers, "get"):
token_header = headers.get("x-ip-token")
if not token_header:
log_progress(
"ZeroGPU auth header missing. Use the Space on huggingface.co "
"to attach your login to ZeroGPU quota.",
2,
)
attempts = 0
while True:
try:
result = _process_dataset_and_push_gpu(
dataset_id,
config,
split,
audio_column,
output_repo,
private_repo,
vad_threshold,
max_silence_gap,
normalize_audio,
target_dbfs,
max_boost_db,
max_atten_db,
max_examples,
resume_processing,
shard_size,
cache_on_hub,
max_shards_per_run,
request=request,
progress=progress,
)
except Exception as exc:
message = str(exc)
if "ZeroGPU proxy token expired" in message:
return (
"ZeroGPU login token expired. Click Process/Resume again "
"to refresh your session."
)
if auto_resume and "GPU task aborted" in message:
attempts += 1
log_progress(
f"ZeroGPU preempted. Retrying (attempt {attempts})...",
2,
)
time.sleep(2)
continue
raise
if not auto_resume:
return result
if "Resume with cached shards" in result:
attempts += 1
log_progress(
f"Auto-resume: continuing (attempt {attempts})...",
2,
)
time.sleep(2)
continue
return result
def assemble_cached_dataset_and_push(
dataset_id: str,
config: str,
split: str,
audio_column: str,
output_repo: str,
private_repo: bool,
cache_on_hub: bool,
progress=gr.Progress(),
) -> str:
token = get_hf_token()
if not token:
return "Missing HF token. Set HF_TOKEN as a secret or env var."
dataset_id = normalize_dataset_id(dataset_id)
if not dataset_id:
return "Provide a dataset ID or URL."
config = config.strip() or None
split = split.strip()
audio_column = audio_column.strip()
output_repo = normalize_dataset_id(output_repo) if output_repo else ""
cache_on_hub = bool(cache_on_hub)
api = HfApi(token=token)
username = api.whoami()["name"]
repo_id = output_repo or default_output_repo(dataset_id, username)
cache_root = get_dataset_cache_dir(dataset_id, config)
cache_slug = get_cache_slug(dataset_id, config)
if split and split.lower() != "all":
split_names = [split]
else:
if cache_on_hub:
repo_files = api.list_repo_files(repo_id, repo_type="dataset")
prefix = f"chizzler_cache/{cache_slug}/"
split_names = sorted(
{
path.split("/")[2]
for path in repo_files
if path.startswith(prefix) and len(path.split("/")) >= 3
}
)
else:
split_names = sorted(
[
path.name
for path in cache_root.iterdir()
if path.is_dir()
]
)
if not split_names:
return "No cached shards found. Run processing first."
repo_files = set()
if cache_on_hub:
try:
repo_files = set(
api.list_repo_files(repo_id, repo_type="dataset")
)
except Exception:
repo_files = set()
processed_splits = {}
for split_name in split_names:
split_cache_dir = cache_root / split_name
hub_cache_prefix = get_hub_cache_prefix(dataset_id, config, split_name)
meta = load_split_meta(
split_cache_dir, hub_cache_prefix, cache_on_hub, repo_id, token
)
if not meta:
return (
f"Missing cache metadata for split '{split_name}'. "
"Re-run processing to rebuild shards."
)
total = int(meta.get("total", 0))
shard_size = int(meta.get("shard_size", 0))
num_shards = int(meta.get("num_shards", 0))
if not total or not shard_size or not num_shards:
return (
f"Incomplete cache metadata for split '{split_name}'. "
"Re-run processing to rebuild shards."
)
shards = []
missing = []
for shard_idx in range(num_shards):
start = shard_idx * shard_size
end = min(total, start + shard_size)
cache_file = split_cache_dir / (
f"{split_name}-{start:07d}-{end:07d}.arrow"
)
hub_cache_path = f"{hub_cache_prefix}/{cache_file.name}"
if cache_file.exists():
cache_path = str(cache_file)
elif cache_on_hub and hub_cache_path in repo_files:
cache_path = hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=hub_cache_path,
token=token,
)
else:
missing.append(cache_file.name)
continue
shards.append(Dataset.from_file(cache_path))
if missing:
return (
f"Missing {len(missing)} shard(s) for split '{split_name}'. "
"Run processing with resume enabled."
)
processed_splits[split_name] = (
concatenate_datasets(shards)
if len(shards) > 1
else shards[0]
)
processed_dataset = (
DatasetDict(processed_splits)
if len(processed_splits) > 1
else next(iter(processed_splits.values()))
)
progress(0, desc="Uploading to the Hub...")
processed_dataset.push_to_hub(repo_id, private=private_repo, token=token)
progress(1.0, desc="Upload complete.")
inferred_audio_column = (
audio_column or infer_audio_column(processed_dataset) or "audio"
)
return (
f"Uploaded cleaned dataset to {repo_id} "
f"(audio column: {inferred_audio_column})."
)
@gpu_decorator(DEFAULT_GPU_DURATION)
def _gradio_single_file_gpu(
audio_file,
vad_threshold,
max_silence_gap,
normalize_audio,
target_dbfs,
max_boost_db,
max_atten_db,
request: gr.Request | None = None,
):
if audio_file is None:
return None, None, None, "Please upload an audio file."
return process_audio_file(
audio_file,
threshold=vad_threshold,
max_gap=max_silence_gap,
normalize_audio=normalize_audio,
target_dbfs=target_dbfs,
max_boost_db=max_boost_db,
max_atten_db=max_atten_db,
)
def gradio_single_file(
audio_file,
vad_threshold,
max_silence_gap,
normalize_audio,
target_dbfs,
max_boost_db,
max_atten_db,
request: gr.Request | None = None,
):
return _gradio_single_file_gpu(
audio_file,
vad_threshold,
max_silence_gap,
normalize_audio,
target_dbfs,
max_boost_db,
max_atten_db,
request=request,
)
with gr.Blocks(title="Representation Chizzler") as demo:
gr.Markdown(
"# Representation Chizzler\n"
"Two-stage audio processing: VAD-based speech extraction followed by MP-SENet "
"denoising. Use the Single File tab for ad-hoc processing or the Dataset tab "
"to clean and publish a dataset to the Hugging Face Hub."
)
with gr.Column():
with gr.Tabs():
with gr.Tab("Single File"):
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
vad_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.1,
label="VAD Threshold (higher = stricter voice detection)",
)
gap_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=4.0,
step=0.5,
label="Max Silence Gap (seconds)",
)
normalize_checkbox = gr.Checkbox(
label="Normalize volume", value=True
)
target_db_slider = gr.Slider(
minimum=-35.0,
maximum=-10.0,
value=DEFAULT_TARGET_DBFS,
step=1.0,
label="Target loudness (dBFS)",
)
max_boost_slider = gr.Slider(
minimum=0.0,
maximum=30.0,
value=DEFAULT_MAX_BOOST_DB,
step=1.0,
label="Max boost (dB)",
)
max_atten_slider = gr.Slider(
minimum=0.0,
maximum=20.0,
value=DEFAULT_MAX_ATTEN_DB,
step=1.0,
label="Max attenuation (dB)",
)
run_button = gr.Button("Process Audio")
original_audio = gr.Audio(label="Original Audio")
vad_audio = gr.Audio(label="VAD Processed (Speech Only)")
denoised_audio = gr.Audio(label="Final Denoised")
details_box = gr.Textbox(label="Processing Details", lines=10)
run_button.click(
fn=gradio_single_file,
inputs=[
audio_input,
vad_slider,
gap_slider,
normalize_checkbox,
target_db_slider,
max_boost_slider,
max_atten_slider,
],
outputs=[original_audio, vad_audio, denoised_audio, details_box],
concurrency_limit=1,
)
with gr.Tab("Dataset to Hub"):
with gr.Row():
gr.LoginButton()
dataset_id_input = gr.Textbox(
label="Dataset ID or URL",
value="https://huggingface.co/datasets/MohammadGholizadeh/fleurs-farsi",
)
config_input = gr.Textbox(label="Config (optional)", value="")
split_input = gr.Textbox(label="Split (optional, or 'all')", value="dev")
audio_column_input = gr.Textbox(
label="Audio column (optional, auto-detect if empty)", value=""
)
output_repo_input = gr.Textbox(
label="Output dataset repo (optional)", value=""
)
private_checkbox = gr.Checkbox(label="Create private repo", value=False)
max_examples_input = gr.Number(
label="Max examples per split (optional)", value=None
)
resume_checkbox = gr.Checkbox(
label="Resume from cached shards", value=True
)
auto_resume_checkbox = gr.Checkbox(
label="Auto-resume on ZeroGPU preemption",
value=DEFAULT_AUTO_RESUME,
)
cache_to_hub_checkbox = gr.Checkbox(
label="Cache shards on Hub (recommended for ZeroGPU)",
value=DEFAULT_CACHE_TO_HUB,
)
shard_size_input = gr.Number(
label="Shard size (examples)", value=25
)
max_shards_input = gr.Number(
label="Max shards per run (ZeroGPU: 1-5, 0 = no limit)",
value=DEFAULT_MAX_SHARDS_PER_RUN,
)
vad_slider_ds = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.1,
label="VAD Threshold",
)
gap_slider_ds = gr.Slider(
minimum=1.0,
maximum=10.0,
value=4.0,
step=0.5,
label="Max Silence Gap (seconds)",
)
normalize_checkbox_ds = gr.Checkbox(
label="Normalize volume", value=True
)
target_db_slider_ds = gr.Slider(
minimum=-35.0,
maximum=-10.0,
value=DEFAULT_TARGET_DBFS,
step=1.0,
label="Target loudness (dBFS)",
)
max_boost_slider_ds = gr.Slider(
minimum=0.0,
maximum=30.0,
value=DEFAULT_MAX_BOOST_DB,
step=1.0,
label="Max boost (dB)",
)
max_atten_slider_ds = gr.Slider(
minimum=0.0,
maximum=20.0,
value=DEFAULT_MAX_ATTEN_DB,
step=1.0,
label="Max attenuation (dB)",
)
process_button = gr.Button(
"Process/Resume Dataset (cache & push when complete)"
)
assemble_button = gr.Button(
"Assemble & Push Cached Dataset"
)
status_box = gr.Textbox(label="Status", lines=6)
process_button.click(
fn=process_dataset_and_push,
inputs=[
dataset_id_input,
config_input,
split_input,
audio_column_input,
output_repo_input,
private_checkbox,
vad_slider_ds,
gap_slider_ds,
normalize_checkbox_ds,
target_db_slider_ds,
max_boost_slider_ds,
max_atten_slider_ds,
max_examples_input,
resume_checkbox,
auto_resume_checkbox,
shard_size_input,
cache_to_hub_checkbox,
max_shards_input,
],
outputs=[status_box],
concurrency_limit=1,
)
assemble_button.click(
fn=assemble_cached_dataset_and_push,
inputs=[
dataset_id_input,
config_input,
split_input,
audio_column_input,
output_repo_input,
private_checkbox,
cache_to_hub_checkbox,
],
outputs=[status_box],
concurrency_limit=1,
)
demo.queue()
if __name__ == "__main__":
demo.launch()