""" Core pipeline: ASR (Whisper) + MT (NLLB-200) functions. TTS is handled by tts_engine.py. """ import torch import numpy as np import re import time import os import subprocess import tempfile import logging import soundfile as sf logger = logging.getLogger(__name__) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 # Models (loaded once at startup) asr_pipe = None mt_tokenizer = None mt_model = None tts_pipe_local = None # Local TTS for Yoruba/Hausa/Igbo/Zulu def load_models(): """Load all models at startup.""" global asr_pipe, mt_tokenizer, mt_model, tts_pipe_local from transformers import ( pipeline as hf_pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, ) print(f"Device: {DEVICE} | Dtype: {TORCH_DTYPE}") print("Loading models...") # ASR ASR_MODEL_ID = "PlotweaverAI/whisper-small-de-en" print(f" Loading ASR: {ASR_MODEL_ID}") asr_pipe = hf_pipeline( "automatic-speech-recognition", model=ASR_MODEL_ID, device=DEVICE, torch_dtype=TORCH_DTYPE, ) print(" ASR loaded") # MT MT_MODEL_ID = "PlotweaverAI/nllb-200-distilled-600M-african-6lang" print(f" Loading MT: {MT_MODEL_ID}") mt_tokenizer = AutoTokenizer.from_pretrained(MT_MODEL_ID) mt_model = AutoModelForSeq2SeqLM.from_pretrained( MT_MODEL_ID, torch_dtype=TORCH_DTYPE ).to(DEVICE) mt_tokenizer.src_lang = "eng_Latn" print(" MT loaded") # Local TTS (Yoruba) TTS_MODEL_ID = "PlotweaverAI/yoruba-mms-tts-new" print(f" Loading local TTS: {TTS_MODEL_ID}") tts_pipe_local = hf_pipeline( "text-to-speech", model=TTS_MODEL_ID, device=DEVICE, torch_dtype=TORCH_DTYPE, ) print(" Local TTS loaded") # Diagnostics print(f"\n=== Device diagnostics ===") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"CUDA device: {torch.cuda.get_device_name(0)}") print(f"ASR on: {next(asr_pipe.model.parameters()).device}") print(f"MT on: {next(mt_model.parameters()).device}") print(f"TTS on: {next(tts_pipe_local.model.parameters()).device}") print(f"YourVoic API key: {'set' if os.environ.get('YOURVOIC_API_KEY') else 'NOT SET'}") print(f"==========================\n") print("All models loaded!") # ---- Text Processing ---- def split_into_sentences(text): """Split raw ASR text into individual sentences.""" text = text.strip() if not text: return [] text = '. '.join(s.strip().capitalize() for s in text.split('. ') if s.strip()) if re.search(r'[.!?]', text): sentences = re.split(r'(?<=[.!?])\s+', text) return [s.strip() for s in sentences if s.strip()] words = text.split() MAX_WORDS = 12 sentences = [] for i in range(0, len(words), MAX_WORDS): chunk = ' '.join(words[i:i + MAX_WORDS]) if not chunk.endswith(('.', '!', '?')): chunk += '.' chunk = chunk[0].upper() + chunk[1:] if len(chunk) > 1 else chunk.upper() sentences.append(chunk) return sentences # ---- ASR ---- def transcribe(audio_array, sample_rate=16000): """ASR: English audio to text. Handles both short and long audio.""" if len(audio_array) < 1600: return "" duration_s = len(audio_array) / sample_rate if sample_rate != 16000: import torchaudio.functional as F_audio audio_tensor = torch.from_numpy(audio_array).float() audio_tensor = F_audio.resample(audio_tensor, sample_rate, 16000) audio_array = audio_tensor.numpy() sample_rate = 16000 if duration_s <= 28: result = asr_pipe( {"raw": audio_array, "sampling_rate": sample_rate}, return_timestamps=False, ) return result["text"].strip() # Long-form: native Whisper generate model = asr_pipe.model processor = asr_pipe.feature_extractor tokenizer = asr_pipe.tokenizer inputs = processor( audio_array, sampling_rate=16000, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, ) input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE) attention_mask = inputs.attention_mask.to(DEVICE) if "attention_mask" in inputs else None generate_kwargs = {"return_timestamps": True, "language": "en", "task": "transcribe"} if attention_mask is not None: generate_kwargs["attention_mask"] = attention_mask with torch.no_grad(): predicted_ids = model.generate(input_features, **generate_kwargs) transcription = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0] return transcription.strip() # ---- MT ---- def translate_sentence(text, target_nllb_code, fast=True, max_length=256): """Translate a single sentence from English to target language.""" inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE) tgt_lang_id = mt_tokenizer.convert_tokens_to_ids(target_nllb_code) generate_kwargs = { "forced_bos_token_id": tgt_lang_id, "repetition_penalty": 1.5, "no_repeat_ngram_size": 3, } if fast: generate_kwargs.update({"max_length": 128, "num_beams": 1, "do_sample": False}) else: generate_kwargs.update({"max_length": max_length, "num_beams": 4, "early_stopping": True}) with torch.no_grad(): output_ids = mt_model.generate(**inputs, **generate_kwargs) return mt_tokenizer.decode(output_ids[0], skip_special_tokens=True) def translate_text(text, target_nllb_code, fast=True): """Split and translate full text sentence-by-sentence.""" sentences = split_into_sentences(text) if not sentences: return "", [], [] translations = [] for s in sentences: yo = translate_sentence(s, target_nllb_code, fast=fast) translations.append(yo) return ' '.join(translations), sentences, translations # ---- Video Processing ---- def extract_audio_from_video(video_path, output_path, target_sr=16000): """Extract audio track from video as 16kHz mono WAV.""" cmd = [ "ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1", output_path, ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"ffmpeg extraction failed: {result.stderr[:200]}") return output_path def get_media_duration(path): """Get duration in seconds.""" cmd = [ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", path, ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"ffprobe failed: {result.stderr[:200]}") return float(result.stdout.strip()) def stretch_audio_to_duration(input_path, output_path, target_duration_s): """Stretch/compress audio to match target duration.""" current_duration = get_media_duration(input_path) if current_duration <= 0: raise RuntimeError("Invalid audio duration") ratio = current_duration / target_duration_s filters = [] remaining = ratio while remaining > 2.0: filters.append("atempo=2.0") remaining /= 2.0 while remaining < 0.5: filters.append("atempo=0.5") remaining /= 0.5 filters.append(f"atempo={remaining:.4f}") cmd = ["ffmpeg", "-y", "-i", input_path, "-filter:a", ",".join(filters), output_path] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"ffmpeg tempo failed: {result.stderr[:200]}") return output_path def mux_video_audio(video_path, audio_path, output_path, extend_video=False, target_duration=None): """Combine video with new audio. Optionally extend video by freezing last frame.""" if extend_video and target_duration: cmd = [ "ffmpeg", "-y", "-i", video_path, "-i", audio_path, "-filter_complex", f"[0:v]tpad=stop_mode=clone:stop_duration={target_duration}[v]", "-map", "[v]", "-map", "1:a:0", "-c:v", "libx264", "-preset", "fast", "-c:a", "aac", "-t", str(target_duration), output_path, ] else: cmd = [ "ffmpeg", "-y", "-i", video_path, "-i", audio_path, "-c:v", "copy", "-c:a", "aac", "-map", "0:v:0", "-map", "1:a:0", "-shortest", output_path, ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"ffmpeg mux failed: {result.stderr[:200]}") return output_path