Spaces:
Running
Running
| from faster_whisper import WhisperModel | |
| import os | |
| class STTManager: | |
| def __init__(self, model_size="base", device="cpu", compute_type="int8"): | |
| """ | |
| model_size: tiny, base, small, medium, large-v3 | |
| device: cpu or cuda | |
| compute_type: int8, float16, int8_float16 等 | |
| """ | |
| self.model_size = model_size | |
| self.device = device | |
| self.compute_type = compute_type | |
| self.model = None | |
| def load_model(self): | |
| """延迟加载模型,方便在加载前显示提示""" | |
| if self.model is None: | |
| # 自动处理 compute_type,如果 GPU 不支持 float16 则回退 | |
| stt_compute_type = self.compute_type | |
| if self.device == "cpu" and stt_compute_type == "float16": | |
| stt_compute_type = "int8" | |
| self.model = WhisperModel( | |
| self.model_size, | |
| device=self.device, | |
| compute_type=stt_compute_type, | |
| download_root=os.path.join(os.getcwd(), "models") | |
| ) | |
| return self.model | |
| def transcribe(self, audio_path): | |
| """识别音频并返回 segments (生成器) 和 info""" | |
| model = self.load_model() | |
| segments, info = model.transcribe(audio_path, beam_size=5) | |
| return segments, info | |
| def is_cuda_available(): | |
| """检测 CUDA 是否可用""" | |
| try: | |
| import ctranslate2 | |
| return ctranslate2.get_cuda_device_count() > 0 | |
| except: | |
| return False | |
| def get_downloaded_models(): | |
| """获取本地已下载的模型列表""" | |
| model_dir = os.path.join(os.getcwd(), "models") | |
| if not os.path.exists(model_dir): | |
| return [] | |
| models = [] | |
| for d in os.listdir(model_dir): | |
| if "faster-whisper-" in d: | |
| name = d.split("-")[-1] | |
| models.append(name) | |
| elif os.path.isdir(os.path.join(model_dir, d)): | |
| models.append(d) | |
| valid_names = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"] | |
| found = [m for m in models if any(v in m for v in valid_names)] | |
| final_models = [] | |
| for f in found: | |
| for v in valid_names: | |
| if v in f: | |
| final_models.append(v) | |
| break | |
| return sorted(list(set(final_models))) | |