offiongbassey's picture
Create app.py
a37473a verified
import gradio as gr
import torch
import numpy as np
import librosa
import os
import ctranslate2
from transformers import (
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoTokenizer,
VitsModel,
)
ASR_MODEL = "offiongbassey/efik_whisper_asr"
MT_MODEL = "offiongbassey/efik-mt"
CT2_DIR = "./ct2_mt"
TTS_EFIK = "offiongbassey/efik-mms-tts"
TTS_ENG = "facebook/mms-tts-eng"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print("Loading ASR...")
processor = AutoProcessor.from_pretrained(ASR_MODEL)
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(
ASR_MODEL,
torch_dtype=dtype,
low_cpu_mem_usage=True,
).to(device)
asr_model.eval()
print("ASR Loaded")
print("Loading MT tokenizer...")
mt_tokenizer = AutoTokenizer.from_pretrained(MT_MODEL)
print("MT tokenizer loaded")
if not os.path.exists(CT2_DIR):
print("Converting MT model to CTranslate2 format...")
os.system(
f"ct2-transformers-converter "
f"--model {MT_MODEL} "
f"--output_dir {CT2_DIR} "
f"--quantization int8"
)
print("Conversion done")
print("Loading CTranslate2 translator...")
translator = ctranslate2.Translator(CT2_DIR, device=device, compute_type="int8")
print("Translator loaded")
_tts_cache = {}
def get_tts(model_id: str):
"""Load and cache a VITS/MMS-TTS model + tokenizer."""
if model_id not in _tts_cache:
print(f"Loading TTS model: {model_id} ...")
tok = AutoTokenizer.from_pretrained(model_id)
model = VitsModel.from_pretrained(model_id).to(device)
model.eval()
_tts_cache[model_id] = (tok, model)
print(f"TTS model loaded: {model_id}")
return _tts_cache[model_id]
def fix_audio(audio):
sr, wav = audio
if len(wav.shape) > 1:
wav = np.mean(wav, axis=1)
wav = wav.astype(np.float32)
if sr != 16000:
wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
mx = np.abs(wav).max()
if mx > 0:
wav = wav / mx
return wav
def transcribe(audio):
if audio is None:
return ""
wav = fix_audio(audio)
inputs = processor(wav, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
ids = asr_model.generate(**inputs, max_new_tokens=128, num_beams=1)
return processor.batch_decode(ids, skip_special_tokens=True)[0]
def translate(text: str, src_lang: str, tgt_lang: str) -> str:
"""
src_lang / tgt_lang are NLLB language codes, e.g. "ibo_Latn", "eng_Latn".
The MT model was trained on Efik; we reuse the same CTranslate2 translator
for both directions.
"""
if not text:
return ""
input_text = f"{src_lang} {text}"
ids = mt_tokenizer.encode(input_text)
tokens = mt_tokenizer.convert_ids_to_tokens(ids)
results = translator.translate_batch(
[tokens],
target_prefix=[[tgt_lang]],
beam_size=4,
)
out = results[0].hypotheses[0]
if out and out[0] == tgt_lang:
out = out[1:]
ids = mt_tokenizer.convert_tokens_to_ids(out)
return mt_tokenizer.decode(ids, skip_special_tokens=True)
def synthesise(text: str, tts_model_id: str):
"""Return (sample_rate, waveform_np) tuple for Gradio Audio output."""
if not text:
return None
tok, model = get_tts(tts_model_id)
inputs = tok(text, return_tensors="pt").to(device)
with torch.no_grad():
output = model(**inputs)
# VitsModel returns waveform in output.waveform shape: (batch, channels, time)
wav = output.waveform[0].squeeze().cpu().float().numpy()
sr = model.config.sampling_rate
return (sr, wav)
DIRECTIONS = {
"Efik β†’ English": {
"src_lang" : "ibo_Latn", # token used in the Efik-MT model
"tgt_lang" : "eng_Latn",
"src_label" : "Efik Text",
"tgt_label" : "English Translation",
"tts_model" : TTS_ENG,
},
"English β†’ Efik": {
"src_lang" : "eng_Latn",
"tgt_lang" : "ibo_Latn",
"src_label" : "English Text",
"tgt_label" : "Efik Translation",
"tts_model" : TTS_EFIK,
},
}
def pipeline(audio, direction: str):
try:
cfg = DIRECTIONS[direction]
transcribed = transcribe(audio)
translated = translate(transcribed, cfg["src_lang"], cfg["tgt_lang"])
speech = synthesise(translated, cfg["tts_model"])
return transcribed, translated, speech
except Exception as e:
import traceback
traceback.print_exc()
return f"ERROR: {str(e)}", "", None
with gr.Blocks(title="Efik Speech Translator") as demo:
gr.Markdown("# 🎀 Efik Speech Translator")
gr.Markdown(
"Record or upload audio β†’ transcribe β†’ translate β†’ hear the result.\n\n"
"Use the toggle below to switch translation direction."
)
direction = gr.Radio(
choices=list(DIRECTIONS.keys()),
value="Efik β†’ English",
label="Translation Direction",
interactive=True,
)
mic = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio")
btn = gr.Button("πŸš€ Translate", variant="primary")
with gr.Column():
out_transcribed = gr.Textbox(label="Transcribed Text", interactive=False)
out_translated = gr.Textbox(label="Translated Text", interactive=False)
out_audio = gr.Audio(label="Generated Speech", interactive=False, autoplay=True)
btn.click(
fn=pipeline,
inputs=[mic, direction],
outputs=[out_transcribed, out_translated, out_audio],
)
demo.launch()