PlotweaverModel commited on
Commit
f7b3ceb
·
verified ·
1 Parent(s): 4ebac7c

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -260
pipeline.py DELETED
@@ -1,260 +0,0 @@
1
- """
2
- Core pipeline: ASR (Whisper) + MT (NLLB-200) functions.
3
- TTS is handled by tts_engine.py.
4
- """
5
-
6
- import torch
7
- import numpy as np
8
- import re
9
- import time
10
- import os
11
- import subprocess
12
- import tempfile
13
- import logging
14
- import soundfile as sf
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
- TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
20
-
21
- # Models (loaded once at startup)
22
- asr_pipe = None
23
- mt_tokenizer = None
24
- mt_model = None
25
- tts_pipe_local = None # Local TTS for Yoruba/Hausa/Igbo/Zulu
26
-
27
-
28
- def load_models():
29
- """Load all models at startup."""
30
- global asr_pipe, mt_tokenizer, mt_model, tts_pipe_local
31
- from transformers import (
32
- pipeline as hf_pipeline,
33
- AutoTokenizer,
34
- AutoModelForSeq2SeqLM,
35
- )
36
-
37
- print(f"Device: {DEVICE} | Dtype: {TORCH_DTYPE}")
38
- print("Loading models...")
39
-
40
- # ASR
41
- ASR_MODEL_ID = "PlotweaverAI/whisper-small-de-en"
42
- print(f" Loading ASR: {ASR_MODEL_ID}")
43
- asr_pipe = hf_pipeline(
44
- "automatic-speech-recognition",
45
- model=ASR_MODEL_ID,
46
- device=DEVICE,
47
- torch_dtype=TORCH_DTYPE,
48
- )
49
- print(" ASR loaded")
50
-
51
- # MT
52
- MT_MODEL_ID = "PlotweaverAI/nllb-200-distilled-600M-african-6lang"
53
- print(f" Loading MT: {MT_MODEL_ID}")
54
- mt_tokenizer = AutoTokenizer.from_pretrained(MT_MODEL_ID)
55
- mt_model = AutoModelForSeq2SeqLM.from_pretrained(
56
- MT_MODEL_ID, torch_dtype=TORCH_DTYPE
57
- ).to(DEVICE)
58
- mt_tokenizer.src_lang = "eng_Latn"
59
- print(" MT loaded")
60
-
61
- # Local TTS (Yoruba)
62
- TTS_MODEL_ID = "PlotweaverAI/yoruba-mms-tts-new"
63
- print(f" Loading local TTS: {TTS_MODEL_ID}")
64
- tts_pipe_local = hf_pipeline(
65
- "text-to-speech",
66
- model=TTS_MODEL_ID,
67
- device=DEVICE,
68
- torch_dtype=TORCH_DTYPE,
69
- )
70
- print(" Local TTS loaded")
71
-
72
- # Diagnostics
73
- print(f"\n=== Device diagnostics ===")
74
- print(f"CUDA available: {torch.cuda.is_available()}")
75
- if torch.cuda.is_available():
76
- print(f"CUDA device: {torch.cuda.get_device_name(0)}")
77
- print(f"ASR on: {next(asr_pipe.model.parameters()).device}")
78
- print(f"MT on: {next(mt_model.parameters()).device}")
79
- print(f"TTS on: {next(tts_pipe_local.model.parameters()).device}")
80
- print(f"YourVoic API key: {'set' if os.environ.get('YOURVOIC_API_KEY') else 'NOT SET'}")
81
- print(f"==========================\n")
82
- print("All models loaded!")
83
-
84
-
85
- # ---- Text Processing ----
86
-
87
- def split_into_sentences(text):
88
- """Split raw ASR text into individual sentences."""
89
- text = text.strip()
90
- if not text:
91
- return []
92
- text = '. '.join(s.strip().capitalize() for s in text.split('. ') if s.strip())
93
- if re.search(r'[.!?]', text):
94
- sentences = re.split(r'(?<=[.!?])\s+', text)
95
- return [s.strip() for s in sentences if s.strip()]
96
- words = text.split()
97
- MAX_WORDS = 12
98
- sentences = []
99
- for i in range(0, len(words), MAX_WORDS):
100
- chunk = ' '.join(words[i:i + MAX_WORDS])
101
- if not chunk.endswith(('.', '!', '?')):
102
- chunk += '.'
103
- chunk = chunk[0].upper() + chunk[1:] if len(chunk) > 1 else chunk.upper()
104
- sentences.append(chunk)
105
- return sentences
106
-
107
-
108
- # ---- ASR ----
109
-
110
- def transcribe(audio_array, sample_rate=16000):
111
- """ASR: English audio to text. Handles both short and long audio."""
112
- if len(audio_array) < 1600:
113
- return ""
114
-
115
- duration_s = len(audio_array) / sample_rate
116
-
117
- if sample_rate != 16000:
118
- import torchaudio.functional as F_audio
119
- audio_tensor = torch.from_numpy(audio_array).float()
120
- audio_tensor = F_audio.resample(audio_tensor, sample_rate, 16000)
121
- audio_array = audio_tensor.numpy()
122
- sample_rate = 16000
123
-
124
- if duration_s <= 28:
125
- result = asr_pipe(
126
- {"raw": audio_array, "sampling_rate": sample_rate},
127
- return_timestamps=False,
128
- )
129
- return result["text"].strip()
130
-
131
- # Long-form: native Whisper generate
132
- model = asr_pipe.model
133
- processor = asr_pipe.feature_extractor
134
- tokenizer = asr_pipe.tokenizer
135
-
136
- inputs = processor(
137
- audio_array, sampling_rate=16000, return_tensors="pt",
138
- truncation=False, padding="longest", return_attention_mask=True,
139
- )
140
- input_features = inputs.input_features.to(DEVICE, dtype=TORCH_DTYPE)
141
- attention_mask = inputs.attention_mask.to(DEVICE) if "attention_mask" in inputs else None
142
-
143
- generate_kwargs = {"return_timestamps": True, "language": "en", "task": "transcribe"}
144
- if attention_mask is not None:
145
- generate_kwargs["attention_mask"] = attention_mask
146
-
147
- with torch.no_grad():
148
- predicted_ids = model.generate(input_features, **generate_kwargs)
149
-
150
- transcription = tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
151
- return transcription.strip()
152
-
153
-
154
- # ---- MT ----
155
-
156
- def translate_sentence(text, target_nllb_code, fast=True, max_length=256):
157
- """Translate a single sentence from English to target language."""
158
- inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
159
- tgt_lang_id = mt_tokenizer.convert_tokens_to_ids(target_nllb_code)
160
-
161
- generate_kwargs = {
162
- "forced_bos_token_id": tgt_lang_id,
163
- "repetition_penalty": 1.5,
164
- "no_repeat_ngram_size": 3,
165
- }
166
- if fast:
167
- generate_kwargs.update({"max_length": 128, "num_beams": 1, "do_sample": False})
168
- else:
169
- generate_kwargs.update({"max_length": max_length, "num_beams": 4, "early_stopping": True})
170
-
171
- with torch.no_grad():
172
- output_ids = mt_model.generate(**inputs, **generate_kwargs)
173
-
174
- return mt_tokenizer.decode(output_ids[0], skip_special_tokens=True)
175
-
176
-
177
- def translate_text(text, target_nllb_code, fast=True):
178
- """Split and translate full text sentence-by-sentence."""
179
- sentences = split_into_sentences(text)
180
- if not sentences:
181
- return "", [], []
182
- translations = []
183
- for s in sentences:
184
- yo = translate_sentence(s, target_nllb_code, fast=fast)
185
- translations.append(yo)
186
- return ' '.join(translations), sentences, translations
187
-
188
-
189
- # ---- Video Processing ----
190
-
191
- def extract_audio_from_video(video_path, output_path, target_sr=16000):
192
- """Extract audio track from video as 16kHz mono WAV."""
193
- cmd = [
194
- "ffmpeg", "-y", "-i", video_path,
195
- "-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1",
196
- output_path,
197
- ]
198
- result = subprocess.run(cmd, capture_output=True, text=True)
199
- if result.returncode != 0:
200
- raise RuntimeError(f"ffmpeg extraction failed: {result.stderr[:200]}")
201
- return output_path
202
-
203
-
204
- def get_media_duration(path):
205
- """Get duration in seconds."""
206
- cmd = [
207
- "ffprobe", "-v", "error",
208
- "-show_entries", "format=duration",
209
- "-of", "default=noprint_wrappers=1:nokey=1", path,
210
- ]
211
- result = subprocess.run(cmd, capture_output=True, text=True)
212
- if result.returncode != 0:
213
- raise RuntimeError(f"ffprobe failed: {result.stderr[:200]}")
214
- return float(result.stdout.strip())
215
-
216
-
217
- def stretch_audio_to_duration(input_path, output_path, target_duration_s):
218
- """Stretch/compress audio to match target duration."""
219
- current_duration = get_media_duration(input_path)
220
- if current_duration <= 0:
221
- raise RuntimeError("Invalid audio duration")
222
-
223
- ratio = current_duration / target_duration_s
224
- filters = []
225
- remaining = ratio
226
- while remaining > 2.0:
227
- filters.append("atempo=2.0")
228
- remaining /= 2.0
229
- while remaining < 0.5:
230
- filters.append("atempo=0.5")
231
- remaining /= 0.5
232
- filters.append(f"atempo={remaining:.4f}")
233
-
234
- cmd = ["ffmpeg", "-y", "-i", input_path, "-filter:a", ",".join(filters), output_path]
235
- result = subprocess.run(cmd, capture_output=True, text=True)
236
- if result.returncode != 0:
237
- raise RuntimeError(f"ffmpeg tempo failed: {result.stderr[:200]}")
238
- return output_path
239
-
240
-
241
- def mux_video_audio(video_path, audio_path, output_path, extend_video=False, target_duration=None):
242
- """Combine video with new audio. Optionally extend video by freezing last frame."""
243
- if extend_video and target_duration:
244
- cmd = [
245
- "ffmpeg", "-y", "-i", video_path, "-i", audio_path,
246
- "-filter_complex", f"[0:v]tpad=stop_mode=clone:stop_duration={target_duration}[v]",
247
- "-map", "[v]", "-map", "1:a:0",
248
- "-c:v", "libx264", "-preset", "fast", "-c:a", "aac",
249
- "-t", str(target_duration), output_path,
250
- ]
251
- else:
252
- cmd = [
253
- "ffmpeg", "-y", "-i", video_path, "-i", audio_path,
254
- "-c:v", "copy", "-c:a", "aac",
255
- "-map", "0:v:0", "-map", "1:a:0", "-shortest", output_path,
256
- ]
257
- result = subprocess.run(cmd, capture_output=True, text=True)
258
- if result.returncode != 0:
259
- raise RuntimeError(f"ffmpeg mux failed: {result.stderr[:200]}")
260
- return output_path