whisper-base / app.py
nixaut-codelabs's picture
Update app.py
c816da3 verified
import gradio as gr
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import numpy as np
import os
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["MKL_NUM_THREADS"] = "2"
torch.set_num_threads(2)
device = "cpu"
torch_dtype = torch.float32
model_id = "openai/whisper-base"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa"
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
dtype=torch_dtype,
device=device,
ignore_warning=True
)
def transcribe_audio(audio_file, task="transcribe", language="auto", return_timestamps=False):
if audio_file is None:
return "No audio file provided."
try:
with torch.inference_mode():
generate_kwargs = {
"task": task,
"language": None if language == "auto" else language,
"num_beams": 1,
"do_sample": False,
"temperature": 0.0,
"max_new_tokens": 220,
"compression_ratio_threshold": 1.35,
"logprob_threshold": -1.0,
"no_speech_threshold": 0.6,
}
if task == "translate":
generate_kwargs["task"] = "translate"
result = pipe(
audio_file,
return_timestamps=return_timestamps,
generate_kwargs=generate_kwargs
)
if return_timestamps and "chunks" in result:
formatted_result = []
for chunk in result["chunks"]:
timestamp = f"[{chunk['timestamp'][0]:.2f}s - {chunk['timestamp'][1]:.2f}s]"
formatted_result.append(f"{timestamp} {chunk['text']}")
return "\n".join(formatted_result)
else:
return result["text"]
except Exception as e:
return f"Error processing audio: {str(e)}"
def transcribe_microphone(audio_data, task="transcribe", language="auto", return_timestamps=False):
if audio_data is None:
return "No audio recorded."
try:
sample_rate, audio_array = audio_data
audio_array = audio_array.astype(np.float32)
audio_array = audio_array / np.max(np.abs(audio_array))
with torch.inference_mode():
generate_kwargs = {
"task": task,
"language": None if language == "auto" else language,
"num_beams": 1,
"do_sample": False,
"temperature": 0.0,
"max_new_tokens": 220,
"compression_ratio_threshold": 1.35,
"logprob_threshold": -1.0,
"no_speech_threshold": 0.6,
}
if task == "translate":
generate_kwargs["task"] = "translate"
result = pipe(
{"array": audio_array, "sampling_rate": sample_rate},
return_timestamps=return_timestamps,
generate_kwargs=generate_kwargs
)
if return_timestamps and "chunks" in result:
formatted_result = []
for chunk in result["chunks"]:
timestamp = f"[{chunk['timestamp'][0]:.2f}s - {chunk['timestamp'][1]:.2f}s]"
formatted_result.append(f"{timestamp} {chunk['text']}")
return "\n".join(formatted_result)
else:
return result["text"]
except Exception as e:
return f"Error processing audio: {str(e)}"
languages = [
("Auto Detect", "auto"),
("English", "en"),
("Chinese", "zh"),
("German", "de"),
("Spanish", "es"),
("Russian", "ru"),
("Korean", "ko"),
("French", "fr"),
("Japanese", "ja"),
("Portuguese", "pt"),
("Turkish", "tr"),
("Polish", "pl"),
("Catalan", "ca"),
("Dutch", "nl"),
("Arabic", "ar"),
("Swedish", "sv"),
("Italian", "it"),
("Indonesian", "id"),
("Hindi", "hi"),
("Finnish", "fi"),
("Vietnamese", "vi"),
("Hebrew", "he"),
("Ukrainian", "uk"),
("Greek", "el"),
("Malay", "ms"),
("Czech", "cs"),
("Romanian", "ro"),
("Danish", "da"),
("Hungarian", "hu"),
("Tamil", "ta"),
("Norwegian", "no"),
("Thai", "th"),
("Urdu", "ur"),
("Croatian", "hr"),
("Bulgarian", "bg"),
("Lithuanian", "lt"),
("Latin", "la"),
]
with gr.Blocks(title="Whisper Base - Speech to Text") as demo:
gr.Markdown("# 🎤 Whisper Base - Speech to Text")
gr.Markdown("Upload an audio file or record directly to get accurate transcription using OpenAI's Whisper Base model (74M parameters).")
with gr.Tab("Upload Audio File"):
with gr.Row():
with gr.Column():
audio_file = gr.Audio(
label="Upload Audio File",
type="filepath"
)
task_file = gr.Radio(
choices=[("Transcribe", "transcribe"), ("Translate to English", "translate")],
value="transcribe",
label="Task"
)
language_file = gr.Dropdown(
choices=languages,
value="auto",
label="Source Language"
)
timestamps_file = gr.Checkbox(
label="Return Timestamps",
value=False
)
submit_file = gr.Button("Transcribe Audio File", variant="primary")
with gr.Column():
output_file = gr.Textbox(
label="Transcription Result",
lines=10,
max_lines=20
)
with gr.Tab("Record Audio"):
with gr.Row():
with gr.Column():
audio_mic = gr.Audio(
label="Record Audio",
sources=["microphone"]
)
task_mic = gr.Radio(
choices=[("Transcribe", "transcribe"), ("Translate to English", "translate")],
value="transcribe",
label="Task"
)
language_mic = gr.Dropdown(
choices=languages,
value="auto",
label="Source Language"
)
timestamps_mic = gr.Checkbox(
label="Return Timestamps",
value=False
)
submit_mic = gr.Button("Transcribe Recording", variant="primary")
with gr.Column():
output_mic = gr.Textbox(
label="Transcription Result",
lines=10,
max_lines=20
)
submit_file.click(
transcribe_audio,
inputs=[audio_file, task_file, language_file, timestamps_file],
outputs=output_file
)
submit_mic.click(
transcribe_microphone,
inputs=[audio_mic, task_mic, language_mic, timestamps_mic],
outputs=output_mic
)
gr.Markdown("### Features:")
gr.Markdown("- **Balanced Performance**: Powered by Whisper Base model (74M parameters)")
gr.Markdown("- **CPU Optimized**: Optimized for 2-core CPU with 16GB RAM")
gr.Markdown("- **Multi-language**: Supports 99+ languages")
gr.Markdown("- **Translation**: Can translate speech to English")
gr.Markdown("- **Timestamps**: Optional word-level or sentence-level timestamps")
gr.Markdown("- **Good Accuracy**: Better accuracy than Tiny with reasonable speed")
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)