offiongbassey's picture
Added c2translate
afa11f8 verified
raw
history blame
11.5 kB
"""
Live Football Commentary Pipeline β€” English β†’ Yoruba
=====================================================
Gradio app for HuggingFace Spaces.
Pipeline: ASR (Whisper) β†’ MT (NLLB-200 via CTranslate2) β†’ TTS (MMS-TTS Yoruba)
"""
import torch
import numpy as np
import re
import time
import gradio as gr
import ctranslate2
from transformers import AutoTokenizer
from transformers import pipeline as hf_pipeline
# =============================================================================
# Configuration
# =============================================================================
ASR_MODEL_ID = "PlotweaverAI/whisper-small-de-en"
MT_MODEL_ID = "PlotweaverAI/nllb-200-distilled-600M-african-6lang"
TTS_MODEL_ID = "PlotweaverAI/yoruba-mms-tts-new"
CT2_MODEL_DIR = "./nllb_ct2" # Local dir where converted model is saved
MT_SRC_LANG = "eng_Latn"
MT_TGT_LANG = "yor_Latn"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
CT2_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CT2_COMPUTE_TYPE = "int8_float16" if torch.cuda.is_available() else "int8"
# =============================================================================
# Convert MT model to CTranslate2 format (runs once at startup if needed)
# =============================================================================
import os
if not os.path.exists(CT2_MODEL_DIR):
print(f"Converting {MT_MODEL_ID} to CTranslate2 format...")
import subprocess
subprocess.run([
"ct2-transformers-converter",
"--model", MT_MODEL_ID,
"--output_dir", CT2_MODEL_DIR,
"--quantization", "int8", # int8 = fastest on CPU; use int8_float16 on GPU
"--force",
], check=True)
print("Conversion done βœ“")
# =============================================================================
# Load models (runs once at startup)
# =============================================================================
print(f"Device: {DEVICE} | CT2 Compute: {CT2_COMPUTE_TYPE}")
print("Loading models...")
# ASR
print(f" Loading ASR: {ASR_MODEL_ID}")
asr_pipe = hf_pipeline(
"automatic-speech-recognition",
model=ASR_MODEL_ID,
device=DEVICE,
torch_dtype=TORCH_DTYPE,
)
print(" ASR loaded βœ“")
# MT β€” CTranslate2 Translator (replaces AutoModelForSeq2SeqLM)
print(f" Loading MT (CTranslate2): {CT2_MODEL_DIR}")
mt_tokenizer = AutoTokenizer.from_pretrained(MT_MODEL_ID)
mt_translator = ctranslate2.Translator(
CT2_MODEL_DIR,
device=CT2_DEVICE,
compute_type=CT2_COMPUTE_TYPE,
inter_threads=2, # allows parallel sentence translations
)
print(" MT (CTranslate2) loaded βœ“")
# TTS
print(f" Loading TTS: {TTS_MODEL_ID}")
tts_pipe = hf_pipeline(
"text-to-speech",
model=TTS_MODEL_ID,
device=DEVICE,
torch_dtype=TORCH_DTYPE,
)
print(" TTS loaded βœ“")
print("All models loaded!")
# =============================================================================
# Pipeline functions
# =============================================================================
def split_into_sentences(text):
"""Split raw ASR text into individual sentences for MT."""
text = text.strip()
if not text:
return []
text = '. '.join(s.strip().capitalize() for s in text.split('. ') if s.strip())
if re.search(r'[.!?]', text):
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
words = text.split()
MAX_WORDS = 12
sentences = []
for i in range(0, len(words), MAX_WORDS):
chunk = ' '.join(words[i:i + MAX_WORDS])
if not chunk.endswith(('.', '!', '?')):
chunk += '.'
chunk = chunk[0].upper() + chunk[1:] if len(chunk) > 1 else chunk.upper()
sentences.append(chunk)
return sentences
def transcribe(audio_array, sample_rate=16000):
"""ASR: English audio β†’ English text."""
result = asr_pipe(
{"raw": audio_array, "sampling_rate": sample_rate},
chunk_length_s=15,
batch_size=1,
return_timestamps=False,
)
return result["text"].strip()
def translate_batch_ct2(sentences):
"""
MT: Translate a batch of sentences from English β†’ Yoruba using CTranslate2.
Much faster than calling .generate() per sentence.
"""
# Tokenize all sentences at once
mt_tokenizer.src_lang = MT_SRC_LANG
tgt_lang_token = MT_TGT_LANG
# Encode to token strings (CTranslate2 works with token lists, not IDs)
tokenized = [
mt_tokenizer.convert_ids_to_tokens(
mt_tokenizer.encode(s, add_special_tokens=True)
)
for s in sentences
]
tgt_prefix = [[tgt_lang_token]] * len(sentences)
results = mt_translator.translate_batch(
tokenized,
target_prefix=tgt_prefix,
beam_size=4,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
max_decoding_length=256,
)
translations = []
for result in results:
tokens = result.hypotheses[0]
# Remove the language token prefix if present
if tokens and tokens[0] == tgt_lang_token:
tokens = tokens[1:]
text = mt_tokenizer.decode(
mt_tokenizer.convert_tokens_to_ids(tokens),
skip_special_tokens=True,
)
translations.append(text)
return translations
def translate_long_text(text):
"""Split into sentences and translate as a batch."""
sentences = split_into_sentences(text)
if not sentences:
return "", [], []
translations = translate_batch_ct2(sentences)
return ' '.join(translations), sentences, translations
def synthesize(text):
"""TTS: Yoruba text β†’ audio."""
result = tts_pipe(text)
audio = np.array(result["audio"]).squeeze()
sr = result["sampling_rate"]
return audio, sr
# =============================================================================
# Gradio interface functions
# =============================================================================
def process_audio(audio_input):
if audio_input is None:
return None, "⚠️ No audio provided. Please upload or record audio."
sample_rate, audio_array = audio_input
audio_array = audio_array.astype(np.float32)
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
if audio_array.max() > 1.0 or audio_array.min() < -1.0:
audio_array = audio_array / max(abs(audio_array.max()), abs(audio_array.min()))
total_start = time.time()
log_lines = []
t0 = time.time()
english_text = transcribe(audio_array, sample_rate)
log_lines.append(f"**🎀 ASR** ({time.time()-t0:.2f}s)")
log_lines.append(f"English: {english_text}\n")
if not english_text:
return None, "⚠️ ASR returned empty text."
t0 = time.time()
yoruba_text, en_sentences, yo_sentences = translate_long_text(english_text)
log_lines.append(f"**πŸ”„ Translation (CTranslate2)** ({time.time()-t0:.2f}s)")
for en_s, yo_s in zip(en_sentences, yo_sentences):
log_lines.append(f" EN: {en_s}")
log_lines.append(f" YO: {yo_s}")
log_lines.append("")
if not yoruba_text:
return None, "⚠️ Translation returned empty text."
t0 = time.time()
yoruba_audio, output_sr = synthesize(yoruba_text)
log_lines.append(f"**πŸ”Š TTS** ({time.time()-t0:.2f}s) β†’ {len(yoruba_audio)/output_sr:.2f}s of audio")
log_lines.append(f"\n**Total: {time.time()-total_start:.2f}s**")
return (output_sr, yoruba_audio), "\n".join(log_lines)
def process_text(english_text):
if not english_text or not english_text.strip():
return None, "⚠️ Please enter some English text."
total_start = time.time()
log_lines = []
t0 = time.time()
yoruba_text, en_sentences, yo_sentences = translate_long_text(english_text.strip())
log_lines.append(f"**πŸ”„ Translation (CTranslate2)** ({time.time()-t0:.2f}s)")
for en_s, yo_s in zip(en_sentences, yo_sentences):
log_lines.append(f" EN: {en_s}")
log_lines.append(f" YO: {yo_s}")
log_lines.append("")
if not yoruba_text:
return None, "⚠️ Translation returned empty text."
t0 = time.time()
yoruba_audio, output_sr = synthesize(yoruba_text)
log_lines.append(f"**πŸ”Š TTS** ({time.time()-t0:.2f}s) β†’ {len(yoruba_audio)/output_sr:.2f}s of audio")
log_lines.append(f"\n**Total: {time.time()-total_start:.2f}s**")
return (output_sr, yoruba_audio), "\n".join(log_lines)
# =============================================================================
# Gradio UI
# =============================================================================
DESCRIPTION = """
# 🏟️ Live Football Commentary β€” English β†’ Yoruba
Translate English football commentary into Yoruba speech in real-time.
**Pipeline:** ASR (Whisper) β†’ MT (NLLB-200 via CTranslate2) β†’ TTS (MMS-TTS Yoruba)
"""
EXAMPLES_TEXT = [
"And it's a brilliant goal from the striker!",
"The referee has shown a yellow card. Corner kick for the home team.",
"What a save by the goalkeeper! The match is heading into injury time.",
"He dribbles past two defenders and shoots! The ball hits the back of the net!",
]
with gr.Blocks(title="Football Commentary EN→YO", theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.TabItem("πŸŽ™οΈ Audio β†’ Audio (Full Pipeline)"):
gr.Markdown("Upload or record English commentary. The pipeline will transcribe, translate, and synthesize Yoruba audio.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(label="English Commentary Audio", type="numpy", sources=["upload", "microphone"])
audio_submit_btn = gr.Button("Translate to Yoruba", variant="primary", size="lg")
with gr.Column():
audio_output = gr.Audio(label="Yoruba Commentary Audio", type="numpy")
audio_log = gr.Markdown(label="Pipeline Log")
audio_submit_btn.click(fn=process_audio, inputs=[audio_input], outputs=[audio_output, audio_log])
with gr.TabItem("πŸ“ Text β†’ Audio (Translation + TTS)"):
gr.Markdown("Type or paste English text to translate to Yoruba and hear the result.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="English Text", placeholder="Type English football commentary here...", lines=4)
text_submit_btn = gr.Button("Translate to Yoruba", variant="primary", size="lg")
gr.Examples(examples=[[e] for e in EXAMPLES_TEXT], inputs=[text_input], label="Example Commentary")
with gr.Column():
text_audio_output = gr.Audio(label="Yoruba Audio", type="numpy")
text_log = gr.Markdown(label="Pipeline Log")
text_submit_btn.click(fn=process_text, inputs=[text_input], outputs=[text_audio_output, text_log])
gr.Markdown("""
---
**Models used:**
[ASR: PlotweaverAI/whisper-small-de-en](https://huggingface.co/PlotweaverAI/whisper-small-de-en) |
[MT: PlotweaverAI/nllb-200-distilled-600M-african-6lang](https://huggingface.co/PlotweaverAI/nllb-200-distilled-600M-african-6lang) |
[TTS: PlotweaverAI/yoruba-mms-tts-new](https://huggingface.co/PlotweaverAI/yoruba-mms-tts-new)
""")
if __name__ == "__main__":
demo.launch()