tiny-audio / handler.py
mazesmazes's picture
Update custom model files, README, and requirements
73cda73 verified
"""Custom inference handler for HuggingFace Inference Endpoints."""
from typing import Any, Dict, List, Union
try:
# For remote execution, imports are relative
from .asr_modeling import ASRModel
from .asr_pipeline import ASRPipeline
except ImportError:
# For local execution, imports are not relative
from asr_modeling import ASRModel # type: ignore[no-redef]
from asr_pipeline import ASRPipeline # type: ignore[no-redef]
class EndpointHandler:
"""HuggingFace Inference Endpoints handler for ASR model.
Handles model loading, warmup, and inference requests for deployment
on HuggingFace Inference Endpoints or similar services.
"""
def __init__(self, path: str = ""):
"""Initialize the endpoint handler.
Args:
path: Path to model directory or HuggingFace model ID
"""
import os
import nltk
nltk.download("punkt_tab", quiet=True)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# Prepare model kwargs - let transformers handle device placement
model_kwargs = {
"device_map": "auto",
"torch_dtype": "auto",
"low_cpu_mem_usage": True,
}
if self._is_flash_attn_available():
model_kwargs["attn_implementation"] = "flash_attention_2"
# Load model (this loads the model, tokenizer, and feature extractor)
self.model = ASRModel.from_pretrained(path, **model_kwargs)
# Get device from model for pipeline
self.device = next(self.model.parameters()).device
# Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
self.pipe = ASRPipeline(
model=self.model,
feature_extractor=self.model.feature_extractor,
tokenizer=self.model.tokenizer,
device=self.device,
)
def _is_flash_attn_available(self):
"""Check if flash attention is available."""
import importlib.util
return importlib.util.find_spec("flash_attn") is not None
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Process an inference request.
Args:
data: Request data containing 'inputs' (audio path/bytes) and optional 'parameters'
Returns:
Transcription result with 'text' key
"""
inputs = data.get("inputs")
if inputs is None:
raise ValueError("Missing 'inputs' in request data")
# Pass through any parameters from request, let model config provide defaults
params = data.get("parameters", {})
return self.pipe(inputs, **params)