| | """Custom inference handler for HuggingFace Inference Endpoints.""" |
| |
|
| | from typing import Any, Dict, List, Union |
| |
|
| | try: |
| | |
| | from .asr_modeling import ASRModel |
| | from .asr_pipeline import ASRPipeline |
| | except ImportError: |
| | |
| | from asr_modeling import ASRModel |
| | from asr_pipeline import ASRPipeline |
| |
|
| |
|
| | class EndpointHandler: |
| | """HuggingFace Inference Endpoints handler for ASR model. |
| | |
| | Handles model loading, warmup, and inference requests for deployment |
| | on HuggingFace Inference Endpoints or similar services. |
| | """ |
| |
|
| | def __init__(self, path: str = ""): |
| | """Initialize the endpoint handler. |
| | |
| | Args: |
| | path: Path to model directory or HuggingFace model ID |
| | """ |
| | import os |
| |
|
| | import nltk |
| |
|
| | nltk.download("punkt_tab", quiet=True) |
| |
|
| | os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
| |
|
| | |
| | model_kwargs = { |
| | "device_map": "auto", |
| | "torch_dtype": "auto", |
| | "low_cpu_mem_usage": True, |
| | } |
| | if self._is_flash_attn_available(): |
| | model_kwargs["attn_implementation"] = "flash_attention_2" |
| |
|
| | |
| | self.model = ASRModel.from_pretrained(path, **model_kwargs) |
| |
|
| | |
| | self.device = next(self.model.parameters()).device |
| |
|
| | |
| | self.pipe = ASRPipeline( |
| | model=self.model, |
| | feature_extractor=self.model.feature_extractor, |
| | tokenizer=self.model.tokenizer, |
| | device=self.device, |
| | ) |
| |
|
| | def _is_flash_attn_available(self): |
| | """Check if flash attention is available.""" |
| | import importlib.util |
| |
|
| | return importlib.util.find_spec("flash_attn") is not None |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
| | """Process an inference request. |
| | |
| | Args: |
| | data: Request data containing 'inputs' (audio path/bytes) and optional 'parameters' |
| | |
| | Returns: |
| | Transcription result with 'text' key |
| | """ |
| | inputs = data.get("inputs") |
| | if inputs is None: |
| | raise ValueError("Missing 'inputs' in request data") |
| |
|
| | |
| | params = data.get("parameters", {}) |
| |
|
| | return self.pipe(inputs, **params) |
| |
|