hoanglinhn0 commited on
Commit
1a2af10
·
verified ·
1 Parent(s): ed372e9

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +57 -55
model.py CHANGED
@@ -1,72 +1,74 @@
1
- # model.py - Fixed for latest sherpa-onnx API structure
2
  from huggingface_hub import hf_hub_download
3
  import sherpa_onnx
4
 
5
- def get_pretrained_model(repo_id: str, decoding_method: str, num_active_paths: int):
6
- # 1. Tải các file cần thiết
7
- tokens = hf_hub_download(repo_id, "tokens.txt")
8
-
9
- # Xác định tên file dựa trên repo (VN hoặc mặc định)
10
- if "vi-2025-04-20" in repo_id:
11
- # Model Tiếng Việt MỚI
12
- encoder = hf_hub_download(repo_id, "encoder-epoch-12-avg-8.onnx")
13
- decoder = hf_hub_download(repo_id, "decoder-epoch-12-avg-8.onnx")
14
- joiner = hf_hub_download(repo_id, "joiner-epoch-12-avg-8.onnx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  else:
16
- # Model Tiếng Anh/Trung/Khác
17
  try:
18
- encoder = hf_hub_download(repo_id, "encoder-epoch-99-avg-1.int8.onnx")
19
- decoder = hf_hub_download(repo_id, "decoder-epoch-99-avg-1.int8.onnx")
20
- joiner = hf_hub_download(repo_id, "joiner-epoch-99-avg-1.int8.onnx")
21
  except:
22
- encoder = hf_hub_download(repo_id, "encoder-epoch-99-avg-1.onnx")
23
- decoder = hf_hub_download(repo_id, "decoder-epoch-99-avg-1.onnx")
24
- joiner = hf_hub_download(repo_id, "joiner-epoch-99-avg-1.onnx")
25
-
26
- # 2. Cấu hình Transducer (Lớp trong cùng)
27
- transducer_config = sherpa_onnx.OfflineTransducerModelConfig(
28
- encoder_filename=encoder,
29
- decoder_filename=decoder,
30
- joiner_filename=joiner,
31
- )
32
 
33
- # 3. Cấu hình Model (Lớp giữa - chứa tokens và transducer)
34
- model_config = sherpa_onnx.OfflineModelConfig(
35
- transducer=transducer_config,
 
 
36
  tokens=tokens,
37
- num_threads=1,
 
 
 
 
38
  debug=False,
39
- model_type="zipformer",
40
  )
41
 
42
- # 4. Cấu hình Recognizer (Lớp ngoài cùng)
43
- # Lưu ý: 'num_active_paths' đổi thành 'max_active_paths'
44
- recognizer_config = sherpa_onnx.OfflineRecognizerConfig(
45
- model_config=model_config,
46
- decoding_method=decoding_method,
47
- max_active_paths=num_active_paths,
48
- )
49
 
50
- return sherpa_onnx.OfflineRecognizer(recognizer_config)
51
 
52
- def decode(recognizer, filename: str):
 
53
  s = recognizer.create_stream()
54
- s.accept_wave_file(filename)
55
  recognizer.decode_stream(s)
56
- return s.result.text
57
 
58
- def get_punct_model():
59
- return None
60
 
61
- # --- DANH SÁCH MODEL ---
62
- language_to_models = {
63
- "Vietnamese": [
64
- "csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20",
65
- ],
66
- "English": [
67
- "k2-fsa/sherpa-onnx-zipformer-en-2023-06-26",
68
- ],
69
- "Chinese": [
70
- "k2-fsa/sherpa-onnx-zipformer-zh-14m-2023-02-23",
71
- ],
72
- }
 
1
+ # model.py - Fixed for latest sherpa-onnx (2025+) + Vietnamese model 2025-04-20
2
  from huggingface_hub import hf_hub_download
3
  import sherpa_onnx
4
 
5
+ # Danh sách model
6
+ language_to_models = {
7
+ "Vietnamese": [
8
+ "csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20",
9
+ # Bạn thể thêm "csukuangfj/sherpa-onnx-zipformer-vi-int8-2025-04-20" nếu muốn int8 nhỏ hơn
10
+ ],
11
+ "English": [
12
+ "k2-fsa/sherpa-onnx-zipformer-en-2023-06-26",
13
+ ],
14
+ "Chinese": [
15
+ "k2-fsa/sherpa-onnx-zipformer-zh-14m-2023-02-23",
16
+ ],
17
+ }
18
+
19
+
20
+ def get_pretrained_model(
21
+ repo_id: str,
22
+ decoding_method: str = "modified_beam_search",
23
+ num_active_paths: int = 4,
24
+ ):
25
+ print(f"🔄 Đang tải model từ {repo_id}...")
26
+
27
+ # Tải tokens.txt
28
+ tokens = hf_hub_download(repo_id=repo_id, filename="tokens.txt")
29
+
30
+ # Xác định tên file encoder/decoder/joiner
31
+ if "vi-2025-04-20" in repo_id.lower():
32
+ # Model Vietnamese mới (csukuangfj/sherpa-onnx-zipformer-vi-2025-04-20)
33
+ encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-12-avg-8.onnx")
34
+ decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-12-avg-8.onnx")
35
+ joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-12-avg-8.onnx")
36
  else:
37
+ # Model khác (thường dùng tên cũ)
38
  try:
39
+ encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.int8.onnx")
40
+ decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.int8.onnx")
41
+ joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.int8.onnx")
42
  except:
43
+ encoder = hf_hub_download(repo_id=repo_id, filename="encoder-epoch-99-avg-1.onnx")
44
+ decoder = hf_hub_download(repo_id=repo_id, filename="decoder-epoch-99-avg-1.onnx")
45
+ joiner = hf_hub_download(repo_id=repo_id, filename="joiner-epoch-99-avg-1.onnx")
 
 
 
 
 
 
 
46
 
47
+ # Tạo recognizer bằng factory method (cách mới nhất)
48
+ recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
49
+ encoder=encoder,
50
+ decoder=decoder,
51
+ joiner=joiner,
52
  tokens=tokens,
53
+ num_threads=2,
54
+ sample_rate=16000,
55
+ feature_dim=80,
56
+ decoding_method=decoding_method,
57
+ max_active_paths=num_active_paths,
58
  debug=False,
 
59
  )
60
 
61
+ print(f"✅ Model {repo_id} đã tải xong!")
62
+ return recognizer
 
 
 
 
 
63
 
 
64
 
65
+ def decode(recognizer, filename: str) -> str:
66
+ """Decode một file WAV (bất kỳ sample rate nào)"""
67
  s = recognizer.create_stream()
68
+ s.accept_wave_file(filename) # Tự động resample về 16kHz
69
  recognizer.decode_stream(s)
70
+ return s.result.text.strip()
71
 
 
 
72
 
73
+ def get_punct_model():
74
+ return None