|
|
import base64 |
|
|
import datetime |
|
|
import io |
|
|
import logging |
|
|
import os |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
from faster_whisper import WhisperModel |
|
|
from typing import Dict, Any, Union |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the endpoint handler. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory. In Hugging Face Inference Endpoints, |
|
|
this will be the directory containing the model files. |
|
|
""" |
|
|
|
|
|
logger.info("Initializing EndpointHandler") |
|
|
|
|
|
if os.environ.get("LOG_DIAGNOSTICS") == "true": |
|
|
self.__log_diagnostics__() |
|
|
|
|
|
cudaIsAvailable = os.environ.get("CUDA_VISIBLE_DEVICES") |
|
|
if cudaIsAvailable: |
|
|
logger.info("CUDA is available") |
|
|
else: |
|
|
logger.info("CUDA is not available, using CPU") |
|
|
|
|
|
|
|
|
device = "cuda" if cudaIsAvailable else "cpu" |
|
|
try: |
|
|
|
|
|
if device == "cuda": |
|
|
logger.info("Attempting to load model with CUDA support") |
|
|
self.model = WhisperModel( |
|
|
path or ".", |
|
|
compute_type="float16", |
|
|
device="cuda", |
|
|
) |
|
|
logger.info("Model loaded successfully with CUDA") |
|
|
else: |
|
|
|
|
|
raise ValueError("CUDA not available, using CPU") |
|
|
except Exception as e: |
|
|
|
|
|
logger.warning(f"Error loading model with CUDA: {e}") |
|
|
logger.info("Falling back to CPU model") |
|
|
try: |
|
|
self.model = WhisperModel( |
|
|
path or ".", |
|
|
compute_type="int8", |
|
|
device="cpu", |
|
|
) |
|
|
logger.info("Model loaded successfully with CPU") |
|
|
except Exception as cpu_err: |
|
|
logger.error(f"Error loading CPU model: {cpu_err}") |
|
|
raise |
|
|
|
|
|
|
|
|
self.sampling_rate = 16000 |
|
|
|
|
|
logger.info("EndpointHandler initialized") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Process a request. |
|
|
|
|
|
Args: |
|
|
data: Request data containing audio input and optional parameters. |
|
|
Expected format: |
|
|
- For batch processing: {"inputs": audio_data, ...parameters} |
|
|
- For streaming: {"inputs": audio_chunk, "stream": true, "session_id": "unique_id", ...parameters} |
|
|
|
|
|
Returns: |
|
|
Transcription result or error message. |
|
|
""" |
|
|
logger.info("Processing request") |
|
|
|
|
|
try: |
|
|
|
|
|
if "inputs" not in data: |
|
|
return {"error": "No inputs provided"} |
|
|
|
|
|
inputs = data.pop("inputs") |
|
|
parameters = data.pop("parameters", {}) |
|
|
|
|
|
|
|
|
audio, sampling_rate = self._process_audio_input(inputs) |
|
|
|
|
|
|
|
|
language = data.pop("language", "da") |
|
|
beam_size = data.pop("beam_size", 5) |
|
|
|
|
|
return self.transcribe( |
|
|
audio, |
|
|
sampling_rate, |
|
|
language, |
|
|
beam_size, |
|
|
**parameters, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing request: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
def _process_audio_input(self, inputs: Union[str, Dict[str, Any]]) -> tuple: |
|
|
""" |
|
|
Process audio input in various formats. |
|
|
|
|
|
Args: |
|
|
inputs: Audio input as base64 string, URL, or numpy array. |
|
|
|
|
|
Returns: |
|
|
Tuple of (audio_array, sampling_rate). |
|
|
""" |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
logger.info("Received audio input as base64 encoded string") |
|
|
|
|
|
if inputs.startswith(("data:", "data%3A")): |
|
|
|
|
|
|
|
|
if "base64," in inputs: |
|
|
audio_b64 = inputs.split("base64,")[1] |
|
|
else: |
|
|
audio_b64 = inputs |
|
|
|
|
|
|
|
|
audio_bytes = base64.b64decode(audio_b64) |
|
|
|
|
|
|
|
|
with io.BytesIO(audio_bytes) as audio_io: |
|
|
audio, sampling_rate = sf.read(audio_io) |
|
|
|
|
|
return audio, sampling_rate |
|
|
|
|
|
|
|
|
else: |
|
|
raise ValueError("URL or file path inputs are not supported") |
|
|
|
|
|
|
|
|
elif isinstance(inputs, dict) and "audio" in inputs: |
|
|
logger.info("Received audio input as dictionary") |
|
|
if isinstance(inputs["audio"], list): |
|
|
|
|
|
audio = np.array(inputs["audio"], dtype=np.float32) |
|
|
else: |
|
|
|
|
|
audio = inputs["audio"] |
|
|
|
|
|
|
|
|
sampling_rate = inputs.get("sampling_rate", self.sampling_rate) |
|
|
|
|
|
return audio, sampling_rate |
|
|
|
|
|
elif isinstance(inputs, bytes): |
|
|
logger.info("Received raw bytes input") |
|
|
|
|
|
with io.BytesIO(inputs) as audio_io: |
|
|
audio, sampling_rate = sf.read(audio_io) |
|
|
return audio, sampling_rate |
|
|
|
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported input format: {type(inputs)}") |
|
|
|
|
|
def transcribe( |
|
|
self, |
|
|
audio: np.ndarray, |
|
|
sampling_rate: int, |
|
|
language: str, |
|
|
beam_size: int, |
|
|
**kwargs, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Perform the transcription |
|
|
|
|
|
Args: |
|
|
audio: Audio data as numpy array. |
|
|
sampling_rate: Sampling rate of the audio. |
|
|
language: Language code. |
|
|
beam_size: Beam size for transcription. |
|
|
**kwargs: Additional parameters. |
|
|
|
|
|
Returns: |
|
|
Transcription result. |
|
|
""" |
|
|
logger.info(f"Batch transcription: {len(audio)} samples, {sampling_rate} Hz") |
|
|
|
|
|
|
|
|
if sampling_rate != self.sampling_rate: |
|
|
logger.warning( |
|
|
f"Sampling rate mismatch: {sampling_rate} Hz vs {self.sampling_rate} Hz" |
|
|
) |
|
|
|
|
|
logger.info(f"Parameters: {kwargs}") |
|
|
|
|
|
|
|
|
|
|
|
now = datetime.datetime.now() |
|
|
segments, info = self.model.transcribe( |
|
|
audio, |
|
|
language=language, |
|
|
beam_size=beam_size, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
logger.info(f"Transcription info: {info}") |
|
|
|
|
|
|
|
|
result = { |
|
|
"text": "", |
|
|
"segments": [], |
|
|
"language": info.language, |
|
|
"language_probability": info.language_probability, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
all_text = [] |
|
|
for segment in segments: |
|
|
segment_info = { |
|
|
"id": segment.id, |
|
|
"text": segment.text, |
|
|
"start": segment.start, |
|
|
"end": segment.end, |
|
|
"temperature": segment.temperature, |
|
|
"avg_logprob": segment.avg_logprob, |
|
|
"compression_ratio": segment.compression_ratio, |
|
|
"no_speech_prob": segment.no_speech_prob, |
|
|
} |
|
|
if kwargs.get("word_timestamps", False): |
|
|
|
|
|
segment_info["words"] = [ |
|
|
{ |
|
|
"word": word.word, |
|
|
"start": word.start, |
|
|
"end": word.end, |
|
|
} |
|
|
for word in segment.words |
|
|
] |
|
|
|
|
|
all_text.append(segment.text) |
|
|
result["segments"].append(segment_info) |
|
|
|
|
|
elapsed_time = datetime.datetime.now() - now |
|
|
logger.info(f"Transcription time: {elapsed_time}") |
|
|
logger.info(f"Segments: {len(result['segments'])}") |
|
|
|
|
|
result["text"] = " ".join(all_text) |
|
|
|
|
|
return result |
|
|
|
|
|
def __log_diagnostics__(self): |
|
|
""" |
|
|
Log diagnostics information for debugging. |
|
|
This includes CUDA availability, library paths, installed packages, |
|
|
and environment variables. |
|
|
|
|
|
Very useful as the HF endpoint runtime is rather secretive about its environment. |
|
|
""" |
|
|
logger.info("Logging environment diagnostics") |
|
|
|
|
|
logger.info("LD_LIBRARY_PATH:") |
|
|
if "LD_LIBRARY_PATH" in os.environ: |
|
|
logger.info(os.environ["LD_LIBRARY_PATH"]) |
|
|
else: |
|
|
logger.info("LD_LIBRARY_PATH not set") |
|
|
|
|
|
|
|
|
logger.info("LD_LIBRARY_PATH files:") |
|
|
if "LD_LIBRARY_PATH" in os.environ: |
|
|
for ld_path in os.environ["LD_LIBRARY_PATH"].split(":"): |
|
|
if os.path.exists(ld_path): |
|
|
logger.info(f" {ld_path}:") |
|
|
for file in os.listdir(ld_path): |
|
|
logger.info(f" {file}") |
|
|
else: |
|
|
logger.info(f" {ld_path} does not exist") |
|
|
|
|
|
logger.info(f"Installed Python packages:") |
|
|
import pkg_resources |
|
|
|
|
|
for package in pkg_resources.working_set: |
|
|
logger.info(f" {package.key}=={package.version}, {package.location}") |
|
|
|
|
|
|
|
|
logger.info("Environment variables:") |
|
|
for key, value in os.environ.items(): |
|
|
logger.info(f" {key}: {value}") |
|
|
|
|
|
logger.info("NVIDIA environment:") |
|
|
import subprocess |
|
|
|
|
|
try: |
|
|
result = subprocess.run(["nvidia-smi"], capture_output=True, text=True) |
|
|
logger.info(result.stdout) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not run nvidia-smi: {e}") |
|
|
|