Spaces:
Sleeping
Sleeping
| """ | |
| Core pipeline: ASR (Whisper) + MT (NLLB-200) functions. | |
| TTS is handled by tts_engine.py. | |
| """ | |
| import torch | |
| import numpy as np | |
| import re | |
| import time | |
| import os | |
| import subprocess | |
| import tempfile | |
| import logging | |
| import soundfile as sf | |
| logger = logging.getLogger(__name__) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Models (loaded once at startup) | |
| asr_pipe = None | |
| mt_tokenizer = None | |
| mt_model = None | |
| tts_pipe_local = None # Local TTS for Yoruba/Hausa/Igbo/Zulu | |
| def load_models(): | |
| """Load all models at startup.""" | |
| global asr_pipe, mt_tokenizer, mt_model, tts_pipe_local | |
| from transformers import ( | |
| pipeline as hf_pipeline, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| ) | |
| print(f"Device: {DEVICE} | Dtype: {TORCH_DTYPE}") | |
| print("Loading models...") | |
| # ASR | |
| ASR_MODEL_ID = "PlotweaverAI/whisper-small-de-en" | |
| print(f" Loading ASR: {ASR_MODEL_ID}") | |
| asr_pipe = hf_pipeline( | |
| "automatic-speech-recognition", | |
| model=ASR_MODEL_ID, | |
| device=DEVICE, | |
| torch_dtype=TORCH_DTYPE, | |
| ) | |
| print(" ASR loaded") | |
| # MT | |
| MT_MODEL_ID = "PlotweaverAI/nllb-200-distilled-600M-african-6lang" | |
| print(f" Loading MT: {MT_MODEL_ID}") | |
| mt_tokenizer = AutoTokenizer.from_pretrained(MT_MODEL_ID) | |
| mt_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| MT_MODEL_ID, torch_dtype=TORCH_DTYPE | |
| ).to(DEVICE) | |
| mt_tokenizer.src_lang = "eng_Latn" | |
| print(" MT loaded") | |
| # Local TTS (Yoruba) | |
| TTS_MODEL_ID = "PlotweaverAI/yoruba-mms-tts-new" | |
| print(f" Loading local TTS: {TTS_MODEL_ID}") | |
| tts_pipe_local = hf_pipeline( | |
| "text-to-speech", | |
| model=TTS_MODEL_ID, | |
| device=DEVICE, | |
| torch_dtype=TORCH_DTYPE, | |
| ) | |
| print(" Local TTS loaded") | |
| # Diagnostics | |
| print(f"\n=== Device diagnostics ===") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name(0)}") | |
| print(f"ASR on: {next(asr_pipe.model.parameters()).device}") | |
| print(f"MT on: {next(mt_model.parameters()).device}") | |
| print(f"TTS on: {next(tts_pipe_local.model.parameters()).device}") | |
| print(f"YourVoic API key: {'set' if os.environ.get('YOURVOIC_API_KEY') else 'NOT SET'}") | |
| print(f"==========================\n") | |
| print("All models loaded!") | |
| # ---- Text Processing ---- | |
| def split_into_sentences(text): | |
| """Split raw ASR text into individual sentences.""" | |
| text = text.strip() | |
| if not text: | |
| return [] | |
| text = '. '.join(s.strip().capitalize() for s in text.split('. ') if s.strip()) | |
| if re.search(r'[.!?]', text): | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| words = text.split() | |
| MAX_WORDS = 12 | |
| sentences = [] | |
| for i in range(0, len(words), MAX_WORDS): | |
| chunk = ' '.join(words[i:i + MAX_WORDS]) | |
| if not chunk.endswith(('.', '!', '?')): | |
| chunk += '.' | |
| chunk = chunk[0].upper() + chunk[1:] if len(chunk) > 1 else chunk.upper() | |
| sentences.append(chunk) | |
| return sentences | |
| # ---- ASR ---- | |
| def transcribe(audio_array, sample_rate=16000): | |
| """ASR: English audio to text. Handles both short and long audio.""" | |
| if len(audio_array) < 1600: | |
| return "" | |
| duration_s = len(audio_array) / sample_rate | |
| if sample_rate != 16000: | |
| import torchaudio.functional as F_audio | |
| audio_tensor = torch.from_numpy(audio_array).float() | |
| audio_tensor = F_audio.resample(audio_tensor, sample_rate, 16000) | |
| audio_array = audio_tensor.numpy() | |
| sample_rate = 16000 | |
| if duration_s <= 28: | |
| result = asr_pipe( | |
| {"raw": audio_array, "sampling_rate": sample_rate}, | |
| return_timestamps=False, | |
| ) | |
| return result["text"].strip() | |
| # Long-form: native Whisper generate | |
| model = asr_pipe.model | |
| processor = asr_pipe.feature_extractor | |
| tokenizer = asr_pipe.tokenizer | |
| inputs = processor( | |
| audio_array, sampling_rate=16000, return_tensors="pt", | |
| truncation=False, padding="longest", return_attention_mask=True, | |
| ) | |
| input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE) | |
| attention_mask = inputs.attention_mask.to(DEVICE) if "attention_mask" in inputs else None | |
| generate_kwargs = {"return_timestamps": True, "language": "en", "task": "transcribe"} | |
| if attention_mask is not None: | |
| generate_kwargs["attention_mask"] = attention_mask | |
| with torch.no_grad(): | |
| predicted_ids = model.generate(input_features, **generate_kwargs) | |
| transcription = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return transcription.strip() | |
| # ---- MT ---- | |
| def translate_sentence(text, target_nllb_code, fast=True, max_length=256): | |
| """Translate a single sentence from English to target language.""" | |
| inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE) | |
| tgt_lang_id = mt_tokenizer.convert_tokens_to_ids(target_nllb_code) | |
| generate_kwargs = { | |
| "forced_bos_token_id": tgt_lang_id, | |
| "repetition_penalty": 1.5, | |
| "no_repeat_ngram_size": 3, | |
| } | |
| if fast: | |
| generate_kwargs.update({"max_length": 128, "num_beams": 1, "do_sample": False}) | |
| else: | |
| generate_kwargs.update({"max_length": max_length, "num_beams": 4, "early_stopping": True}) | |
| with torch.no_grad(): | |
| output_ids = mt_model.generate(**inputs, **generate_kwargs) | |
| return mt_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| def translate_text(text, target_nllb_code, fast=True): | |
| """Split and translate full text sentence-by-sentence.""" | |
| sentences = split_into_sentences(text) | |
| if not sentences: | |
| return "", [], [] | |
| translations = [] | |
| for s in sentences: | |
| yo = translate_sentence(s, target_nllb_code, fast=fast) | |
| translations.append(yo) | |
| return ' '.join(translations), sentences, translations | |
| # ---- Video Processing ---- | |
| def extract_audio_from_video(video_path, output_path, target_sr=16000): | |
| """Extract audio track from video as 16kHz mono WAV.""" | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", video_path, | |
| "-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1", | |
| output_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffmpeg extraction failed: {result.stderr[:200]}") | |
| return output_path | |
| def get_media_duration(path): | |
| """Get duration in seconds.""" | |
| cmd = [ | |
| "ffprobe", "-v", "error", | |
| "-show_entries", "format=duration", | |
| "-of", "default=noprint_wrappers=1:nokey=1", path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffprobe failed: {result.stderr[:200]}") | |
| return float(result.stdout.strip()) | |
| def stretch_audio_to_duration(input_path, output_path, target_duration_s): | |
| """Stretch/compress audio to match target duration.""" | |
| current_duration = get_media_duration(input_path) | |
| if current_duration <= 0: | |
| raise RuntimeError("Invalid audio duration") | |
| ratio = current_duration / target_duration_s | |
| filters = [] | |
| remaining = ratio | |
| while remaining > 2.0: | |
| filters.append("atempo=2.0") | |
| remaining /= 2.0 | |
| while remaining < 0.5: | |
| filters.append("atempo=0.5") | |
| remaining /= 0.5 | |
| filters.append(f"atempo={remaining:.4f}") | |
| cmd = ["ffmpeg", "-y", "-i", input_path, "-filter:a", ",".join(filters), output_path] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffmpeg tempo failed: {result.stderr[:200]}") | |
| return output_path | |
| def mux_video_audio(video_path, audio_path, output_path, extend_video=False, target_duration=None): | |
| """Combine video with new audio. Optionally extend video by freezing last frame.""" | |
| if extend_video and target_duration: | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", video_path, "-i", audio_path, | |
| "-filter_complex", f"[0:v]tpad=stop_mode=clone:stop_duration={target_duration}[v]", | |
| "-map", "[v]", "-map", "1:a:0", | |
| "-c:v", "libx264", "-preset", "fast", "-c:a", "aac", | |
| "-t", str(target_duration), output_path, | |
| ] | |
| else: | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", video_path, "-i", audio_path, | |
| "-c:v", "copy", "-c:a", "aac", | |
| "-map", "0:v:0", "-map", "1:a:0", "-shortest", output_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffmpeg mux failed: {result.stderr[:200]}") | |
| return output_path | |