File size: 2,350 Bytes
a6908e4
fa9bc69
a6908e4
25b2219
a6908e4
25b2219
a6908e4
fa9bc69
a6908e4
fa9bc69
 
 
a6908e4
 
fa9bc69
a6908e4
 
 
 
 
 
fa9bc69
a6908e4
 
 
 
fa9bc69
 
 
 
 
a6908e4
 
fa9bc69
a6908e4
 
 
fa9bc69
 
a6908e4
fa9bc69
a6908e4
 
fa9bc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
import torch
import warnings
from tools.base_tool import BaseTool

class SpeechRecognitionTool(BaseTool):
    name = 'speech_to_text'
    description = 'Transcribes speech from audio input.'

    def __init__(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        dtype = torch.float16 if device == 'cuda' else torch.float32
        model_id = 'openai/whisper-large-v3-turbo'

        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id,
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
        ).to(device)

        self.processor = AutoProcessor.from_pretrained(model_id)

        logging.set_verbosity_error()
        warnings.filterwarnings("ignore", category=FutureWarning)

        self.pipeline = pipeline(
            "automatic-speech-recognition",
            model=self.model,
            tokenizer=self.processor.tokenizer,
            feature_extractor=self.processor.feature_extractor,
            torch_dtype=dtype,
            device=device,
            chunk_length_s=30,
            return_timestamps=True,
        )

    def transcribe(self, audio_path: str, with_timestamps: bool = False) -> str:
        result = self.pipeline(audio_path)

        if not with_timestamps:
            return result['text'].strip()

        formatted = ""
        for chunk in self._parse_timed_chunks(result['chunks']):
            formatted += f"[{chunk['start']:.2f}]\n{chunk['text']}\n[{chunk['end']:.2f}]\n"
        return formatted.strip()

    def _parse_timed_chunks(self, chunks):
        absolute_offset = 0.0
        current_offset = 0.0
        normalized = []
        max_chunk = 30.0

        for c in chunks:
            start, end = c['timestamp']
            if start < current_offset:
                absolute_offset += max_chunk
                current_offset = start
            start_time = absolute_offset + start

            if end < start:
                absolute_offset += max_chunk
            end_time = absolute_offset + end
            current_offset = end

            text = c['text'].strip()
            if text:
                normalized.append({"start": start_time, "end": end_time, "text": text})

        return normalized