Spaces:
Sleeping
Sleeping
| # model.py - Fixed for latest sherpa-onnx (2025+) + Vietnamese model 2025-04-20 | |
| from huggingface_hub import hf_hub_download | |
| import sherpa_onnx | |
| # Danh sách model | |
| language_to_models = { | |
| "Vietnamese": [ | |
| "csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20", | |
| # Bạn có thể thêm "csukuangfj/sherpa-onnx-zipformer-vi-int8-2025-04-20" nếu muốn int8 nhỏ hơn | |
| ], | |
| "English": [ | |
| "k2-fsa/sherpa-onnx-zipformer-en-2023-06-26", | |
| ], | |
| "Chinese": [ | |
| "k2-fsa/sherpa-onnx-zipformer-zh-14m-2023-02-23", | |
| ], | |
| } | |
| def get_pretrained_model( | |
| repo_id: str, | |
| decoding_method: str = "modified_beam_search", | |
| num_active_paths: int = 4, | |
| ): | |
| print(f"🔄 Đang tải model từ {repo_id}...") | |
| # Tải tokens.txt | |
| tokens = hf_hub_download(repo_id=repo_id, filename="tokens.txt") | |
| # Xác định tên file encoder/decoder/joiner | |
| if "vi-2025-04-20" in repo_id.lower(): | |
| # Model Vietnamese mới (csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20) | |
| encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-12-avg-8.onnx") | |
| decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-12-avg-8.onnx") | |
| joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-12-avg-8.onnx") | |
| else: | |
| # Model khác (thường dùng tên cũ) | |
| try: | |
| encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.int8.onnx") | |
| decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.int8.onnx") | |
| joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.int8.onnx") | |
| except: | |
| encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.onnx") | |
| decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.onnx") | |
| joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.onnx") | |
| # Tạo recognizer bằng factory method (cách mới nhất) | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | |
| encoder=encoder, | |
| decoder=decoder, | |
| joiner=joiner, | |
| tokens=tokens, | |
| num_threads=2, | |
| sample_rate=16000, | |
| feature_dim=80, | |
| decoding_method=decoding_method, | |
| max_active_paths=num_active_paths, | |
| debug=False, | |
| ) | |
| print(f"✅ Model {repo_id} đã tải xong!") | |
| return recognizer | |
| def decode(recognizer, filename: str) -> str: | |
| """Decode một file WAV (bất kỳ sample rate nào)""" | |
| s = recognizer.create_stream() | |
| s.accept_wave_file(filename) # Tự động resample về 16kHz | |
| recognizer.decode_stream(s) | |
| return s.result.text.strip() | |
| def get_punct_model(): | |
| return None |