Spaces:
Sleeping
Sleeping
| # Standard library imports | |
| import os | |
| from typing import Annotated, List, Dict | |
| # Related third-party imports | |
| import torch | |
| from faster_whisper import decode_audio | |
| from ctc_forced_aligner import ( | |
| generate_emissions, | |
| get_alignments, | |
| get_spans, | |
| load_alignment_model, | |
| postprocess_results, | |
| preprocess_text, | |
| ) | |
| class ForcedAligner: | |
| """ | |
| ForcedAligner is a class for aligning audio to a provided transcript using a pre-trained alignment model. | |
| Attributes | |
| ---------- | |
| device : str | |
| Device to run the model on ('cuda' for GPU or 'cpu'). | |
| alignment_model : torch.nn.Module | |
| The pre-trained alignment model. | |
| alignment_tokenizer : Any | |
| Tokenizer for processing text in alignment. | |
| Methods | |
| ------- | |
| align(audio_path, transcript, language, batch_size) | |
| Aligns audio with a transcript and returns word-level timing information. | |
| """ | |
| def __init__(self, device: Annotated[str, "Device for model ('cuda' or 'cpu')"] = None): | |
| """ | |
| Initialize the ForcedAligner with the specified device. | |
| Parameters | |
| ---------- | |
| device : str, optional | |
| Device for running the model, by default 'cuda' if available, otherwise 'cpu'. | |
| """ | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.alignment_model, self.alignment_tokenizer = load_alignment_model( | |
| self.device, | |
| dtype=torch.float16 if self.device == 'cuda' else torch.float32, | |
| ) | |
| def align( | |
| self, | |
| audio_path: Annotated[str, "Path to the audio file"], | |
| transcript: Annotated[str, "Transcript of the audio content"], | |
| language: Annotated[str, "Language of the transcript"] = 'en', | |
| batch_size: Annotated[int, "Batch size for emission generation"] = 8, | |
| ) -> Annotated[List[Dict[str, float]], "List of word alignment data with timestamps"]: | |
| """ | |
| Aligns audio with a transcript and returns word-level timing information. | |
| Parameters | |
| ---------- | |
| audio_path : str | |
| Path to the audio file. | |
| transcript : str | |
| Transcript text corresponding to the audio. | |
| language : str, optional | |
| Language code for the transcript, default is 'en' (English). | |
| batch_size : int, optional | |
| Batch size for generating emissions, by default 8. | |
| Returns | |
| ------- | |
| List[Dict[str, float]] | |
| A list of dictionaries containing word timing information. | |
| Raises | |
| ------ | |
| FileNotFoundError | |
| If the specified audio file does not exist. | |
| Examples | |
| -------- | |
| >>> aligner = ForcedAligner() | |
| >>> aligner.align("path/to/audio.wav", "hello world") | |
| [{'word': 'hello', 'start': 0.0, 'end': 0.5}, {'word': 'world', 'start': 0.6, 'end': 1.0}] | |
| """ | |
| if not os.path.exists(audio_path): | |
| raise FileNotFoundError( | |
| f"The audio file at path '{audio_path}' was not found." | |
| ) | |
| speech_array = torch.from_numpy(decode_audio(audio_path)) | |
| emissions, stride = generate_emissions( | |
| self.alignment_model, | |
| speech_array.to(self.alignment_model.dtype).to(self.alignment_model.device), | |
| batch_size=batch_size, | |
| ) | |
| tokens_starred, text_starred = preprocess_text( | |
| transcript, | |
| romanize=True, | |
| language=language, | |
| ) | |
| segments, scores, blank_token = get_alignments( | |
| emissions, | |
| tokens_starred, | |
| self.alignment_tokenizer, | |
| ) | |
| spans = get_spans(tokens_starred, segments, blank_token) | |
| word_timestamps = postprocess_results(text_starred, spans, stride, scores) | |
| if self.device == 'cuda': | |
| del self.alignment_model | |
| torch.cuda.empty_cache() | |
| print(f"Word_Timestamps: {word_timestamps}") | |
| return word_timestamps | |
| if __name__ == "__main__": | |
| forced_aligner = ForcedAligner() | |
| try: | |
| path = "example_audio.wav" | |
| audio_transcript = "This is a test transcript." | |
| word_timestamp = forced_aligner.align(path, audio_transcript) | |
| print(word_timestamp) | |
| except FileNotFoundError as e: | |
| print(e) |