tiny-audio / handler.py
mazesmazes's picture
Update custom model files, README, and requirements
ec54f1d verified
raw
history blame
5.03 kB
"""Custom inference handler for HuggingFace Inference Endpoints."""
from typing import Any, Dict, List, Union
import torch
try:
# For remote execution, imports are relative
from .asr_modeling import ASRModel
from .asr_pipeline import ASRPipeline
except ImportError:
# For local execution, imports are not relative
from asr_modeling import ASRModel # type: ignore[no-redef]
from asr_pipeline import ASRPipeline # type: ignore[no-redef]
class EndpointHandler:
def __init__(self, path: str = ""):
# Set environment variables for PyTorch/CUDA (must be before imports/operations)
import os
# Enable expandable segments to reduce fragmentation
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# Enable TF32 for faster matmul on A40/A100
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set device and dtype
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
# Enable CUDA optimizations
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
# Prepare model kwargs for pipeline
model_kwargs = {
"dtype": self.dtype,
"low_cpu_mem_usage": True,
}
if torch.cuda.is_available():
model_kwargs["attn_implementation"] = (
"flash_attention_2" if self._is_flash_attn_available() else "sdpa"
)
# Load model (this loads the model, tokenizer, and feature extractor)
self.model = ASRModel.from_pretrained(path, **model_kwargs)
# Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
self.pipe = ASRPipeline(
model=self.model,
feature_extractor=self.model.feature_extractor,
tokenizer=self.model.tokenizer,
device=self.device,
)
# Apply torch.compile if enabled (after model is loaded by pipeline)
# Enable by default for significant speedup (20-40%)
if torch.cuda.is_available() and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1":
compile_mode = os.getenv("TORCH_COMPILE_MODE", "reduce-overhead")
self.model = torch.compile(self.model, mode=compile_mode)
# Update the pipeline with the compiled model
self.pipe.model = self.model
# Warmup the model
if torch.cuda.is_available():
self._warmup()
def _is_flash_attn_available(self):
"""Check if flash attention is available."""
import importlib.util
return importlib.util.find_spec("flash_attn") is not None
def _warmup(self):
"""Warmup to trigger model compilation and allocate GPU memory."""
try:
# Create dummy audio (1 second at config sample rate)
sample_rate = self.pipe.model.config.audio_sample_rate
dummy_audio = torch.randn(sample_rate, dtype=torch.float32)
# The pipeline now handles GPU optimization internally
with torch.inference_mode():
warmup_tokens = self.pipe.model.config.inference_warmup_tokens
_ = self.pipe(
{"raw": dummy_audio, "sampling_rate": sample_rate}, max_new_tokens=warmup_tokens
)
# Force CUDA synchronization to ensure kernels are compiled
if torch.cuda.is_available():
torch.cuda.synchronize()
# Clear cache after warmup to free memory
torch.cuda.empty_cache()
except Exception as e:
print(f"Warmup skipped due to: {e}")
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
inputs = data.get("inputs")
if inputs is None:
raise ValueError("Missing 'inputs' in request data")
params = data.get("parameters", {})
max_new_tokens = params.get("max_new_tokens", 200)
num_beams = params.get("num_beams", 1)
do_sample = params.get("do_sample", False)
length_penalty = params.get("length_penalty", 1.0)
repetition_penalty = params.get("repetition_penalty", 1.0)
no_repeat_ngram_size = params.get("no_repeat_ngram_size", 0)
early_stopping = params.get("early_stopping", True)
default_diversity = self.pipe.model.config.inference_diversity_penalty
diversity_penalty = params.get("diversity_penalty", default_diversity)
return self.pipe(
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping,
diversity_penalty=diversity_penalty,
)