File size: 10,196 Bytes
d56b9d9
 
0fea237
 
d56b9d9
0fea237
 
d56b9d9
dc382c8
d56b9d9
0fea237
dc382c8
 
d56b9d9
 
02c9b64
d56b9d9
 
 
5bebd85
d56b9d9
 
 
5bebd85
 
 
 
 
 
 
 
 
 
 
d56b9d9
 
 
 
 
5c395b2
 
 
 
 
 
 
 
d56b9d9
5bebd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d56b9d9
bb6107f
d56b9d9
 
 
 
 
 
 
 
 
bb6107f
5bebd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6107f
 
 
 
 
 
 
 
 
 
 
 
 
 
dc382c8
5bebd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc382c8
 
 
 
 
 
 
02c9b64
0fea237
5bebd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fea237
 
 
 
 
 
02c9b64
5bebd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02c9b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fea237
 
5bebd85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fea237
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import gradio as gr
from io import BytesIO
import librosa
import numpy as np
from os import getenv
from PIL.Image import Image, open as open_image
import soundfile as sf
import requests
from tempfile import NamedTemporaryFile
import torch
from transformers import AutoProcessor


# Try to import spaces decorator (for Hugging Face Spaces), otherwise use no-op decorator.
try:
    from spaces import GPU as spaces_gpu
except ImportError:
    # For local development, use a no-op decorator because spaces is not available.
    def spaces_gpu(func):
        """No-op decorator for local development when spaces module is not available."""
        return func

def get_pytorch_device() -> str:
    """Determine the best available PyTorch device for computation.
    
    Checks for available hardware accelerators in priority order:
    1. CUDA (Nvidia GPUs and AMD ROCm)
    2. XPU (Intel GPUs)
    3. MPS (Apple Silicon/Metal Performance Shaders)
    4. CPU (fallback)
    
    Returns:
        String device name: "cuda", "xpu", "mps", or "cpu"
    """
    return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm
       else "xpu" if torch.xpu.is_available() # Intel XPU
       else "mps" if torch.mps.is_available() # Apple Silicon
       else "cpu") # gl bro 🫠

def get_torch_dtype():
    """Get the appropriate torch dtype based on reduced memory setting.
    
    Returns:
        torch.float16 if reduced memory is enabled, None otherwise (uses default precision).
    """
    return torch.float16 if getenv("REDUCED_MEMORY", "False").lower() == "true" else None

def request_image(url: str) -> Image:
    """Fetch an image from a URL and return it as a PIL Image.
    
    Downloads an image from the provided URL and converts it to a PIL Image
    object for processing. Handles various HTTP errors and timeouts gracefully.
    
    Args:
        url: HTTP/HTTPS URL pointing to an image file.
    
    Returns:
        PIL Image object loaded from the URL.
    
    Raises:
        gr.Error: If the image cannot be fetched due to:
            - HTTP errors (4xx, 5xx status codes)
            - Network timeouts
            - Other request exceptions
    
    Note:
        - Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds)
        - Supports common image formats (JPEG, PNG, GIF, WebP, etc.)
    """
    try:
        response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
        response.raise_for_status()
        return open_image(BytesIO(response.content))
    except requests.HTTPError as e:
        raise gr.Error(f"Failed to fetch image from URL because of HTTP error: {e.response.status_code} {e.response.text}")
    except requests.Timeout as e:
        raise gr.Error(f"Failed to fetch image from URL because the request timed out.")
    except requests.RequestException as e:
        raise gr.Error(f"Failed to fetch image from URL: {str(e)}")

def request_audio(url: str) -> tuple[int, np.ndarray]:
    """Fetch an audio file from a URL and return it as audio data.
    
    Downloads an audio file from the provided URL and loads it using librosa,
    which supports many audio formats. Returns the audio data in a format
    compatible with Gradio's Audio component.
    
    Args:
        url: HTTP/HTTPS URL pointing to an audio file.
    
    Returns:
        Tuple containing:
            - int: Sample rate of the audio in Hz (e.g., 44100, 22050)
            - np.ndarray: Audio waveform data as a numpy array (float32, normalized)
    
    Raises:
        gr.Error: If the audio cannot be fetched or loaded due to:
            - HTTP errors (4xx, 5xx status codes)
            - Network timeouts
            - Unsupported audio formats
            - Other request or audio loading exceptions
    
    Note:
        - Timeout is configurable via REQUEST_TIMEOUT environment variable (default: 45 seconds)
        - Supports many audio formats (MP3, WAV, FLAC, OGG, M4A, etc.)
        - Audio is loaded at its native sample rate (sr=None)
        - Returns normalized float32 audio data suitable for processing
    """
    try:
        response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT", "45")))
        response.raise_for_status()
        audio_array, sample_rate = librosa.load(BytesIO(response.content), sr=None)
        return (sample_rate, audio_array)
    except requests.HTTPError as e:
        raise gr.Error(f"Failed to fetch audio from URL because of HTTP error: {e.response.status_code} {e.response.text}")
    except requests.Timeout as e:
        raise gr.Error(f"Failed to fetch audio from URL because the request timed out.")
    except requests.RequestException as e:
        raise gr.Error(f"Failed to fetch audio from URL: {str(e)}")
    except Exception as e:
        raise gr.Error(f"Failed to load audio file: {str(e)}")

