asr / app.py
michaeltangz
fix app.py to remove redundant generate_kwargs in transcription calls and enable demo sharing during launch
3e64dd3
import spaces
import torch
import gradio as gr
import os
import uuid
import scipy.io.wavfile
import time
import numpy as np
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
MODEL_NAME = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME,
dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa"
)
model.to(device)
processor = AutoProcessor.from_pretrained(MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME)
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=10,
device=device,
ignore_warning=True,
)
@spaces.GPU
def stream_transcribe(stream, new_chunk):
start_time = time.time()
try:
sr, y = new_chunk
if y.ndim > 1:
y = y.mean(axis=1)
y = y.astype(np.float32)
max_val = np.max(np.abs(y))
if max_val > 0:
y /= max_val
if stream is not None:
stream = np.concatenate([stream, y])
else:
stream = y
transcription = pipe({"sampling_rate": sr, "raw": stream})["text"]
end_time = time.time()
latency = end_time - start_time
return stream, transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Transcription: {e}")
return stream, str(e), "Error"
@spaces.GPU
def transcribe(inputs, previous_transcription):
start_time = time.time()
try:
filename = f"{uuid.uuid4().hex}.wav"
sample_rate, audio_data = inputs
scipy.io.wavfile.write(filename, sample_rate, audio_data)
transcription = pipe(filename)["text"]
previous_transcription += transcription
end_time = time.time()
latency = end_time - start_time
return previous_transcription, f"{latency:.2f}"
except Exception as e:
print(f"Error during Transcription: {e}")
return previous_transcription, "Error"
def clear():
return ""
def clear_state():
return None
with gr.Blocks() as microphone:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo\nTranscribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.")
with gr.Row():
input_audio_microphone = gr.Audio(streaming=True)
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
clear_button = gr.Button("Clear Output")
state = gr.State()
input_audio_microphone.stream(
stream_transcribe,
inputs=[state, input_audio_microphone],
outputs=[state, output, latency_textbox]
)
clear_button.click(clear_state, outputs=[state]).then(clear, outputs=[output])
with gr.Blocks() as file:
with gr.Column():
gr.Markdown(f"# Realtime Whisper Large V3 Turbo\nTranscribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.")
with gr.Row():
input_audio_microphone = gr.Audio(sources="upload", type="numpy")
output = gr.Textbox(label="Transcription", value="")
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear Output")
submit_button.click(transcribe, inputs=[input_audio_microphone, output], outputs=[output, latency_textbox])
clear_button.click(clear, outputs=[output])
with gr.Blocks() as demo:
gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
if __name__ == "__main__":
demo.launch(share=True)