translator / stt_module.py
felix1968839's picture
first commit
4aff0b5 verified
from faster_whisper import WhisperModel
import os
class STTManager:
def __init__(self, model_size="base", device="cpu", compute_type="int8"):
"""
model_size: tiny, base, small, medium, large-v3
device: cpu or cuda
compute_type: int8, float16, int8_float16 等
"""
self.model_size = model_size
self.device = device
self.compute_type = compute_type
self.model = None
def load_model(self):
"""延迟加载模型,方便在加载前显示提示"""
if self.model is None:
# 自动处理 compute_type,如果 GPU 不支持 float16 则回退
stt_compute_type = self.compute_type
if self.device == "cpu" and stt_compute_type == "float16":
stt_compute_type = "int8"
self.model = WhisperModel(
self.model_size,
device=self.device,
compute_type=stt_compute_type,
download_root=os.path.join(os.getcwd(), "models")
)
return self.model
def transcribe(self, audio_path):
"""识别音频并返回 segments (生成器) 和 info"""
model = self.load_model()
segments, info = model.transcribe(audio_path, beam_size=5)
return segments, info
@staticmethod
def is_cuda_available():
"""检测 CUDA 是否可用"""
try:
import ctranslate2
return ctranslate2.get_cuda_device_count() > 0
except:
return False
@staticmethod
def get_downloaded_models():
"""获取本地已下载的模型列表"""
model_dir = os.path.join(os.getcwd(), "models")
if not os.path.exists(model_dir):
return []
models = []
for d in os.listdir(model_dir):
if "faster-whisper-" in d:
name = d.split("-")[-1]
models.append(name)
elif os.path.isdir(os.path.join(model_dir, d)):
models.append(d)
valid_names = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]
found = [m for m in models if any(v in m for v in valid_names)]
final_models = []
for f in found:
for v in valid_names:
if v in f:
final_models.append(v)
break
return sorted(list(set(final_models)))