import os import torch import gradio as gr import logging import subprocess from pydub import AudioSegment from pydub.exceptions import CouldntDecodeError from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor from pathlib import Path from tempfile import NamedTemporaryFile from datetime import timedelta # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Configuration MODEL_ID = "KBLab/kb-whisper-large" CHUNK_DURATION_MS = 10000 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"} # Check for ffmpeg availability def check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) logger.info("ffmpeg is installed and accessible.") return True except (subprocess.CalledProcessError, FileNotFoundError): logger.error("ffmpeg is not installed or not found in PATH.") return False # Initialize model and pipeline def initialize_pipeline(): model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=TORCH_DTYPE, low_cpu_mem_usage=True ).to(DEVICE) processor = AutoProcessor.from_pretrained(MODEL_ID) return pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=DEVICE, torch_dtype=TORCH_DTYPE ) # Convert audio if needed def convert_to_wav(audio_path: str) -> str: if not check_ffmpeg(): raise RuntimeError("ffmpeg is required") ext = str(Path(audio_path).suffix).lower() if ext not in SUPPORTED_FORMATS: raise ValueError(f"Unsupported format: {ext}") if ext != ".wav": audio = AudioSegment.from_file(audio_path) wav_path = str(Path(audio_path).with_suffix(".converted.wav")) audio.export(wav_path, format="wav") return wav_path return audio_path # Split audio into chunks def split_audio(audio_path: str) -> list: audio = AudioSegment.from_file(audio_path) return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)] # Helper to compute chunk start time def get_chunk_time(index: int, chunk_duration_ms: int) -> str: start_ms = index * chunk_duration_ms return str(timedelta(milliseconds=start_ms)) # Transcribe audio with streaming + working download def transcribe(audio_path: str, include_timestamps: bool, progress=gr.Progress()): if not audio_path or not os.path.exists(audio_path): yield "Please upload a valid audio file.", None return wav_path = convert_to_wav(audio_path) chunks = split_audio(wav_path) transcript = [] timestamped_transcript = [] for i, chunk in enumerate(chunks): temp_file_path = None try: with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_file_path = temp_file.name chunk.export(temp_file.name, format="wav") result = PIPELINE( temp_file.name, generate_kwargs={"task": "transcribe", "language": "sv"} ) text = result["text"].strip() if text: transcript.append(text) if include_timestamps: timestamp = get_chunk_time(i, CHUNK_DURATION_MS) timestamped_transcript.append(f"[{timestamp}] {text}") finally: if temp_file_path and os.path.exists(temp_file_path): os.remove(temp_file_path) progress((i + 1) / len(chunks)) yield " ".join(transcript), None # STREAM TEXT ONLY # Create downloadable file ONLY ONCE (fix) content = ( "\n".join(timestamped_transcript) if include_timestamps else " ".join(transcript) ) with NamedTemporaryFile( suffix=".txt", delete=False, mode="w", encoding="utf-8" ) as f: f.write(content) download_path = f.name yield " ".join(transcript), download_path # FINAL OUTPUT # Initialize pipeline globally PIPELINE = initialize_pipeline() # Gradio Interface def create_interface(): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Swedish Whisper Transcriber") with gr.Row(): with gr.Column(): audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio") timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download") transcribe_btn = gr.Button("Transcribe") with gr.Column(): transcript_output = gr.Textbox(label="Live Transcription", lines=10) download_output = gr.File(label="Download Transcript") transcribe_btn.click( fn=transcribe, inputs=[audio_input, timestamp_toggle], outputs=[transcript_output, download_output] ) return demo if __name__ == "__main__": create_interface().launch()