Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import io | |
| import soundfile as sf | |
| import numpy as np | |
| import whisper | |
| import torch | |
| # pre-process | |
| # file object input case | |
| def trans_byte2arr(byte_data: bytes): | |
| arr_data, _ = sf.read(file=io.BytesIO(byte_data.read()), dtype="float32") | |
| sig_data = merge_sig(arr_data) | |
| return sig_data | |
| def merge_sig(arr_data): | |
| if arr_data.ndim == 2: | |
| # left right channel sound file case | |
| # element-wise add left and right | |
| sig_data = arr_data.sum(axis=1) | |
| elif arr_data.ndim > 2: | |
| print("this file is not audio file") | |
| else: | |
| return arr_data | |
| return sig_data | |
| # pre-process | |
| def audio_speed_reduce(sig_data: np.array, sample_rate: int): | |
| if sample_rate > 16000: | |
| reduce_size = sample_rate / 16000 | |
| elif sample_rate < 16000: | |
| reduce_size = 16000 / sample_rate | |
| else: | |
| reduce_size = None | |
| sig_data = merge_sig(sig_data) | |
| if reduce_size is None: | |
| return audio | |
| else: | |
| try: | |
| audio = sig_data.reshape(-1, int(reduce_size)).mean(axis=1) | |
| except: | |
| slice_size = len(sig_data) % reduce_size | |
| audio = ( | |
| sig_data[: -int(slice_size)].reshape(-1, int(reduce_size)).mean(axis=1) | |
| ) | |
| return audio | |
| def convert_byte_audio(byte_data): | |
| # convert audio from bytes | |
| arr_data, sr = sf.read(file=io.BytesIO(byte_data), dtype="float32") | |
| # reduce audio | |
| audio = audio_speed_reduce(arr_data, sr) | |
| return audio | |
| def get_langage_cls(audio_arr: np.array, model: torch.nn.Module): | |
| # data slice 30 sec | |
| audio = whisper.pad_or_trim(audio_arr) | |
| # make log-Mel spectrogram and move to the same device as the model | |
| mel = whisper.log_mel_spectrogram(audio).to(model.device) | |
| # detect the spoken language | |
| _, probs = model.detect_language(mel) | |
| return probs | |
| def transcribe(audio: np.array, model: torch.nn.Module, task: str = "transcribe"): | |
| base_option = dict(beam_size=5, best_of=5) | |
| if task == "transcribe": | |
| base_option = dict(task="transcribe", **base_option) | |
| else: | |
| base_option = dict(task="translate", **base_option) | |
| result = model.transcribe(audio, **base_option) | |
| return result["text"] | |
| def load_model(model_name: str): | |
| model = whisper.load_model(model_name) | |
| return model | |
| file_data = st.file_uploader("Upload your audio(.wav) file") | |
| if file_data is not None and file_data.name[-4:] == ".wav": | |
| # To read file as bytes: | |
| bytes_data = file_data.getvalue() | |
| audio_arr = convert_byte_audio(bytes_data) | |
| # audio plotting | |
| #fig, ax = plt.subplots() | |
| #ax.plot(audio_arr) | |
| #st.pyplot(fig) | |
| st.audio(bytes_data) | |
| model_option = [ | |
| "tiny", | |
| "base", | |
| "small", | |
| "medium", | |
| "large", | |
| ] | |
| selected_model_size = st.selectbox( | |
| "What do you want model size?", ["None"] + model_option | |
| ) | |
| if selected_model_size in model_option: | |
| model = load_model(selected_model_size) | |
| lang_button = st.button("What is language") | |
| if lang_button: | |
| with st.spinner('Detecting language...'): | |
| probs = get_langage_cls(audio_arr=audio_arr, model=model) | |
| st.write(f"Detected language: {max(probs, key=probs.get)}") | |
| task_option = ["transcribe", "translate"] | |
| translate_task = st.selectbox("What is your task", ["None"] + task_option) | |
| if translate_task != "None": | |
| with st.spinner('In progress...'): | |
| result = transcribe(audio=audio_arr, model=model, task=translate_task) | |
| st.write(result) | |