CUT-DATA / model.py
hoanglinhn0's picture
Update model.py
1a2af10 verified
# model.py - Fixed for latest sherpa-onnx (2025+) + Vietnamese model 2025-04-20
from huggingface_hub import hf_hub_download
import sherpa_onnx
# Danh sách model
language_to_models = {
"Vietnamese": [
"csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20",
# Bạn có thể thêm "csukuangfj/sherpa-onnx-zipformer-vi-int8-2025-04-20" nếu muốn int8 nhỏ hơn
],
"English": [
"k2-fsa/sherpa-onnx-zipformer-en-2023-06-26",
],
"Chinese": [
"k2-fsa/sherpa-onnx-zipformer-zh-14m-2023-02-23",
],
}
def get_pretrained_model(
repo_id: str,
decoding_method: str = "modified_beam_search",
num_active_paths: int = 4,
):
print(f"🔄 Đang tải model từ {repo_id}...")
# Tải tokens.txt
tokens = hf_hub_download(repo_id=repo_id, filename="tokens.txt")
# Xác định tên file encoder/decoder/joiner
if "vi-2025-04-20" in repo_id.lower():
# Model Vietnamese mới (csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20)
encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-12-avg-8.onnx")
decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-12-avg-8.onnx")
joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-12-avg-8.onnx")
else:
# Model khác (thường dùng tên cũ)
try:
encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.int8.onnx")
decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.int8.onnx")
joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.int8.onnx")
except:
encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.onnx")
decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.onnx")
joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.onnx")
# Tạo recognizer bằng factory method (cách mới nhất)
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
tokens=tokens,
num_threads=2,
sample_rate=16000,
feature_dim=80,
decoding_method=decoding_method,
max_active_paths=num_active_paths,
debug=False,
)
print(f"✅ Model {repo_id} đã tải xong!")
return recognizer
def decode(recognizer, filename: str) -> str:
"""Decode một file WAV (bất kỳ sample rate nào)"""
s = recognizer.create_stream()
s.accept_wave_file(filename) # Tự động resample về 16kHz
recognizer.decode_stream(s)
return s.result.text.strip()
def get_punct_model():
return None