ai-building-blocks / utils.py
LiKenun's picture
Switch the automatic speech recognition (ASR) implementation to use the inference client instead
0fea237
raw
history blame
3.19 kB
import gradio as gr
from io import BytesIO
import librosa
import numpy as np
from os import getenv
from PIL.Image import Image, open as open_image
import soundfile as sf
import requests
from tempfile import NamedTemporaryFile
import torch
from transformers import AutoProcessor
# Try to import spaces decorator (for Hugging Face Spaces), otherwise use no-op decorator.
try:
from spaces import GPU as spaces_gpu
except ImportError:
# For local development, use a no-op decorator because spaces is not available.
def spaces_gpu(func):
return func
def get_pytorch_device() -> str:
return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm
else "xpu" if torch.xpu.is_available() # Intel XPU
else "mps" if torch.mps.is_available() # Apple Silicon
else "cpu") # gl bro 🫠
def request_image(url: str) -> Image:
try:
response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT")))
response.raise_for_status()
return open_image(BytesIO(response.content))
except requests.HTTPError as e:
raise gr.Error(f"Failed to fetch image from URL because of HTTP error: {e.response.status_code} {e.response.text}")
except requests.Timeout as e:
raise gr.Error(f"Failed to fetch image from URL because the request timed out.")
except requests.RequestException as e:
raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
def save_image_to_temp_file(image: Image) -> str:
image_format = image.format if image.format else 'PNG'
format_extension = image_format.lower() if image_format else 'png'
temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}")
temp_path = temp_file.name
temp_file.close()
image.save(temp_path, format=image_format)
return temp_path
def get_model_sample_rate(model_id: str) -> int:
try:
processor = AutoProcessor.from_pretrained(model_id)
return processor.feature_extractor.sampling_rate
except Exception:
return 16000 # Fallback value as most ASR models use 16kHz
def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> np.ndarray:
sample_rate, audio_data = audio
# Convert audio data to a numpy array if it’s bytes
if isinstance(audio_data, bytes):
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
elif isinstance(audio_data, np.ndarray):
audio_array = audio_data.astype(np.float32)
else:
raise ValueError(f"Unsupported audio_data type: {type(audio_data)}")
# Resample if sample rates don’t match.
if sample_rate != target_sample_rate:
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sample_rate)
return audio_array
def save_audio_to_temp_file(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> str:
audio_array = resample_audio(target_sample_rate, audio)
temp_file = NamedTemporaryFile(delete=False, suffix='.wav')
temp_path = temp_file.name
temp_file.close()
sf.write(temp_path, audio_array, target_sample_rate, format='WAV')
return temp_path