Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import librosa | |
| from transformers import Wav2Vec2ProcessorWithLM, AutoModelForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor | |
| import torch | |
| model_name = os.getenv("MODEL_NAME") | |
| auth_token = os.getenv("API_TOKEN") | |
| # Load models | |
| tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name, eos_token=None, bos_token=None, use_auth_token=auth_token) | |
| processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name, use_auth_token=auth_token) | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name, use_auth_token=auth_token) | |
| decoder = processor.decoder | |
| processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) | |
| model = AutoModelForCTC.from_pretrained(model_name, use_auth_token=auth_token) | |
| def load_data(input_file): | |
| # Read the file | |
| speech, sample_rate = librosa.load(input_file) | |
| # Make it 1-D | |
| if len(speech.shape) > 1: | |
| speech = speech[:,0] + speech[:,1] | |
| # Resampling at 16KHz | |
| if sample_rate !=16_000: | |
| speech = librosa.resample(speech, sample_rate, 16_000) | |
| return speech | |
| def transcribe(input_file): | |
| audio = load_data(input_file) | |
| # audio = input_file | |
| # Tokenize | |
| input_values = processor(audio, return_tensors="pt", sampling_rate=16_000).input_values | |
| # Take logits | |
| with torch.no_grad(): | |
| logits = model(input_values).logits.cpu().numpy()[0] | |
| # Decode | |
| text = decoder.decode(logits, beam_width=30) | |
| return text | |
| examples = [ | |
| ["examples/example1.mp3"], | |
| ["examples/example2.mp3"], | |
| ] | |
| gr.Interface( | |
| title="Rozpoznání mluvené řeči pro český jazyk", | |
| fn=transcribe, | |
| inputs=gr.inputs.Audio(source="upload", type="filepath"), | |
| outputs="text", | |
| examples=examples | |
| ).launch() | |