rikhoffbauer2 commited on
Commit
8b482f1
·
verified ·
1 Parent(s): 4b10521

Upload lyric_sync/transcribe.py

Browse files
Files changed (1) hide show
  1. lyric_sync/transcribe.py +396 -0
lyric_sync/transcribe.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Word-level transcription of vocal audio.
3
+
4
+ Supports multiple backends:
5
+ - WhisperX (recommended): Whisper transcription + wav2vec2 phoneme alignment
6
+ - Whisper (transformers pipeline): Simpler, less precise alignment
7
+ - Granite Speech: IBM's timestamp-capable model (experimental for singing)
8
+
9
+ WhisperX is recommended because its two-stage approach (transcription + forced
10
+ phoneme alignment) is more robust for singing than Whisper's attention-based
11
+ word timestamps.
12
+ """
13
+
14
+ import logging
15
+ import re
16
+ from dataclasses import dataclass, field
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class TimedWord:
26
+ """A single word with timing information."""
27
+ word: str
28
+ start: float # seconds
29
+ end: float # seconds
30
+ confidence: float = 1.0
31
+
32
+ @property
33
+ def duration(self) -> float:
34
+ return self.end - self.start
35
+
36
+ def __repr__(self):
37
+ return f"TimedWord('{self.word}', {self.start:.3f}-{self.end:.3f})"
38
+
39
+
40
+ @dataclass
41
+ class TranscriptionResult:
42
+ """Full transcription with word-level timings."""
43
+ text: str
44
+ words: list[TimedWord] = field(default_factory=list)
45
+ language: str = "en"
46
+
47
+ @property
48
+ def duration(self) -> float:
49
+ if not self.words:
50
+ return 0.0
51
+ return self.words[-1].end - self.words[0].start
52
+
53
+
54
+ class WhisperXTranscriber:
55
+ """
56
+ Word-level transcription using WhisperX.
57
+
58
+ Two-stage approach:
59
+ 1. Whisper large-v2/v3 for text transcription (batched)
60
+ 2. wav2vec2 phoneme model for forced word-level alignment
61
+
62
+ This decoupled approach is robust to the timing drift that Whisper's
63
+ native word_timestamps exhibit on singing (stretched syllables).
64
+
65
+ Reference: arxiv:2303.00747 (WhisperX paper)
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ model_size: str = "large-v2",
71
+ device: str = "cuda",
72
+ compute_type: str = "float16",
73
+ language: str = "en",
74
+ batch_size: int = 16,
75
+ ):
76
+ """
77
+ Args:
78
+ model_size: Whisper model size. "large-v2" recommended for lyrics (per arxiv:2506.15514).
79
+ device: "cuda" or "cpu"
80
+ compute_type: "float16" (GPU) or "int8" (CPU) or "float32"
81
+ language: Language code for transcription
82
+ batch_size: Batch size for transcription (reduce if OOM)
83
+ """
84
+ self.model_size = model_size
85
+ self.device = device
86
+ self.compute_type = compute_type
87
+ self.language = language
88
+ self.batch_size = batch_size
89
+ self._model = None
90
+ self._align_model = None
91
+ self._align_metadata = None
92
+
93
+ def _load_models(self):
94
+ """Lazy-load WhisperX models."""
95
+ if self._model is not None:
96
+ return
97
+
98
+ import whisperx
99
+
100
+ self._model = whisperx.load_model(
101
+ self.model_size,
102
+ self.device,
103
+ compute_type=self.compute_type,
104
+ language=self.language,
105
+ )
106
+ self._align_model, self._align_metadata = whisperx.load_align_model(
107
+ language_code=self.language,
108
+ device=self.device,
109
+ )
110
+ logger.info(f"Loaded WhisperX: {self.model_size} + alignment model on {self.device}")
111
+
112
+ def transcribe(
113
+ self,
114
+ audio: np.ndarray,
115
+ sr: int = 16000,
116
+ initial_prompt: str = "Song lyrics: ",
117
+ ) -> TranscriptionResult:
118
+ """
119
+ Transcribe audio with word-level timestamps.
120
+
121
+ Args:
122
+ audio: Mono float32 numpy array
123
+ sr: Sample rate (16000 for Whisper)
124
+ initial_prompt: Prompt to bias Whisper toward lyrics domain
125
+
126
+ Returns:
127
+ TranscriptionResult with word-level timings
128
+ """
129
+ import whisperx
130
+
131
+ self._load_models()
132
+
133
+ # WhisperX expects audio loaded via its own loader at 16kHz
134
+ # But we can pass raw numpy if it's already 16kHz mono float32
135
+ if sr != 16000:
136
+ import torchaudio
137
+ import torch
138
+ audio_t = torch.from_numpy(audio).unsqueeze(0)
139
+ audio_t = torchaudio.functional.resample(audio_t, sr, 16000)
140
+ audio = audio_t.squeeze(0).numpy()
141
+
142
+ # Step 1: Transcribe
143
+ result = self._model.transcribe(
144
+ audio,
145
+ batch_size=self.batch_size,
146
+ language=self.language,
147
+ chunk_length=30, # 30s context — best for singing (arxiv:2506.15514)
148
+ initial_prompt=initial_prompt,
149
+ )
150
+
151
+ # Step 2: Forced word-level alignment via wav2vec2
152
+ result = whisperx.align(
153
+ result["segments"],
154
+ self._align_model,
155
+ self._align_metadata,
156
+ audio,
157
+ self.device,
158
+ return_char_alignments=False,
159
+ )
160
+
161
+ # Convert to our format
162
+ words = []
163
+ for ws in result.get("word_segments", []):
164
+ if "start" in ws and "end" in ws:
165
+ words.append(TimedWord(
166
+ word=ws["word"].strip(),
167
+ start=ws["start"],
168
+ end=ws["end"],
169
+ confidence=ws.get("score", 1.0),
170
+ ))
171
+
172
+ full_text = " ".join(w.word for w in words)
173
+ return TranscriptionResult(text=full_text, words=words, language=self.language)
174
+
175
+
176
+ class WhisperTranscriber:
177
+ """
178
+ Simpler fallback: Whisper via transformers pipeline with word timestamps.
179
+
180
+ Uses Whisper's built-in cross-attention DTW for word-level timestamps.
181
+ Less precise than WhisperX on singing but has fewer dependencies.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ model_id: str = "openai/whisper-large-v3",
187
+ device: str = "cuda",
188
+ torch_dtype: str = "float16",
189
+ ):
190
+ self.model_id = model_id
191
+ self.device = device
192
+ self.torch_dtype = torch_dtype
193
+ self._pipe = None
194
+
195
+ def _load_model(self):
196
+ if self._pipe is not None:
197
+ return
198
+
199
+ import torch
200
+ from transformers import pipeline
201
+
202
+ dtype_map = {"float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16}
203
+ self._pipe = pipeline(
204
+ task="automatic-speech-recognition",
205
+ model=self.model_id,
206
+ torch_dtype=dtype_map.get(self.torch_dtype, torch.float16),
207
+ device=self.device if self.device != "cpu" else -1,
208
+ model_kwargs={"attn_implementation": "sdpa"},
209
+ )
210
+ logger.info(f"Loaded Whisper pipeline: {self.model_id} on {self.device}")
211
+
212
+ def transcribe(
213
+ self,
214
+ audio: np.ndarray,
215
+ sr: int = 16000,
216
+ language: str = "english",
217
+ ) -> TranscriptionResult:
218
+ """
219
+ Transcribe with word-level timestamps via transformers pipeline.
220
+
221
+ Args:
222
+ audio: Mono float32 numpy array at sr Hz
223
+ sr: Sample rate
224
+ language: Language for transcription
225
+ """
226
+ self._load_model()
227
+
228
+ result = self._pipe(
229
+ {"array": audio, "sampling_rate": sr},
230
+ return_timestamps="word",
231
+ generate_kwargs={
232
+ "language": language,
233
+ "task": "transcribe",
234
+ "condition_on_previous_text": False, # Reduces hallucination on music
235
+ },
236
+ chunk_length_s=30,
237
+ stride_length_s=5,
238
+ )
239
+
240
+ words = []
241
+ for chunk in result.get("chunks", []):
242
+ text = chunk["text"].strip()
243
+ ts = chunk.get("timestamp", (None, None))
244
+ if text and ts[0] is not None and ts[1] is not None:
245
+ words.append(TimedWord(
246
+ word=text,
247
+ start=ts[0],
248
+ end=ts[1],
249
+ ))
250
+
251
+ full_text = " ".join(w.word for w in words)
252
+ return TranscriptionResult(text=full_text, words=words, language=language[:2])
253
+
254
+
255
+ class GraniteSpeechTranscriber:
256
+ """
257
+ Experimental: IBM Granite Speech 4.1 2B Plus with word timestamps.
258
+
259
+ Uses in-model [T:NNN] timestamp tokens. Promising but:
260
+ - Only works up to ~5 minutes in timestamp mode
261
+ - Trained on speech only (not singing)
262
+ - Only outputs word-end times (not start)
263
+
264
+ Reference: arxiv:2604.22817 (In-Sync paper)
265
+ """
266
+
267
+ def __init__(self, device: str = "cuda"):
268
+ self.device = device
269
+ self.model_id = "ibm-granite/granite-speech-4.1-2b-plus"
270
+ self._model = None
271
+ self._processor = None
272
+
273
+ def _load_model(self):
274
+ if self._model is not None:
275
+ return
276
+
277
+ import torch
278
+ from transformers import AutoModelForCausalLM, AutoProcessor
279
+
280
+ self._processor = AutoProcessor.from_pretrained(self.model_id)
281
+ self._model = AutoModelForCausalLM.from_pretrained(
282
+ self.model_id,
283
+ torch_dtype=torch.bfloat16,
284
+ device_map="auto",
285
+ )
286
+ logger.info(f"Loaded Granite Speech: {self.model_id}")
287
+
288
+ def transcribe(
289
+ self,
290
+ audio: np.ndarray,
291
+ sr: int = 16000,
292
+ ) -> TranscriptionResult:
293
+ """
294
+ Transcribe with word-level end-timestamps via Granite's [T:NNN] tokens.
295
+ """
296
+ import torch
297
+
298
+ self._load_model()
299
+
300
+ conversation = [
301
+ {
302
+ "role": "user",
303
+ "content": [
304
+ {"type": "audio", "audio": audio, "sampling_rate": sr},
305
+ {"type": "text", "text": "Please transcribe the speech into written format and add word-level timestamps."},
306
+ ],
307
+ }
308
+ ]
309
+
310
+ inputs = self._processor.apply_chat_template(
311
+ conversation,
312
+ add_generation_prompt=True,
313
+ tokenize=True,
314
+ return_dict=True,
315
+ return_tensors="pt",
316
+ ).to(self._model.device, dtype=torch.bfloat16)
317
+
318
+ output_ids = self._model.generate(**inputs, max_new_tokens=2048)
319
+ output_text = self._processor.decode(
320
+ output_ids[0][inputs["input_ids"].shape[1]:],
321
+ skip_special_tokens=True,
322
+ )
323
+
324
+ words = self._parse_granite_timestamps(output_text)
325
+ full_text = " ".join(w.word for w in words)
326
+ return TranscriptionResult(text=full_text, words=words)
327
+
328
+ @staticmethod
329
+ def _parse_granite_timestamps(text: str) -> list[TimedWord]:
330
+ """
331
+ Parse Granite [T:NNN] format where NNN is centiseconds.
332
+ Handles 10-second rollover.
333
+
334
+ Format: "word1 [T:012] word2 [T:045] ..."
335
+ """
336
+ pattern = r"(\S+)\s*\[T:(\d{3})\]"
337
+ matches = re.findall(pattern, text)
338
+
339
+ words = []
340
+ rollover = 0
341
+ prev_cs = 0
342
+
343
+ for word_text, cs_str in matches:
344
+ cs = int(cs_str)
345
+ # Detect rollover (centiseconds resets)
346
+ if cs < prev_cs - 50:
347
+ rollover += 1
348
+ prev_cs = cs
349
+
350
+ end_time = (cs + rollover * 1000) / 100.0
351
+
352
+ # Granite only gives end times; estimate start from previous word's end
353
+ start_time = words[-1].end if words else max(0.0, end_time - 0.3)
354
+
355
+ if word_text != "_": # underscore = sentence boundary marker
356
+ words.append(TimedWord(
357
+ word=word_text,
358
+ start=start_time,
359
+ end=end_time,
360
+ ))
361
+
362
+ return words
363
+
364
+
365
+ def transcribe_vocals(
366
+ audio: np.ndarray,
367
+ sr: int = 16000,
368
+ backend: str = "whisperx",
369
+ device: str = "cuda",
370
+ language: str = "en",
371
+ **kwargs,
372
+ ) -> TranscriptionResult:
373
+ """
374
+ Transcribe vocals with word-level timestamps.
375
+
376
+ Args:
377
+ audio: Mono float32 numpy array
378
+ sr: Sample rate
379
+ backend: "whisperx" (recommended), "whisper", or "granite"
380
+ device: "cuda" or "cpu"
381
+ language: Language code
382
+ **kwargs: Additional args passed to the backend
383
+
384
+ Returns:
385
+ TranscriptionResult with word-level timings
386
+ """
387
+ if backend == "whisperx":
388
+ transcriber = WhisperXTranscriber(device=device, language=language, **kwargs)
389
+ elif backend == "whisper":
390
+ transcriber = WhisperTranscriber(device=device, **kwargs)
391
+ elif backend == "granite":
392
+ transcriber = GraniteSpeechTranscriber(device=device)
393
+ else:
394
+ raise ValueError(f"Unknown backend: {backend}. Use 'whisperx', 'whisper', or 'granite'.")
395
+
396
+ return transcriber.transcribe(audio, sr=sr)