def save_image_to_temp_file(image: Image) -> str:
    """Save a PIL Image to a temporary file on disk.
    
    Creates a temporary file with an appropriate extension based on the image's
    format and saves the image to it. This is needed because some APIs (like
    Hugging Face InferenceClient) require file paths rather than PIL Image objects.
    
    Args:
        image: PIL Image object to save.
    
    Returns:
        String path to the temporary file where the image was saved.
    
    Note:
        - Preserves the original image format if available
        - Falls back to PNG format if image.format is None
        - Temporary file is not automatically deleted (caller is responsible for cleanup)
        - File extension is determined from the image format
        - Useful for APIs that require local file paths rather than in-memory objects
    """
    image_format = image.format if image.format else 'PNG'
    format_extension = image_format.lower() if image_format else 'png'
    temp_file = NamedTemporaryFile(delete=False, suffix=f".{format_extension}")
    temp_path = temp_file.name
    temp_file.close()
    image.save(temp_path, format=image_format)
    return temp_path

def get_model_sample_rate(model_id: str) -> int:
    """Get the expected sample rate for an audio processing model.
    
    Retrieves the sample rate configuration from a Hugging Face model's
    feature extractor. This is useful for ensuring audio is resampled to
    match the model's expected input format.
    
    Args:
        model_id: Hugging Face model identifier (e.g., "openai/whisper-large-v3").
    
    Returns:
        Integer sample rate in Hz that the model expects (e.g., 16000).
        Defaults to 16000 Hz if the sample rate cannot be determined.
    
    Note:
        - Most ASR models use 16kHz sample rate
        - Uses AutoProcessor to access the model's feature extractor configuration
        - Returns a sensible default (16kHz) if the model config cannot be loaded
    """
    try:
        processor = AutoProcessor.from_pretrained(model_id)
        return processor.feature_extractor.sampling_rate
    except Exception:
        return 16000 # Fallback value as most ASR models use 16kHz

def resample_audio(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> np.ndarray:
    """Resample audio data to a target sample rate.
    
    Converts audio data to the target sample rate using librosa's resampling.
    Handles both bytes and numpy array input formats, converting bytes to
    float32 numpy arrays as needed.
    
    Args:
        target_sample_rate: Desired output sample rate in Hz (e.g., 16000).
        audio: Tuple containing:
            - int: Current sample rate of the audio
            - bytes | np.ndarray: Audio data (can be raw bytes or numpy array)
    
    Returns:
        Numpy array (float32) containing the resampled audio waveform.
        If sample rates match, returns the audio data unchanged.
    
    Raises:
        ValueError: If audio_data is neither bytes nor np.ndarray.
    
    Note:
        - Converts bytes to float32 by assuming int16 PCM format
        - Normalizes int16 values to [-1.0, 1.0] range
        - Only resamples if source and target sample rates differ
        - Uses librosa's high-quality resampling algorithm
    """
    sample_rate, audio_data = audio
    
    # Convert audio data to a numpy array if it’s bytes
    if isinstance(audio_data, bytes):
        audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
    elif isinstance(audio_data, np.ndarray):
        audio_array = audio_data.astype(np.float32)
    else:
        raise ValueError(f"Unsupported audio_data type: {type(audio_data)}")
    
    # Resample if sample rates don’t match.
    if sample_rate != target_sample_rate:
        audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=target_sample_rate)
    
    return audio_array

def save_audio_to_temp_file(target_sample_rate: int, audio: tuple[int, bytes | np.ndarray]) -> str:
    """Resample audio to target sample rate and save to a temporary WAV file.
    
    This function resamples audio data to match a target sample rate and saves
    it as a WAV file. This is useful for preparing audio for APIs that require
    specific sample rates and file formats.
    
    Args:
        target_sample_rate: Target sample rate in Hz for the output file (e.g., 16000).
        audio: Tuple containing:
            - int: Current sample rate of the input audio
            - bytes | np.ndarray: Audio data to process
    
    Returns:
        String path to the temporary WAV file where the audio was saved.
    
    Note:
        - Automatically resamples audio if sample rates don't match
        - Saves audio as WAV format (16-bit PCM)
        - Temporary file is not automatically deleted (caller is responsible for cleanup)
        - Audio is normalized and converted to float32 before saving
        - Useful for preparing audio for Hugging Face InferenceClient APIs
    """
    audio_array = resample_audio(target_sample_rate, audio)
    temp_file = NamedTemporaryFile(delete=False, suffix='.wav')
    temp_path = temp_file.name
    temp_file.close()
    sf.write(temp_path, audio_array, target_sample_rate, format='WAV')
    return temp_path