ASR_Morisyen / app.py
eleferrand's picture
Update app.py
b6d95a4 verified
import gradio as gr
from transformers import Wav2Vec2Processor, AutoModelForCTC
from transformers import Wav2Vec2ProcessorWithLM
import torch
from pyctcdecode import build_ctcdecoder
import librosa
import logging
# This hides the 'Exception ignored in' messages which are usually harmless during shutdown
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
# Replace with your specific model path
MODEL_ID = "eleferrand/w2v-Morisyen"
# Initialize the ASR pipeline
if torch.cuda.is_available():
device="cuda"
else:
device="cpu"
lm_model= "Morisyen.arpa"
path_checkpoint = MODEL_ID
model = AutoModelForCTC.from_pretrained(path_checkpoint).to(device)
processor = Wav2Vec2Processor.from_pretrained(path_checkpoint)
vocab = processor.tokenizer.get_vocab()
vocab[' '] = vocab['|']
del vocab[' ']
sorted_dict = {k.lower(): v for k, v in sorted(vocab.items(), key=lambda item: item[1])}
print(sorted_dict)
decoder = build_ctcdecoder(
list(sorted_dict.keys()),
lm_model,
alpha = 0.5,
beta = 1.5
)
processor_with_lm = Wav2Vec2ProcessorWithLM(
feature_extractor=processor.feature_extractor,
tokenizer=processor.tokenizer,
decoder=decoder
)
def transcribe(audio_path, request: gr.Request):
if audio_path is None:
return "No audio recorded", ""
is_api = request.headers.get("sec-ch-ua") is None
w, sr = librosa.load(audio_path, sr=16000)
entry = {"input_values" : processor(w, sampling_rate=sr).input_values[0]}
entry["input_length"] = len(entry["input_values"])
input_dict = processor_with_lm(entry["input_values"], return_tensors="pt",sampling_rate=16000, padding=True)
logits = model(input_dict.input_values.to(device)).logits
transc = processor_with_lm.decode(logits[0].cpu().detach().numpy()).text
if is_api==False:
direction = ""
if "dwat" in transc:
direction = " ale a dwat..."
elif "gos" in transc:
direction = "ale a gos..."
elif "dwa" in transc:
direction = "ale tout dwa"
elif "se bon" in transc:
direction = "end"
else:
direction = transc
return direction, transc
else:
return transc, "something"
# Define the interface
demo = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs=[gr.Text(), gr.Text()],
api_name="transcribe"
)
# Use this to prevent background threads from hanging on exit
demo.launch(show_error=True)