File size: 1,887 Bytes
7630e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ac1e08
 
f655787
dd76b9f
c71655e
dd76b9f
 
c71655e
7630e84
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from typing import Dict

import torch
import whisper

import numpy as np  # for counting parameters

from utils import log

device = "cuda" if torch.cuda.is_available() else "cpu"


class TranscribeAudio:
    def __init__(self):
        self.model = whisper.load_model("base", device=device)
        log(
            f"Model is {'multilingual' if self.model.is_multilingual else 'English-only'} "
            f"and has {sum(np.prod(p.shape) for p in self.model.parameters()):,} parameters."
        )

    def transcribe(self, audio_file_path: str, language: str = "en") -> Dict:
        log(f"Transcribing {audio_file_path} in {language}")
        options = dict(language=language, beam_size=5, best_of=5)
        transcribe_options = dict(task="transcribe", **options)
        result = self.model.transcribe(audio_file_path, **transcribe_options)
        return result

    def save_output(self, transcript_output: Dict, audio_file_path: str) -> str:
        filename, ext = os.path.splitext(audio_file_path)
        directory = os.path.dirname(filename)
        log(f"Saving output to {directory} directory as {filename}.vtt")
        srt_writer = whisper.utils.get_writer("srt", directory)
        vtt_writer = whisper.utils.get_writer("vtt", directory)

        # Save as an SRT file
        srt_writer(result=transcript_output, audio_path=audio_file_path)

        # Save as a VTT file
        vtt_writer(result=transcript_output, audio_path=audio_file_path)

        return f"{filename}.vtt"

    def __call__(self, audio_file_path: str, output_dir: str, input_language: str = "en") -> str:
        transcript = self.transcribe(audio_file_path)
        transcript_path = self.save_output(transcript, audio_file_path)
        return transcript_path


if __name__ == '__main__':
    transcribe_audio = TranscribeAudio()
    transcribe_audio('sample', 'iPhone_14_Pro.mp3')