# model.py - Fixed for latest sherpa-onnx (2025+) + Vietnamese model 2025-04-20 from huggingface_hub import hf_hub_download import sherpa_onnx # Danh sách model language_to_models = { "Vietnamese": [ "csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20", # Bạn có thể thêm "csukuangfj/sherpa-onnx-zipformer-vi-int8-2025-04-20" nếu muốn int8 nhỏ hơn ], "English": [ "k2-fsa/sherpa-onnx-zipformer-en-2023-06-26", ], "Chinese": [ "k2-fsa/sherpa-onnx-zipformer-zh-14m-2023-02-23", ], } def get_pretrained_model( repo_id: str, decoding_method: str = "modified_beam_search", num_active_paths: int = 4, ): print(f"🔄 Đang tải model từ {repo_id}...") # Tải tokens.txt tokens = hf_hub_download(repo_id=repo_id, filename="tokens.txt") # Xác định tên file encoder/decoder/joiner if "vi-2025-04-20" in repo_id.lower(): # Model Vietnamese mới (csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20) encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-12-avg-8.onnx") decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-12-avg-8.onnx") joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-12-avg-8.onnx") else: # Model khác (thường dùng tên cũ) try: encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.int8.onnx") decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.int8.onnx") joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.int8.onnx") except: encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.onnx") decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.onnx") joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.onnx") # Tạo recognizer bằng factory method (cách mới nhất) recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( encoder=encoder, decoder=decoder, joiner=joiner, tokens=tokens, num_threads=2, sample_rate=16000, feature_dim=80, decoding_method=decoding_method, max_active_paths=num_active_paths, debug=False, ) print(f"✅ Model {repo_id} đã tải xong!") return recognizer def decode(recognizer, filename: str) -> str: """Decode một file WAV (bất kỳ sample rate nào)""" s = recognizer.create_stream() s.accept_wave_file(filename) # Tự động resample về 16kHz recognizer.decode_stream(s) return s.result.text.strip() def get_punct_model(): return None