cronos3k commited on
Commit
fd56ec5
·
verified ·
1 Parent(s): e7eeec3

Upload whisper_helper.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. whisper_helper.py +183 -0
whisper_helper.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Whisper STT helper for LongCat-AudioDiT.
3
+
4
+ Supports:
5
+ - faster-whisper backend (CTranslate2, recommended)
6
+ - Model variants: large-v3-turbo ("turbo"), large-v3 ("large-v3")
7
+
8
+ Usage:
9
+ helper = WhisperHelper(model_size="turbo", device="cuda")
10
+ text, language = helper.transcribe("audio.wav")
11
+ helper.unload()
12
+ """
13
+
14
+ import gc
15
+ import logging
16
+ from pathlib import Path
17
+ from typing import Optional, Tuple
18
+
19
+ import torch
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Approximate VRAM usage in GB (fp16 / int8)
24
+ WHISPER_VRAM_MAP = {
25
+ "turbo": 1.6, # large-v3-turbo
26
+ "large-v3": 3.0,
27
+ "medium": 1.5,
28
+ "small": 0.5,
29
+ "base": 0.3,
30
+ }
31
+
32
+ # HuggingFace model IDs for faster-whisper
33
+ FASTER_WHISPER_MODELS = {
34
+ "turbo": "deepdml/faster-whisper-large-v3-turbo-ct2",
35
+ "large-v3": "Systran/faster-whisper-large-v3",
36
+ "medium": "Systran/faster-whisper-medium",
37
+ "small": "Systran/faster-whisper-small",
38
+ "base": "Systran/faster-whisper-base",
39
+ }
40
+
41
+
42
+ class WhisperHelper:
43
+ """Thin wrapper around faster-whisper for on-demand STT."""
44
+
45
+ def __init__(
46
+ self,
47
+ model_size: str = "turbo",
48
+ device: str = "auto",
49
+ compute_type: str = "auto",
50
+ download_root: Optional[str] = None,
51
+ ):
52
+ """
53
+ Args:
54
+ model_size: "turbo", "large-v3", "medium", "small", "base"
55
+ device: "auto", "cuda", "cpu"
56
+ compute_type: "auto", "float16", "int8_float16", "int8"
57
+ download_root: where to cache models (defaults to ./models/whisper/)
58
+ """
59
+ self.model_size = model_size
60
+ self._model = None
61
+ self._is_loaded = False
62
+
63
+ # Resolve device
64
+ if device == "auto":
65
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ else:
67
+ self.device = device
68
+
69
+ # Resolve compute type
70
+ if compute_type == "auto":
71
+ if self.device == "cuda":
72
+ self.compute_type = "float16"
73
+ else:
74
+ self.compute_type = "int8"
75
+ else:
76
+ self.compute_type = compute_type
77
+
78
+ # Download root: always local to project, never user dirs
79
+ if download_root is None:
80
+ self.download_root = str(Path(__file__).parent / "models" / "whisper")
81
+ else:
82
+ self.download_root = download_root
83
+
84
+ Path(self.download_root).mkdir(parents=True, exist_ok=True)
85
+
86
+ @property
87
+ def is_loaded(self) -> bool:
88
+ return self._is_loaded and self._model is not None
89
+
90
+ @property
91
+ def vram_estimate_gb(self) -> float:
92
+ return WHISPER_VRAM_MAP.get(self.model_size, 3.0)
93
+
94
+ def load(self) -> None:
95
+ """Load Whisper model into VRAM/RAM."""
96
+ if self.is_loaded:
97
+ return
98
+ try:
99
+ from faster_whisper import WhisperModel
100
+ except ImportError:
101
+ raise ImportError(
102
+ "faster-whisper is not installed. Run: pip install faster-whisper"
103
+ )
104
+
105
+ model_id = FASTER_WHISPER_MODELS.get(self.model_size, self.model_size)
106
+ logger.info(
107
+ "Loading Whisper %s on %s (%s) from %s",
108
+ self.model_size, self.device, self.compute_type, model_id,
109
+ )
110
+ self._model = WhisperModel(
111
+ model_id,
112
+ device=self.device,
113
+ compute_type=self.compute_type,
114
+ download_root=self.download_root,
115
+ )
116
+ self._is_loaded = True
117
+ logger.info("Whisper %s loaded.", self.model_size)
118
+
119
+ def unload(self) -> None:
120
+ """Release VRAM/RAM used by the model."""
121
+ if not self.is_loaded:
122
+ return
123
+ del self._model
124
+ self._model = None
125
+ self._is_loaded = False
126
+ gc.collect()
127
+ if torch.cuda.is_available():
128
+ torch.cuda.empty_cache()
129
+ logger.info("Whisper %s unloaded.", self.model_size)
130
+
131
+ def transcribe(
132
+ self,
133
+ audio_path: str,
134
+ language: Optional[str] = None,
135
+ task: str = "transcribe",
136
+ beam_size: int = 5,
137
+ vad_filter: bool = True,
138
+ auto_load: bool = True,
139
+ ) -> Tuple[str, str]:
140
+ """
141
+ Transcribe an audio file.
142
+
143
+ Args:
144
+ audio_path: path to audio file (wav, mp3, flac, …)
145
+ language: ISO 639-1 code ("en", "zh", …) or None for auto-detect
146
+ task: "transcribe" or "translate" (translate → English)
147
+ beam_size: beam search width
148
+ vad_filter: apply voice activity detection filter
149
+ auto_load: load model if not already loaded
150
+
151
+ Returns:
152
+ (transcription_text, detected_language)
153
+ """
154
+ if not self.is_loaded:
155
+ if auto_load:
156
+ self.load()
157
+ else:
158
+ raise RuntimeError("Whisper model not loaded. Call load() first.")
159
+
160
+ segments, info = self._model.transcribe(
161
+ audio_path,
162
+ language=language,
163
+ task=task,
164
+ beam_size=beam_size,
165
+ vad_filter=vad_filter,
166
+ )
167
+
168
+ text_parts = [seg.text for seg in segments]
169
+ full_text = " ".join(text_parts).strip()
170
+ detected_lang = info.language
171
+
172
+ logger.info(
173
+ "Transcribed %s: '%s...' [lang=%s, prob=%.2f]",
174
+ Path(audio_path).name,
175
+ full_text[:60],
176
+ detected_lang,
177
+ info.language_probability,
178
+ )
179
+ return full_text, detected_lang
180
+
181
+ def __repr__(self) -> str:
182
+ status = "loaded" if self.is_loaded else "unloaded"
183
+ return f"WhisperHelper(size={self.model_size}, device={self.device}, {status})"