Transcriber / app.py
Jonascaps1's picture
Update app.py
e6c1f00 verified
import os
import torch
import gradio as gr
import logging
import subprocess
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from pathlib import Path
from tempfile import NamedTemporaryFile
from datetime import timedelta
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Configuration
MODEL_ID = "KBLab/kb-whisper-large"
CHUNK_DURATION_MS = 10000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"}
# Check for ffmpeg availability
def check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
logger.info("ffmpeg is installed and accessible.")
return True
except (subprocess.CalledProcessError, FileNotFoundError):
logger.error("ffmpeg is not installed or not found in PATH.")
return False
# Initialize model and pipeline
def initialize_pipeline():
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=TORCH_DTYPE,
low_cpu_mem_usage=True
).to(DEVICE)
processor = AutoProcessor.from_pretrained(MODEL_ID)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
device=DEVICE,
torch_dtype=TORCH_DTYPE
)
# Convert audio if needed
def convert_to_wav(audio_path: str) -> str:
if not check_ffmpeg():
raise RuntimeError("ffmpeg is required")
ext = str(Path(audio_path).suffix).lower()
if ext not in SUPPORTED_FORMATS:
raise ValueError(f"Unsupported format: {ext}")
if ext != ".wav":
audio = AudioSegment.from_file(audio_path)
wav_path = str(Path(audio_path).with_suffix(".converted.wav"))
audio.export(wav_path, format="wav")
return wav_path
return audio_path
# Split audio into chunks
def split_audio(audio_path: str) -> list:
audio = AudioSegment.from_file(audio_path)
return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]
# Helper to compute chunk start time
def get_chunk_time(index: int, chunk_duration_ms: int) -> str:
start_ms = index * chunk_duration_ms
return str(timedelta(milliseconds=start_ms))
# Transcribe audio with streaming + working download
def transcribe(audio_path: str, include_timestamps: bool, progress=gr.Progress()):
if not audio_path or not os.path.exists(audio_path):
yield "Please upload a valid audio file.", None
return
wav_path = convert_to_wav(audio_path)
chunks = split_audio(wav_path)
transcript = []
timestamped_transcript = []
for i, chunk in enumerate(chunks):
temp_file_path = None
try:
with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_file_path = temp_file.name
chunk.export(temp_file.name, format="wav")
result = PIPELINE(
temp_file.name,
generate_kwargs={"task": "transcribe", "language": "sv"}
)
text = result["text"].strip()
if text:
transcript.append(text)
if include_timestamps:
timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
timestamped_transcript.append(f"[{timestamp}] {text}")
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
progress((i + 1) / len(chunks))
yield " ".join(transcript), None # STREAM TEXT ONLY
# Create downloadable file ONLY ONCE (fix)
content = (
"\n".join(timestamped_transcript)
if include_timestamps
else " ".join(transcript)
)
with NamedTemporaryFile(
suffix=".txt",
delete=False,
mode="w",
encoding="utf-8"
) as f:
f.write(content)
download_path = f.name
yield " ".join(transcript), download_path # FINAL OUTPUT
# Initialize pipeline globally
PIPELINE = initialize_pipeline()
# Gradio Interface
def create_interface():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Swedish Whisper Transcriber")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio")
timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download")
transcribe_btn = gr.Button("Transcribe")
with gr.Column():
transcript_output = gr.Textbox(label="Live Transcription", lines=10)
download_output = gr.File(label="Download Transcript")
transcribe_btn.click(
fn=transcribe,
inputs=[audio_input, timestamp_toggle],
outputs=[transcript_output, download_output]
)
return demo
if __name__ == "__main__":
create_interface().launch()