""" Load a user-submitted custom evaluator script and wrap it as a transcriber callable. The script must define:: from pathlib import Path def evaluate(file: Path) -> str: ... Each packed benchmark sample is written to a temporary WAV file and passed to ``evaluate`` once per sample inside the existing eval loop. """ from __future__ import annotations import importlib.util import os import shutil import sys import tempfile from collections.abc import Callable from pathlib import Path def _prepend_mega_asr_import_path() -> None: """Make the ``MegaASR`` package importable for custom top-level imports. Mega-ASR is laid out as ``/src/MegaASR``. Custom scripts often import it at module load (``from MegaASR.model.megaASR import MegaASR``), which runs before the script can touch ``sys.path``. The recipe normally injects ``MEGA_ASR_ROOT`` via the worker env, but that propagation can be missing (e.g. recipe env not forwarded), so we also fall back to the recipe's canonical clone path and accept either the ``src/`` layout or a package sitting at the repo root. """ roots: list[Path] = [] env_root = os.environ.get("MEGA_ASR_ROOT", "").strip() if env_root: roots.append(Path(env_root)) # Recipe default clone path; harmless if it doesn't exist. roots.append(Path("/tmp/Mega-ASR")) seen: set[str] = set() for root in roots: try: base = root.resolve() except OSError: continue key = str(base) if key in seen: continue seen.add(key) # Prefer the documented ``src/`` layout, but tolerate a root-level package. for candidate in (base / "src", base): if (candidate / "MegaASR").is_dir(): cand_s = str(candidate) if cand_s not in sys.path: sys.path.insert(0, cand_s) return def _candidate_import_roots() -> list[Path]: """Directories that may contain a freshly cloned repo's importable packages. Submitters typically ``git clone`` into ``/tmp`` from their setup script, so we look there (plus any explicitly declared paths) instead of requiring a hardcoded recipe in this repo. """ roots: list[Path] = [] # Submitter-declared search paths (``FFASR_IMPORT_PATHS``, os.pathsep-separated) # take priority; they can be exported from the setup script via ``$FFASR_ENV_FILE``. for raw in os.environ.get("FFASR_IMPORT_PATHS", "").split(os.pathsep): raw = raw.strip() if raw: roots.append(Path(raw)) env_root = os.environ.get("MEGA_ASR_ROOT", "").strip() if env_root: roots.append(Path(env_root)) roots.append(Path("/tmp/Mega-ASR")) # Conventional clone target: immediate children of /tmp. try: for child in sorted(Path("/tmp").iterdir()): if child.is_dir() and not child.name.startswith("ffasr_"): roots.append(child) except OSError: pass return roots def _add_path_for_missing_module(missing: str) -> bool: """Make top-level module ``missing`` importable by scanning setup-created clones. Looks for ``/`` (package dir), ``/.py`` (module), or the ``/src/`` layout, and prepends the containing directory to ``sys.path``. Lets submitter custom scripts ``import`` a cloned repo at top level without a maintainer-side recipe. Returns ``True`` if a path was added. """ top = (missing or "").split(".")[0] if not top: return False seen: set[str] = set() for root in _candidate_import_roots(): try: base = root.resolve() except OSError: continue for parent in (base, base / "src"): key = str(parent) if key in seen: continue seen.add(key) if (parent / top).is_dir() or (parent / f"{top}.py").is_file(): if key not in sys.path: sys.path.insert(0, key) return True return False def _exec_user_module(spec, mod) -> None: """Exec the user module, auto-adding clone import paths on ``ModuleNotFoundError``. Custom scripts commonly import a cloned repo at module load. If a top-level package isn't importable yet, we locate it among setup-created directories and retry, so submitters can provide their own setup + ``evaluate()`` without us registering a recipe in this repo. """ import importlib attempted: set[str] = set() while True: try: spec.loader.exec_module(mod) return except ModuleNotFoundError as exc: name = getattr(exc, "name", "") or "" top = name.split(".")[0] if not top or top in attempted: raise attempted.add(top) if not _add_path_for_missing_module(name): raise importlib.invalidate_caches() def discover_num_params(mod) -> int: """Best-effort model size (parameter count) for a custom evaluator module. Custom scripts load their own model, so we never hold the model object the way the built-in backends do. To still report a model size on the leaderboard, submitters can expose it in any of these ways (checked in order): * a module-level ``FFASR_NUM_PARAMS`` / ``NUM_PARAMS`` value: an ``int`` total parameter count, or a zero-arg callable returning one; * a module-level ``num_params`` / ``model_num_params`` value (``int`` or callable); * a module-level model object under a common name (``model``, ``MODEL``, ``asr_model``, ``pipe``, ``pipeline``) exposing ``.parameters()`` (we also look one level down at ``obj.model`` so a Transformers ``pipeline`` works). Returns 0 when nothing usable is found (model size simply stays blank). """ from ._model_utils import count_params def _as_int(value) -> int | None: if callable(value): try: value = value() except Exception: return None try: n = int(value) except (TypeError, ValueError): return None return n if n > 0 else None for name in ("FFASR_NUM_PARAMS", "NUM_PARAMS", "num_params", "model_num_params"): if hasattr(mod, name): n = _as_int(getattr(mod, name)) if n: return n for name in ("model", "MODEL", "asr_model", "pipe", "pipeline"): obj = getattr(mod, name, None) if obj is None: continue # A Transformers pipeline keeps the nn.Module at ``.model``; fall back to obj. candidate = getattr(obj, "model", None) or obj n = count_params(candidate) if n: return n return 0 def normalize_custom_script_compat(script_text: str) -> str: """ Rewrite deprecated Hugging Face / pyannote kwargs so scripts run on current hub versions. * ``use_auth_token=`` → ``token=`` * ``authentication_token=`` → ``token=`` """ import re s = script_text s = re.sub(r"\buse_auth_token\s*=", "token=", s) s = re.sub(r"\bauthentication_token\s*=", "token=", s) return s def build_transcriber_from_custom_script( script_text: str, ) -> tuple[Callable[..., str], Callable[[], None]]: """ Load ``script_text`` as a module and return ``(transcribe, cleanup)``. ``transcribe(audio_np, sampling_rate)`` writes a temp WAV and calls ``evaluate(path)`` on the loaded module. """ import soundfile as sf script_text = normalize_custom_script_compat((script_text or "").strip()) if not script_text: raise RuntimeError("Custom script is empty.") from evaluation.runtime import ( disable_broken_torchcodec, patch_transformers_load_audio_for_paths, ) disable_broken_torchcodec() patch_transformers_load_audio_for_paths() _prepend_mega_asr_import_path() hf_token = (os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or "").strip() if hf_token: os.environ.setdefault("HF_TOKEN", hf_token) os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", hf_token) fd, script_path = tempfile.mkstemp(prefix="ffasr_custom_eval_", suffix=".py") try: with os.fdopen(fd, "w", encoding="utf-8") as f: f.write(script_text) except Exception: try: os.unlink(script_path) except FileNotFoundError: pass raise spec = importlib.util.spec_from_file_location("ffasr_user_eval", script_path) if spec is None or spec.loader is None: try: os.unlink(script_path) except FileNotFoundError: pass raise RuntimeError("Could not load custom evaluator script.") mod = importlib.util.module_from_spec(spec) _exec_user_module(spec, mod) evaluate_fn = getattr(mod, "evaluate", None) if not callable(evaluate_fn): try: os.unlink(script_path) except FileNotFoundError: pass raise RuntimeError( "Custom script must define `evaluate(file: pathlib.Path) -> str`." ) wav_dir = tempfile.mkdtemp(prefix="ffasr_custom_wav_") counter = {"i": 0} def transcribe(audio_np, sampling_rate: int) -> str: counter["i"] += 1 wav_path = Path(wav_dir) / f"sample_{counter['i']:08d}.wav" sf.write(str(wav_path), audio_np, int(sampling_rate), subtype="PCM_16") try: text = evaluate_fn(wav_path) finally: try: wav_path.unlink() except FileNotFoundError: pass if isinstance(text, (list, tuple)): text = text[0] if text else "" return str(text or "").strip() try: transcribe._num_params = discover_num_params(mod) # type: ignore[attr-defined] except Exception: transcribe._num_params = 0 # type: ignore[attr-defined] def cleanup() -> None: shutil.rmtree(wav_dir, ignore_errors=True) try: os.unlink(script_path) except FileNotFoundError: pass return transcribe, cleanup