Spaces:
Sleeping
Sleeping
audio processing pipeliine
Browse files- src/app.py +56 -21
src/app.py
CHANGED
|
@@ -15,7 +15,9 @@ import io
|
|
| 15 |
import base64
|
| 16 |
import numpy as np
|
| 17 |
from transformers.pipelines import pipeline # Changed from transformers import pipeline
|
| 18 |
-
from transformers import WhisperProcessor
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Model options mapped to their requirements
|
| 21 |
MODEL_OPTIONS = {
|
|
@@ -40,25 +42,24 @@ MODEL_OPTIONS = {
|
|
| 40 |
}
|
| 41 |
|
| 42 |
# Initialize Whisper with proper configuration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
transcriber = pipeline(
|
| 44 |
"automatic-speech-recognition",
|
| 45 |
model="openai/whisper-base.en",
|
| 46 |
chunk_length_s=30,
|
| 47 |
stride_length_s=5,
|
| 48 |
-
|
| 49 |
-
device="cpu", # Explicitly set to CPU since we're seeing GPU warnings
|
| 50 |
torch_dtype=torch.float32,
|
|
|
|
| 51 |
generate_kwargs={
|
| 52 |
-
"
|
| 53 |
-
"language": "en",
|
| 54 |
-
"use_cache": True,
|
| 55 |
-
"return_timestamps": True
|
| 56 |
}
|
| 57 |
)
|
| 58 |
|
| 59 |
-
# Create processor for proper attention mask
|
| 60 |
-
processor = WhisperProcessor.from_pretrained("openai/whisper-base.en")
|
| 61 |
-
|
| 62 |
def get_system_specs() -> Dict[str, float]:
|
| 63 |
"""Get system specifications."""
|
| 64 |
# Get RAM
|
|
@@ -207,12 +208,23 @@ def process_speech(audio_data, history):
|
|
| 207 |
audio_array = audio_array.astype(np.float32)
|
| 208 |
audio_array /= np.max(np.abs(audio_array))
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
# Transcribe with error handling
|
| 211 |
try:
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
# Handle different result types
|
| 218 |
if isinstance(result, dict) and "text" in result:
|
|
@@ -422,6 +434,34 @@ with gr.Blocks(
|
|
| 422 |
queue=True # Enable queuing for better stream handling
|
| 423 |
)
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
# Update transcription handler
|
| 426 |
def update_live_transcription(audio):
|
| 427 |
"""Real-time transcription updates."""
|
|
@@ -430,14 +470,9 @@ with gr.Blocks(
|
|
| 430 |
|
| 431 |
try:
|
| 432 |
sample_rate, audio_array = audio
|
| 433 |
-
|
| 434 |
-
audio_array = audio_array.mean(axis=1)
|
| 435 |
-
audio_array = audio_array.astype(np.float32)
|
| 436 |
-
audio_array /= np.max(np.abs(audio_array))
|
| 437 |
|
| 438 |
-
result = transcriber(
|
| 439 |
-
{"sampling_rate": sample_rate, "raw": audio_array}
|
| 440 |
-
)
|
| 441 |
|
| 442 |
# Handle different result types
|
| 443 |
if isinstance(result, dict):
|
|
|
|
| 15 |
import base64
|
| 16 |
import numpy as np
|
| 17 |
from transformers.pipelines import pipeline # Changed from transformers import pipeline
|
| 18 |
+
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
|
| 19 |
+
import torchaudio
|
| 20 |
+
import torchaudio.transforms as T
|
| 21 |
|
| 22 |
# Model options mapped to their requirements
|
| 23 |
MODEL_OPTIONS = {
|
|
|
|
| 42 |
}
|
| 43 |
|
| 44 |
# Initialize Whisper with proper configuration
|
| 45 |
+
# Create components separately
|
| 46 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
|
| 47 |
+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
|
| 48 |
+
processor = WhisperProcessor(feature_extractor, tokenizer)
|
| 49 |
+
|
| 50 |
transcriber = pipeline(
|
| 51 |
"automatic-speech-recognition",
|
| 52 |
model="openai/whisper-base.en",
|
| 53 |
chunk_length_s=30,
|
| 54 |
stride_length_s=5,
|
| 55 |
+
device="cpu",
|
|
|
|
| 56 |
torch_dtype=torch.float32,
|
| 57 |
+
# Remove feature_extractor and tokenizer parameters as they're included in the model
|
| 58 |
generate_kwargs={
|
| 59 |
+
"use_cache": True
|
|
|
|
|
|
|
|
|
|
| 60 |
}
|
| 61 |
)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
def get_system_specs() -> Dict[str, float]:
|
| 64 |
"""Get system specifications."""
|
| 65 |
# Get RAM
|
|
|
|
| 208 |
audio_array = audio_array.astype(np.float32)
|
| 209 |
audio_array /= np.max(np.abs(audio_array))
|
| 210 |
|
| 211 |
+
# Ensure correct sampling rate
|
| 212 |
+
if sample_rate != 16000:
|
| 213 |
+
resampler = T.Resample(sample_rate, 16000)
|
| 214 |
+
audio_tensor = torch.FloatTensor(audio_array)
|
| 215 |
+
audio_tensor = resampler(audio_tensor)
|
| 216 |
+
audio_array = audio_tensor.numpy()
|
| 217 |
+
sample_rate = 16000
|
| 218 |
+
|
| 219 |
# Transcribe with error handling
|
| 220 |
try:
|
| 221 |
+
# Format dictionary correctly with required keys
|
| 222 |
+
inputs = {
|
| 223 |
+
"raw": audio_array,
|
| 224 |
+
"sampling_rate": sample_rate
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
result = transcriber(inputs)
|
| 228 |
|
| 229 |
# Handle different result types
|
| 230 |
if isinstance(result, dict) and "text" in result:
|
|
|
|
| 434 |
queue=True # Enable queuing for better stream handling
|
| 435 |
)
|
| 436 |
|
| 437 |
+
def process_audio(audio_array, sample_rate):
|
| 438 |
+
"""Pre-process audio for Whisper."""
|
| 439 |
+
if audio_array.ndim > 1:
|
| 440 |
+
audio_array = audio_array.mean(axis=1)
|
| 441 |
+
|
| 442 |
+
# Convert to tensor for resampling
|
| 443 |
+
audio_tensor = torch.FloatTensor(audio_array)
|
| 444 |
+
|
| 445 |
+
# Resample to 16kHz if needed
|
| 446 |
+
if sample_rate != 16000:
|
| 447 |
+
resampler = T.Resample(sample_rate, 16000)
|
| 448 |
+
audio_tensor = resampler(audio_tensor)
|
| 449 |
+
|
| 450 |
+
# Normalize
|
| 451 |
+
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
|
| 452 |
+
|
| 453 |
+
# Use feature extractor with correct sampling rate
|
| 454 |
+
features = feature_extractor(
|
| 455 |
+
audio_tensor.numpy(),
|
| 456 |
+
sampling_rate=16000, # Always use 16kHz
|
| 457 |
+
return_tensors="pt"
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
return {
|
| 461 |
+
"input_features": features.input_features,
|
| 462 |
+
"sampling_rate": 16000 # Return resampled rate
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
# Update transcription handler
|
| 466 |
def update_live_transcription(audio):
|
| 467 |
"""Real-time transcription updates."""
|
|
|
|
| 470 |
|
| 471 |
try:
|
| 472 |
sample_rate, audio_array = audio
|
| 473 |
+
input_features = process_audio(audio_array, sample_rate)
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
+
result = transcriber(input_features)
|
|
|
|
|
|
|
| 476 |
|
| 477 |
# Handle different result types
|
| 478 |
if isinstance(result, dict):
|