ffasr / backends /custom_eval.py
Shivam
model param fetch, text edit
65f8690
Raw
History Blame Contribute Delete
10.3 kB
"""
Load a user-submitted custom evaluator script and wrap it as a transcriber callable.
The script must define::
from pathlib import Path
def evaluate(file: Path) -> str:
...
Each packed benchmark sample is written to a temporary WAV file and passed to
``evaluate`` once per sample inside the existing eval loop.
"""
from __future__ import annotations
import importlib.util
import os
import shutil
import sys
import tempfile
from collections.abc import Callable
from pathlib import Path
def _prepend_mega_asr_import_path() -> None:
"""Make the ``MegaASR`` package importable for custom top-level imports.
Mega-ASR is laid out as ``<root>/src/MegaASR``. Custom scripts often import it
at module load (``from MegaASR.model.megaASR import MegaASR``), which runs before
the script can touch ``sys.path``. The recipe normally injects ``MEGA_ASR_ROOT``
via the worker env, but that propagation can be missing (e.g. recipe env not
forwarded), so we also fall back to the recipe's canonical clone path and accept
either the ``src/`` layout or a package sitting at the repo root.
"""
roots: list[Path] = []
env_root = os.environ.get("MEGA_ASR_ROOT", "").strip()
if env_root:
roots.append(Path(env_root))
# Recipe default clone path; harmless if it doesn't exist.
roots.append(Path("/tmp/Mega-ASR"))
seen: set[str] = set()
for root in roots:
try:
base = root.resolve()
except OSError:
continue
key = str(base)
if key in seen:
continue
seen.add(key)
# Prefer the documented ``src/`` layout, but tolerate a root-level package.
for candidate in (base / "src", base):
if (candidate / "MegaASR").is_dir():
cand_s = str(candidate)
if cand_s not in sys.path:
sys.path.insert(0, cand_s)
return
def _candidate_import_roots() -> list[Path]:
"""Directories that may contain a freshly cloned repo's importable packages.
Submitters typically ``git clone`` into ``/tmp`` from their setup script, so we
look there (plus any explicitly declared paths) instead of requiring a hardcoded
recipe in this repo.
"""
roots: list[Path] = []
# Submitter-declared search paths (``FFASR_IMPORT_PATHS``, os.pathsep-separated)
# take priority; they can be exported from the setup script via ``$FFASR_ENV_FILE``.
for raw in os.environ.get("FFASR_IMPORT_PATHS", "").split(os.pathsep):
raw = raw.strip()
if raw:
roots.append(Path(raw))
env_root = os.environ.get("MEGA_ASR_ROOT", "").strip()
if env_root:
roots.append(Path(env_root))
roots.append(Path("/tmp/Mega-ASR"))
# Conventional clone target: immediate children of /tmp.
try:
for child in sorted(Path("/tmp").iterdir()):
if child.is_dir() and not child.name.startswith("ffasr_"):
roots.append(child)
except OSError:
pass
return roots
def _add_path_for_missing_module(missing: str) -> bool:
"""Make top-level module ``missing`` importable by scanning setup-created clones.
Looks for ``<root>/<missing>`` (package dir), ``<root>/<missing>.py`` (module), or
the ``<root>/src/<missing>`` layout, and prepends the containing directory to
``sys.path``. Lets submitter custom scripts ``import`` a cloned repo at top level
without a maintainer-side recipe. Returns ``True`` if a path was added.
"""
top = (missing or "").split(".")[0]
if not top:
return False
seen: set[str] = set()
for root in _candidate_import_roots():
try:
base = root.resolve()
except OSError:
continue
for parent in (base, base / "src"):
key = str(parent)
if key in seen:
continue
seen.add(key)
if (parent / top).is_dir() or (parent / f"{top}.py").is_file():
if key not in sys.path:
sys.path.insert(0, key)
return True
return False
def _exec_user_module(spec, mod) -> None:
"""Exec the user module, auto-adding clone import paths on ``ModuleNotFoundError``.
Custom scripts commonly import a cloned repo at module load. If a top-level
package isn't importable yet, we locate it among setup-created directories and
retry, so submitters can provide their own setup + ``evaluate()`` without us
registering a recipe in this repo.
"""
import importlib
attempted: set[str] = set()
while True:
try:
spec.loader.exec_module(mod)
return
except ModuleNotFoundError as exc:
name = getattr(exc, "name", "") or ""
top = name.split(".")[0]
if not top or top in attempted:
raise
attempted.add(top)
if not _add_path_for_missing_module(name):
raise
importlib.invalidate_caches()
def discover_num_params(mod) -> int:
"""Best-effort model size (parameter count) for a custom evaluator module.
Custom scripts load their own model, so we never hold the model object the way
the built-in backends do. To still report a model size on the leaderboard,
submitters can expose it in any of these ways (checked in order):
* a module-level ``FFASR_NUM_PARAMS`` / ``NUM_PARAMS`` value: an ``int`` total
parameter count, or a zero-arg callable returning one;
* a module-level ``num_params`` / ``model_num_params`` value (``int`` or callable);
* a module-level model object under a common name (``model``, ``MODEL``,
``asr_model``, ``pipe``, ``pipeline``) exposing ``.parameters()`` (we also look
one level down at ``obj.model`` so a Transformers ``pipeline`` works).
Returns 0 when nothing usable is found (model size simply stays blank).
"""
from ._model_utils import count_params
def _as_int(value) -> int | None:
if callable(value):
try:
value = value()
except Exception:
return None
try:
n = int(value)
except (TypeError, ValueError):
return None
return n if n > 0 else None
for name in ("FFASR_NUM_PARAMS", "NUM_PARAMS", "num_params", "model_num_params"):
if hasattr(mod, name):
n = _as_int(getattr(mod, name))
if n:
return n
for name in ("model", "MODEL", "asr_model", "pipe", "pipeline"):
obj = getattr(mod, name, None)
if obj is None:
continue
# A Transformers pipeline keeps the nn.Module at ``.model``; fall back to obj.
candidate = getattr(obj, "model", None) or obj
n = count_params(candidate)
if n:
return n
return 0
def normalize_custom_script_compat(script_text: str) -> str:
"""
Rewrite deprecated Hugging Face / pyannote kwargs so scripts run on current hub versions.
* ``use_auth_token=`` → ``token=``
* ``authentication_token=`` → ``token=``
"""
import re
s = script_text
s = re.sub(r"\buse_auth_token\s*=", "token=", s)
s = re.sub(r"\bauthentication_token\s*=", "token=", s)
return s
def build_transcriber_from_custom_script(
script_text: str,
) -> tuple[Callable[..., str], Callable[[], None]]:
"""
Load ``script_text`` as a module and return ``(transcribe, cleanup)``.
``transcribe(audio_np, sampling_rate)`` writes a temp WAV and calls
``evaluate(path)`` on the loaded module.
"""
import soundfile as sf
script_text = normalize_custom_script_compat((script_text or "").strip())
if not script_text:
raise RuntimeError("Custom script is empty.")
from evaluation.runtime import (
disable_broken_torchcodec,
patch_transformers_load_audio_for_paths,
)
disable_broken_torchcodec()
patch_transformers_load_audio_for_paths()
_prepend_mega_asr_import_path()
hf_token = (os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or "").strip()
if hf_token:
os.environ.setdefault("HF_TOKEN", hf_token)
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", hf_token)
fd, script_path = tempfile.mkstemp(prefix="ffasr_custom_eval_", suffix=".py")
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(script_text)
except Exception:
try:
os.unlink(script_path)
except FileNotFoundError:
pass
raise
spec = importlib.util.spec_from_file_location("ffasr_user_eval", script_path)
if spec is None or spec.loader is None:
try:
os.unlink(script_path)
except FileNotFoundError:
pass
raise RuntimeError("Could not load custom evaluator script.")
mod = importlib.util.module_from_spec(spec)
_exec_user_module(spec, mod)
evaluate_fn = getattr(mod, "evaluate", None)
if not callable(evaluate_fn):
try:
os.unlink(script_path)
except FileNotFoundError:
pass
raise RuntimeError(
"Custom script must define `evaluate(file: pathlib.Path) -> str`."
)
wav_dir = tempfile.mkdtemp(prefix="ffasr_custom_wav_")
counter = {"i": 0}
def transcribe(audio_np, sampling_rate: int) -> str:
counter["i"] += 1
wav_path = Path(wav_dir) / f"sample_{counter['i']:08d}.wav"
sf.write(str(wav_path), audio_np, int(sampling_rate), subtype="PCM_16")
try:
text = evaluate_fn(wav_path)
finally:
try:
wav_path.unlink()
except FileNotFoundError:
pass
if isinstance(text, (list, tuple)):
text = text[0] if text else ""
return str(text or "").strip()
try:
transcribe._num_params = discover_num_params(mod) # type: ignore[attr-defined]
except Exception:
transcribe._num_params = 0 # type: ignore[attr-defined]
def cleanup() -> None:
shutil.rmtree(wav_dir, ignore_errors=True)
try:
os.unlink(script_path)
except FileNotFoundError:
pass
return transcribe, cleanup