Live_Commentary_App / pipeline.py
PlotweaverModel's picture
Upload 8 files
bad74fd verified
raw
history blame
8.87 kB
"""
Core pipeline: ASR (Whisper) + MT (NLLB-200) functions.
TTS is handled by tts_engine.py.
"""
import torch
import numpy as np
import re
import time
import os
import subprocess
import tempfile
import logging
import soundfile as sf
logger = logging.getLogger(__name__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# Models (loaded once at startup)
asr_pipe = None
mt_tokenizer = None
mt_model = None
tts_pipe_local = None # Local TTS for Yoruba/Hausa/Igbo/Zulu
def load_models():
"""Load all models at startup."""
global asr_pipe, mt_tokenizer, mt_model, tts_pipe_local
from transformers import (
pipeline as hf_pipeline,
AutoTokenizer,
AutoModelForSeq2SeqLM,
)
print(f"Device: {DEVICE} | Dtype: {TORCH_DTYPE}")
print("Loading models...")
# ASR
ASR_MODEL_ID = "PlotweaverAI/whisper-small-de-en"
print(f" Loading ASR: {ASR_MODEL_ID}")
asr_pipe = hf_pipeline(
"automatic-speech-recognition",
model=ASR_MODEL_ID,
device=DEVICE,
torch_dtype=TORCH_DTYPE,
)
print(" ASR loaded")
# MT
MT_MODEL_ID = "PlotweaverAI/nllb-200-distilled-600M-african-6lang"
print(f" Loading MT: {MT_MODEL_ID}")
mt_tokenizer = AutoTokenizer.from_pretrained(MT_MODEL_ID)
mt_model = AutoModelForSeq2SeqLM.from_pretrained(
MT_MODEL_ID, torch_dtype=TORCH_DTYPE
).to(DEVICE)
mt_tokenizer.src_lang = "eng_Latn"
print(" MT loaded")
# Local TTS (Yoruba)
TTS_MODEL_ID = "PlotweaverAI/yoruba-mms-tts-new"
print(f" Loading local TTS: {TTS_MODEL_ID}")
tts_pipe_local = hf_pipeline(
"text-to-speech",
model=TTS_MODEL_ID,
device=DEVICE,
torch_dtype=TORCH_DTYPE,
)
print(" Local TTS loaded")
# Diagnostics
print(f"\n=== Device diagnostics ===")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
print(f"ASR on: {next(asr_pipe.model.parameters()).device}")
print(f"MT on: {next(mt_model.parameters()).device}")
print(f"TTS on: {next(tts_pipe_local.model.parameters()).device}")
print(f"YourVoic API key: {'set' if os.environ.get('YOURVOIC_API_KEY') else 'NOT SET'}")
print(f"==========================\n")
print("All models loaded!")
# ---- Text Processing ----
def split_into_sentences(text):
"""Split raw ASR text into individual sentences."""
text = text.strip()
if not text:
return []
text = '. '.join(s.strip().capitalize() for s in text.split('. ') if s.strip())
if re.search(r'[.!?]', text):
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
words = text.split()
MAX_WORDS = 12
sentences = []
for i in range(0, len(words), MAX_WORDS):
chunk = ' '.join(words[i:i + MAX_WORDS])
if not chunk.endswith(('.', '!', '?')):
chunk += '.'
chunk = chunk[0].upper() + chunk[1:] if len(chunk) > 1 else chunk.upper()
sentences.append(chunk)
return sentences
# ---- ASR ----
def transcribe(audio_array, sample_rate=16000):
"""ASR: English audio to text. Handles both short and long audio."""
if len(audio_array) < 1600:
return ""
duration_s = len(audio_array) / sample_rate
if sample_rate != 16000:
import torchaudio.functional as F_audio
audio_tensor = torch.from_numpy(audio_array).float()
audio_tensor = F_audio.resample(audio_tensor, sample_rate, 16000)
audio_array = audio_tensor.numpy()
sample_rate = 16000
if duration_s <= 28:
result = asr_pipe(
{"raw": audio_array, "sampling_rate": sample_rate},
return_timestamps=False,
)
return result["text"].strip()
# Long-form: native Whisper generate
model = asr_pipe.model
processor = asr_pipe.feature_extractor
tokenizer = asr_pipe.tokenizer
inputs = processor(
audio_array, sampling_rate=16000, return_tensors="pt",
truncation=False, padding="longest", return_attention_mask=True,
)
input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
attention_mask = inputs.attention_mask.to(DEVICE) if "attention_mask" in inputs else None
generate_kwargs = {"return_timestamps": True, "language": "en", "task": "transcribe"}
if attention_mask is not None:
generate_kwargs["attention_mask"] = attention_mask
with torch.no_grad():
predicted_ids = model.generate(input_features, **generate_kwargs)
transcription = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription.strip()
# ---- MT ----
def translate_sentence(text, target_nllb_code, fast=True, max_length=256):
"""Translate a single sentence from English to target language."""
inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
tgt_lang_id = mt_tokenizer.convert_tokens_to_ids(target_nllb_code)
generate_kwargs = {
"forced_bos_token_id": tgt_lang_id,
"repetition_penalty": 1.5,
"no_repeat_ngram_size": 3,
}
if fast:
generate_kwargs.update({"max_length": 128, "num_beams": 1, "do_sample": False})
else:
generate_kwargs.update({"max_length": max_length, "num_beams": 4, "early_stopping": True})
with torch.no_grad():
output_ids = mt_model.generate(**inputs, **generate_kwargs)
return mt_tokenizer.decode(output_ids[0], skip_special_tokens=True)
def translate_text(text, target_nllb_code, fast=True):
"""Split and translate full text sentence-by-sentence."""
sentences = split_into_sentences(text)
if not sentences:
return "", [], []
translations = []
for s in sentences:
yo = translate_sentence(s, target_nllb_code, fast=fast)
translations.append(yo)
return ' '.join(translations), sentences, translations
# ---- Video Processing ----
def extract_audio_from_video(video_path, output_path, target_sr=16000):
"""Extract audio track from video as 16kHz mono WAV."""
cmd = [
"ffmpeg", "-y", "-i", video_path,
"-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1",
output_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg extraction failed: {result.stderr[:200]}")
return output_path
def get_media_duration(path):
"""Get duration in seconds."""
cmd = [
"ffprobe", "-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffprobe failed: {result.stderr[:200]}")
return float(result.stdout.strip())
def stretch_audio_to_duration(input_path, output_path, target_duration_s):
"""Stretch/compress audio to match target duration."""
current_duration = get_media_duration(input_path)
if current_duration <= 0:
raise RuntimeError("Invalid audio duration")
ratio = current_duration / target_duration_s
filters = []
remaining = ratio
while remaining > 2.0:
filters.append("atempo=2.0")
remaining /= 2.0
while remaining < 0.5:
filters.append("atempo=0.5")
remaining /= 0.5
filters.append(f"atempo={remaining:.4f}")
cmd = ["ffmpeg", "-y", "-i", input_path, "-filter:a", ",".join(filters), output_path]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg tempo failed: {result.stderr[:200]}")
return output_path
def mux_video_audio(video_path, audio_path, output_path, extend_video=False, target_duration=None):
"""Combine video with new audio. Optionally extend video by freezing last frame."""
if extend_video and target_duration:
cmd = [
"ffmpeg", "-y", "-i", video_path, "-i", audio_path,
"-filter_complex", f"[0:v]tpad=stop_mode=clone:stop_duration={target_duration}[v]",
"-map", "[v]", "-map", "1:a:0",
"-c:v", "libx264", "-preset", "fast", "-c:a", "aac",
"-t", str(target_duration), output_path,
]
else:
cmd = [
"ffmpeg", "-y", "-i", video_path, "-i", audio_path,
"-c:v", "copy", "-c:a", "aac",
"-map", "0:v:0", "-map", "1:a:0", "-shortest", output_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg mux failed: {result.stderr[:200]}")
return output_path