FD900 commited on
Commit
fa9bc69
·
verified ·
1 Parent(s): fa39ad6

Update tools/speech_recognition_tool.py

Browse files
Files changed (1) hide show
  1. tools/speech_recognition_tool.py +45 -82
tools/speech_recognition_tool.py CHANGED
@@ -1,107 +1,70 @@
1
- from smolagents import Tool
2
- import torch
3
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
 
4
  import warnings
5
 
6
-
7
- class SpeechRecognitionTool(Tool):
8
  name = 'speech_to_text'
9
- description = 'Transcribes spoken audio to text with optional time markers.'
10
-
11
- inputs = {
12
- 'audio': {
13
- 'type': 'string',
14
- 'description': 'Local path to the audio file to transcribe.',
15
- },
16
- 'with_time_markers': {
17
- 'type': 'boolean',
18
- 'description': 'Include timestamps in output.',
19
- 'nullable': True,
20
- 'default': False,
21
- },
22
- }
23
-
24
- output_type = 'string'
25
-
26
- chunk_length_s = 30 # chunk length for inference
27
-
28
- def __new__(cls, *args, **kwargs):
29
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
30
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
31
 
 
 
 
32
  model_id = 'openai/whisper-large-v3-turbo'
33
 
34
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
35
  model_id,
36
  torch_dtype=dtype,
37
  low_cpu_mem_usage=True,
38
  use_safetensors=True,
39
  ).to(device)
40
 
41
- processor = AutoProcessor.from_pretrained(model_id)
42
 
43
  logging.set_verbosity_error()
44
  warnings.filterwarnings("ignore", category=FutureWarning)
45
 
46
- cls.pipe = pipeline(
47
- task='automatic-speech-recognition',
48
- model=model,
49
- tokenizer=processor.tokenizer,
50
- feature_extractor=processor.feature_extractor,
51
  torch_dtype=dtype,
52
  device=device,
53
- chunk_length_s=cls.chunk_length_s,
54
  return_timestamps=True,
55
  )
56
 
57
- return super().__new__(cls, *args, **kwargs)
58
-
59
- def forward(self, audio: str, with_time_markers: bool = False) -> str:
60
- """
61
- Run speech recognition on the input audio file.
62
-
63
- Args:
64
- audio (str): Path to a local .wav or .mp3 file
65
- with_time_markers (bool): Whether to return chunked timestamps
66
 
67
- Returns:
68
- str: Transcript or chunked transcript with [start]\n[text]\n[end]
69
- """
70
- result = self.pipe(audio)
71
-
72
- if not with_time_markers:
73
  return result['text'].strip()
74
 
75
- chunks = self._normalize_chunks(result['chunks'])
76
-
77
- lines = []
78
- for ch in chunks:
79
- lines.append(f"[{ch['start']:.2f}]\n{ch['text']}\n[{ch['end']:.2f}]")
80
-
81
- return "\n".join(lines).strip()
82
-
83
- def _normalize_chunks(self, chunks):
84
- offset = 0.0
85
- chunk_offset = 0.0
86
- norm_chunks = []
87
-
88
- for chunk in chunks:
89
- ts_start, ts_end = chunk['timestamp']
90
- if ts_start < chunk_offset:
91
- offset += self.chunk_length_s
92
- chunk_offset = ts_start
93
-
94
- start = offset + ts_start
95
- if ts_end < ts_start:
96
- offset += self.chunk_length_s
97
- end = offset + ts_end
98
- chunk_offset = ts_end
99
-
100
- if chunk['text'].strip():
101
- norm_chunks.append({
102
- 'start': start,
103
- 'end': end,
104
- 'text': chunk['text'].strip(),
105
- })
106
-
107
- return norm_chunks
 
 
 
1
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
2
+ import torch
3
  import warnings
4
 
5
+ class SpeechRecognitionTool:
 
6
  name = 'speech_to_text'
7
+ description = 'Transcribes speech from audio input.'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def __init__(self):
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ dtype = torch.float16 if device == 'cuda' else torch.float32
12
  model_id = 'openai/whisper-large-v3-turbo'
13
 
14
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
15
  model_id,
16
  torch_dtype=dtype,
17
  low_cpu_mem_usage=True,
18
  use_safetensors=True,
19
  ).to(device)
20
 
21
+ self.processor = AutoProcessor.from_pretrained(model_id)
22
 
23
  logging.set_verbosity_error()
24
  warnings.filterwarnings("ignore", category=FutureWarning)
25
 
26
+ self.pipeline = pipeline(
27
+ "automatic-speech-recognition",
28
+ model=self.model,
29
+ tokenizer=self.processor.tokenizer,
30
+ feature_extractor=self.processor.feature_extractor,
31
  torch_dtype=dtype,
32
  device=device,
33
+ chunk_length_s=30,
34
  return_timestamps=True,
35
  )
36
 
37
+ def transcribe(self, audio_path: str, with_timestamps: bool = False) -> str:
38
+ result = self.pipeline(audio_path)
 
 
 
 
 
 
 
39
 
40
+ if not with_timestamps:
 
 
 
 
 
41
  return result['text'].strip()
42
 
43
+ formatted = ""
44
+ for chunk in self._parse_timed_chunks(result['chunks']):
45
+ formatted += f"[{chunk['start']:.2f}]\n{chunk['text']}\n[{chunk['end']:.2f}]\n"
46
+ return formatted.strip()
47
+
48
+ def _parse_timed_chunks(self, chunks):
49
+ absolute_offset = 0.0
50
+ current_offset = 0.0
51
+ normalized = []
52
+ max_chunk = 30.0
53
+
54
+ for c in chunks:
55
+ start, end = c['timestamp']
56
+ if start < current_offset:
57
+ absolute_offset += max_chunk
58
+ current_offset = start
59
+ start_time = absolute_offset + start
60
+
61
+ if end < start:
62
+ absolute_offset += max_chunk
63
+ end_time = absolute_offset + end
64
+ current_offset = end
65
+
66
+ text = c['text'].strip()
67
+ if text:
68
+ normalized.append({"start": start_time, "end": end_time, "text": text})
69
+
70
+ return normalized