Spaces:
Sleeping
Sleeping
File size: 2,241 Bytes
abf1092 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | #!/usr/bin/env python3
"""Default model adapter for benchmark inference.
Abstract interface and canonical batch input are defined in ``model_abc.py``.
This module provides the default adapter that wraps the official CMI baseline
(``baselines/inference.py``).
"""
from __future__ import annotations
from dataclasses import dataclass
import importlib
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional
import numpy as np
_repo_root = Path(__file__).resolve().parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from core.model_abc import BenchmarkBatchInput, BenchmarkModelABC
BenchmarkModelInterface = BenchmarkModelABC
@dataclass
class RewardModelAdapterConfig:
checkpoint: str
config: Optional[str] = None # path to config.yaml; auto-detected if None
device: str = "cuda:0"
mode: str = "final" # "final" or "standard"
init_kwargs: Optional[Dict[str, Any]] = None
class RewardModelAdapter(BenchmarkModelABC):
"""Default adapter wrapping the official CMI baseline.
``mode="final"`` → chunk-encode (training-consistent, recommended)
``mode="standard"`` → encode full segment / sliding window
"""
def __init__(self, cfg: RewardModelAdapterConfig) -> None:
baseline_module = importlib.import_module("baselines.inference")
RewardModelInference = baseline_module.RewardModelInference
init_kwargs = dict(cfg.init_kwargs or {})
self._impl = RewardModelInference(
checkpoint=cfg.checkpoint,
config=cfg.config,
device=cfg.device,
mode=cfg.mode,
**init_kwargs,
)
@property
def sr(self) -> int:
return int(self._impl.sr)
def score_batch(
self,
inputs: List[BenchmarkBatchInput],
batch_size: int,
max_dur: float,
**kwargs: Any,
) -> np.ndarray:
payload = [x.to_dict() for x in inputs]
return self._impl.score_batch(
inputs=payload,
batch_size=batch_size,
max_dur=max_dur,
**kwargs,
)
|