File size: 3,226 Bytes
a6908e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from smolagents import Tool
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
import warnings


class SpeechRecognitionTool(Tool):
    name = 'speech_to_text'
    description = 'Transcribes spoken audio to text with optional time markers.'

    inputs = {
        'audio': {
            'type': 'string',
            'description': 'Local path to the audio file to transcribe.',
        },
        'with_time_markers': {
            'type': 'boolean',
            'description': 'Include timestamps in output.',
            'nullable': True,
            'default': False,
        },
    }

    output_type = 'string'

    chunk_length_s = 30  # chunk length for inference

    def __new__(cls, *args, **kwargs):
        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        model_id = 'openai/whisper-large-v3-turbo'

        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id,
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
        ).to(device)

        processor = AutoProcessor.from_pretrained(model_id)

        logging.set_verbosity_error()
        warnings.filterwarnings("ignore", category=FutureWarning)

        cls.pipe = pipeline(
            task='automatic-speech-recognition',
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            torch_dtype=dtype,
            device=device,
            chunk_length_s=cls.chunk_length_s,
            return_timestamps=True,
        )

        return super().__new__(cls, *args, **kwargs)

    def forward(self, audio: str, with_time_markers: bool = False) -> str:
        """
        Run speech recognition on the input audio file.

        Args:
            audio (str): Path to a local .wav or .mp3 file
            with_time_markers (bool): Whether to return chunked timestamps

        Returns:
            str: Transcript or chunked transcript with [start]\n[text]\n[end]
        """
        result = self.pipe(audio)

        if not with_time_markers:
            return result['text'].strip()

        chunks = self._normalize_chunks(result['chunks'])

        lines = []
        for ch in chunks:
            lines.append(f"[{ch['start']:.2f}]\n{ch['text']}\n[{ch['end']:.2f}]")

        return "\n".join(lines).strip()

    def _normalize_chunks(self, chunks):
        offset = 0.0
        chunk_offset = 0.0
        norm_chunks = []

        for chunk in chunks:
            ts_start, ts_end = chunk['timestamp']
            if ts_start < chunk_offset:
                offset += self.chunk_length_s
                chunk_offset = ts_start

            start = offset + ts_start
            if ts_end < ts_start:
                offset += self.chunk_length_s
            end = offset + ts_end
            chunk_offset = ts_end

            if chunk['text'].strip():
                norm_chunks.append({
                    'start': start,
                    'end': end,
                    'text': chunk['text'].strip(),
                })

        return norm_chunks