| import argparse |
| import io |
| from typing import Any, Dict, List, Tuple, Union |
|
|
| import moviepy.editor as mp |
| import numpy as np |
| import streamlit as st |
| import torch |
| import torchaudio |
| import torchaudio.transforms as T |
| from scipy.io import wavfile |
| from streamlit_mic_recorder import mic_recorder |
| from transformers import ( |
| AutomaticSpeechRecognitionPipeline, |
| AutoModelForSpeechSeq2Seq, |
| AutoProcessor, |
| pipeline, |
| ) |
|
|
|
|
| def parse_arguments() -> argparse.Namespace: |
| """Parse command-line arguments.""" |
| parser = argparse.ArgumentParser( |
| description="Streamlit app for speech transcription." |
| ) |
| parser.add_argument( |
| "--model_id", type=str, required=True, help="Path to the model directory" |
| ) |
| return parser.parse_args() |
|
|
|
|
| |
| @st.cache_resource |
| def load_model_and_processor( |
| model_id: str, |
| ) -> Tuple[AutoModelForSpeechSeq2Seq, AutoProcessor]: |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
| ) |
| model.to(device) |
| model.generation_config.median_filter_width = 3 |
| processor = AutoProcessor.from_pretrained(model_id) |
| return model, processor |
|
|
|
|
| |
| @st.cache_resource |
| def setup_pipeline( |
| _model: AutoModelForSpeechSeq2Seq, _processor: AutoProcessor |
| ) -> AutomaticSpeechRecognitionPipeline: |
| return pipeline( |
| "automatic-speech-recognition", |
| model=_model, |
| tokenizer=_processor.tokenizer, |
| feature_extractor=_processor.feature_extractor, |
| chunk_length_s=30, |
| batch_size=1, |
| return_timestamps=True, |
| torch_dtype=torch_dtype, |
| device=device, |
| ) |
|
|
|
|
| def wav_to_black_mp4(wav_path: str, output_path: str, fps: int = 25) -> None: |
| """Convert WAV file to a black-screen MP4 with the same audio.""" |
| waveform, sample_rate = torchaudio.load(wav_path) |
| duration: float = waveform.shape[1] / sample_rate |
| audio = mp.AudioFileClip(wav_path) |
| black_clip = mp.ColorClip((256, 250), color=(0, 0, 0), duration=duration) |
| final_clip = black_clip.set_audio(audio) |
| final_clip.write_videofile(output_path, fps=fps) |
|
|
|
|
| def timestamps_to_vtt(timestamps: List[Dict[str, Union[str, Any]]]) -> str: |
| """Convert timestamps to VTT format.""" |
| vtt_content: str = "WEBVTT\n\n" |
| for word in timestamps: |
| start_time, end_time = word["timestamp"] |
| start_time_str = f"{int(start_time // 3600)}:{int(start_time // 60 % 60):02d}:{start_time % 60:06.3f}" |
| end_time_str = f"{int(end_time // 3600)}:{int(end_time // 60 % 60):02d}:{end_time % 60:06.3f}" |
| vtt_content += f"{start_time_str} --> {end_time_str}\n{word['text']}\n\n" |
| return vtt_content |
|
|
|
|
| def process_audio_bytes(audio_bytes: bytes) -> torch.Tensor: |
| """Process audio bytes to the required format.""" |
| audio_stream = io.BytesIO(audio_bytes) |
| sr, y = wavfile.read(audio_stream) |
| y = y.astype(np.float32) |
| y_mean = np.mean(y) |
| y_std = np.std(y) |
| y_normalized = (y - y_mean) / y_std |
| transform = T.Resample(sr, 16000) |
| waveform = transform(torch.unsqueeze(torch.tensor(y_normalized / 8), 0)) |
| torchaudio.save("sample.wav", waveform, sample_rate=16000) |
| return waveform |
|
|
|
|
| def transcribe(audio_bytes: bytes) -> Dict[str, Any]: |
| """Transcribe the given audio bytes.""" |
| waveform = process_audio_bytes(audio_bytes) |
| transcription = pipe(waveform[0, :].numpy(), return_timestamps="word") |
| return transcription |
|
|
|
|
| args = parse_arguments() |
| model_id = args.model_id |
|
|
| |
| device: str = "cuda:0" if torch.cuda.is_available() else "cpu" |
| torch_dtype: torch.dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| model, processor = load_model_and_processor(model_id) |
| pipe = setup_pipeline(model, processor) |
|
|
| |
| st.title("CrisperWhisper++ 🦻") |
| st.subheader("Caution when using. Make sure you can handle the crispness. ⚠️") |
| st.write("🎙️ Record an audio to transcribe or 📁 upload an audio file.") |
|
|
| |
| audio = mic_recorder( |
| start_prompt="Start recording", |
| stop_prompt="Stop recording", |
| just_once=False, |
| use_container_width=False, |
| format="wav", |
| callback=None, |
| args=(), |
| kwargs={}, |
| key=None, |
| ) |
|
|
| audio_bytes: Union[bytes, None] = audio["bytes"] if audio else None |
|
|
| |
| audio_file = st.file_uploader("Or upload an audio file", type=["wav", "mp3", "ogg"]) |
|
|
| if audio_file is not None: |
| audio_bytes = audio_file.getvalue() |
|
|
| if audio_bytes: |
| try: |
| transcription = transcribe(audio_bytes) |
| vtt = timestamps_to_vtt(transcription["chunks"]) |
|
|
| with open("subtitles.vtt", "w") as file: |
| file.write(vtt) |
|
|
| wav_to_black_mp4("sample.wav", "video.mp4") |
|
|
| st.video("video.mp4", subtitles="subtitles.vtt") |
| st.subheader("Transcription:") |
| st.markdown( |
| f""" |
| <div style="background-color: #f0f0f0; padding: 10px; border-radius: 5px;"> |
| <p style="font-size: 16px; color: #333;">{transcription['text']}</p> |
| </div> |
| """, |
| unsafe_allow_html=True, |
| ) |
| except Exception as e: |
| st.error(f"An error occurred during transcription: {e}") |
|
|
| |
| st.markdown( |
| """ |
| <hr> |
| <footer> |
| <p style="text-align: center;">© 2024 nyra health GmbH</p> |
| </footer> |
| """, |
| unsafe_allow_html=True, |
| ) |
|
|