HK0712 commited on
Commit
e75d71c
·
1 Parent(s): 76b1f2f

improved cantonese version v2

Browse files
Files changed (1) hide show
  1. analyzer/ASR_zh_hk.py +60 -70
analyzer/ASR_zh_hk.py CHANGED
@@ -14,64 +14,48 @@ print(f"INFO: ASR_zh_hk.py is configured to use device: {DEVICE}")
14
 
15
  MODEL_NAME = "HK0712/Wav2Vec2_Cantonese"
16
 
17
- # --- 1. 輔助函數:粵拼智慧切分器 (Linguistic Split) ---
18
  def _tokenize_jyutping_smart(jyutping_str: str) -> list:
19
  """
20
  將單個粵拼音節 (如 'gwong2') 根據聲韻學結構切分為 token。
21
  Target: 'gwong2' -> ['gw', 'o', 'ng', '2']
22
- 這樣前端顯示時會是 "gw o ng 2",比 "g w o n g 2" 易讀得多。
23
  """
24
  try:
25
- # pycantonese.parse_jyutping 回傳的是一個列表,包含 Jyutping 物件
26
- # 例如: parse_jyutping('gwong2') -> [Jyutping(onset='gw', nucleus='o', coda='ng', tone='2')]
27
  parsed = pycantonese.parse_jyutping(jyutping_str)
28
-
29
  tokens = []
30
  for jp in parsed:
31
  if jp.onset: tokens.append(jp.onset)
32
  if jp.nucleus: tokens.append(jp.nucleus)
33
  if jp.coda: tokens.append(jp.coda)
34
  if jp.tone: tokens.append(jp.tone)
35
-
36
  return tokens
37
  except:
38
- # 萬一解析失敗(例如模型輸出的拼音不標準),回退到簡單切分
39
- # 但保留數字作為獨立 token
40
  return re.findall(r'[a-z]+|[0-9]', jyutping_str)
41
 
42
- # --- 2. 智慧 G2P 歸屬邏輯 (中文版) ---
43
  def _get_target_jyutping_by_char(sentence: str) -> (list, list):
44
  """
45
  將中文句子轉換為「字」級別的粵拼目標。
46
  """
47
- # pycantonese.characters_to_jyutping 會處理變調與分詞
48
- # 範例: "廣東話" -> [('廣東話', 'gwong2dung1waa2')]
49
  segmented_result = pycantonese.characters_to_jyutping(sentence)
50
 
51
  original_chars_flat = []
52
  target_jyutping_groups = []
53
-
54
- # 簡單的正則表達式,用來把連在一起的拼音分開 (e.g. 'gwong2dung1' -> 'gwong2', 'dung1')
55
  jyutping_syllable_pattern = re.compile(r'([a-z]+[1-6])')
56
 
57
  for word_segment, jyutping_segment in segmented_result:
58
- if not jyutping_segment:
59
- continue
60
 
61
  syllables = jyutping_syllable_pattern.findall(jyutping_segment)
62
 
63
- # 嘗試將分詞後的結果對齊回單個漢字
64
  if len(word_segment) == len(syllables):
65
  for char, syl in zip(word_segment, syllables):
66
  original_chars_flat.append(char)
67
- # 使用智慧切分:'gwong2' -> ['gw', 'o', 'ng', '2']
68
  target_jyutping_groups.append(_tokenize_jyutping_smart(syl))
69
  else:
70
- # 長度不匹配時的備用方案 (逐字處理)
71
  print(f"WARNING: Mismatch length for {word_segment}. Fallback to char-by-char G2P.")
72
  for char in word_segment:
73
  original_chars_flat.append(char)
74
- # 對單字再做一次 G2P
75
  single_res = pycantonese.characters_to_jyutping(char)
76
  if single_res and single_res[0][1]:
77
  target_jyutping_groups.append(_tokenize_jyutping_smart(single_res[0][1]))
@@ -80,7 +64,7 @@ def _get_target_jyutping_by_char(sentence: str) -> (list, list):
80
 
81
  return original_chars_flat, target_jyutping_groups
82
 
83
- # --- 3. 核心分析函數 (主入口) ---
84
  def analyze(audio_file_path: str, target_sentence: str, cache: dict = {}) -> dict:
85
  if "model" not in cache:
86
  print(f"Cache miss (ASR_zh_hk). Loading model '{MODEL_NAME}'...")
@@ -119,66 +103,28 @@ def analyze(audio_file_path: str, target_sentence: str, cache: dict = {}) -> dic
119
  logits = model(input_values).logits
120
  predicted_ids = torch.argmax(logits, dim=-1)
121
 
122
- # 3. 獲取使用者輸出 (User Output)
123
- # 模型輸出: "gwong2 dung1 waa2" (字串)
124
  raw_output_str = processor.decode(predicted_ids[0])
125
 
126
- # 清理並準備對齊
127
- # 我們需要把用戶的輸出也變成 ['gw', 'o', 'ng', '2', 'd', 'u', 'ng', '1'...] 的流
128
- # 這樣才能跟 Target 的結構對齊
129
-
130
- # 步驟 A: 移除空格,變成連續字串 "gwong2dung1waa2"
131
- # 注意:這一步假設模型輸出的拼音是標準的。如果模型輸出亂碼,tokenize 可能會切得不完美,
132
- # 但 Needleman-Wunsch 算法會處理這些 mismatch,所以沒關係。
133
- user_jyutping_clean = raw_output_str.replace(" ", "")
134
-
135
- # 步驟 B: 使用相同的邏輯切分用戶輸入
136
- # 因為用戶輸入是一長串,我們用正則表達式把 [a-z] 和 [0-9] 分開,或者嘗試 parse
137
- # 這裡用一個簡單的策略:把它當作一連串的 components
138
- # 為了最佳對齊,我們這裡還是用 "Character + Number" 的粒度比較好,
139
- # 因為用戶可能讀錯導致無法形成合法的 onset/nucleus。
140
- #
141
- # ★ 關鍵決策:為了避免用戶讀錯導致 crash,用戶端我們使用較細的粒度 (Regex Split),
142
- # 然後讓對齊算法去匹配 Target 的 "gw", "o", "ng"。
143
- # 等等,如果 Target 是 "gw" (1個token),User 是 "g", "w" (2個 tokens),對齊會錯位。
144
- #
145
- # ★ 修正策略:
146
- # 我們也嘗試用 pycantonese.parse_jyutping 去解析用戶的整句輸出。
147
- # 如果解析成功,我們就用結構化 token。如果失敗(亂讀),回退到字母切分。
148
-
149
  user_tokens = []
150
- # 嘗試把用戶輸出拆成音節 (e.g. "gwong2", "dung1")
151
  user_syllables = re.findall(r'[a-z]+[0-9]', raw_output_str)
152
 
153
  if user_syllables:
154
- # 如果能抓到音節,就用結構化切分
155
  for syl in user_syllables:
156
  user_tokens.extend(_tokenize_jyutping_smart(syl))
157
  else:
158
- # 如果抓不到(例如沒聲調),就退化成字母切分
159
- # 但這會導致跟 Target (gw) 對不上。
160
- # 為了保險,我們這裡對於 Target 也許應該退化成簡單切分?
161
- # 不,Target 是 Ground Truth,應該保持結構。
162
- #
163
- # 最終方案:讓 User stream 盡量 "粘" 在一起。
164
- # 實際上,Wav2Vec2 輸出的通常是標準拼音。我們直接用 smart parse。
165
  user_tokens = _tokenize_jyutping_smart(raw_output_str)
166
 
167
-
168
  # 4. 對齊 (Alignment)
169
  word_alignments = _get_phoneme_alignments_by_word(user_tokens, target_jyutping_by_char)
170
 
171
  return _format_to_json_structure(word_alignments, target_sentence, target_chars)
172
 
173
- # --- 4. 對齊與格式化 (保持原樣或微調) ---
174
- # 這裡的邏輯與之前相同,不需要大改,因為它只是比較兩個 list 的相似度。
175
- # 只要 user_tokens 和 target_jyutping_by_char 的元素 (token) 粒度一致即可。
176
- # ... ( _get_phoneme_alignments_by_word 與 _format_to_json_structure 代碼同上) ...
177
-
178
- # 為了節省篇幅,請使用上一版提供的 _get_phoneme_alignments_by_word 和 _format_to_json_structure
179
- # 只需要替換上面的 _tokenize_jyutping_smart 和 analyze 函數即可。
180
- # 下面我會把完整的 _get_phoneme_alignments_by_word 貼上以確保完整性。
181
 
 
182
  def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
183
  target_phonemes_flat = []
184
  word_boundaries_indices = []
@@ -189,27 +135,70 @@ def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
189
  current_idx += len(word_ipa_tokens)
190
  word_boundaries_indices.append(current_idx - 1)
191
 
192
- # DP Matrix
193
  dp = np.zeros((len(user_phonemes) + 1, len(target_phonemes_flat) + 1))
194
  for i in range(1, len(user_phonemes) + 1): dp[i][0] = i
195
  for j in range(1, len(target_phonemes_flat) + 1): dp[0][j] = j
196
 
 
197
  for i in range(1, len(user_phonemes) + 1):
198
  for j in range(1, len(target_phonemes_flat) + 1):
199
- cost = 0 if user_phonemes[i-1] == target_phonemes_flat[j-1] else 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  dp[i][j] = min(dp[i-1][j] + 1, dp[i][j-1] + 1, dp[i-1][j-1] + cost)
201
 
 
202
  i, j = len(user_phonemes), len(target_phonemes_flat)
203
  user_path, target_path = [], []
204
  while i > 0 or j > 0:
205
- cost = float('inf') if i == 0 or j == 0 else (0 if user_phonemes[i-1] == target_phonemes_flat[j-1] else 1)
206
- if i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + cost:
207
- user_path.insert(0, user_phonemes[i-1]); target_path.insert(0, target_phonemes_flat[j-1]); i -= 1; j -= 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
209
- user_path.insert(0, user_phonemes[i-1]); target_path.insert(0, '-'); i -= 1
 
 
 
210
  else:
211
- user_path.insert(0, '-'); target_path.insert(0, target_phonemes_flat[j-1]); j -= 1
 
 
212
 
 
213
  alignments_by_word = []
214
  word_start_idx_in_path = 0
215
  target_phoneme_counter_in_path = 0
@@ -238,6 +227,7 @@ def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
238
 
239
  return alignments_by_word
240
 
 
241
  def _format_to_json_structure(alignments, sentence, original_words) -> dict:
242
  total_phonemes = 0
243
  total_errors = 0
 
14
 
15
  MODEL_NAME = "HK0712/Wav2Vec2_Cantonese"
16
 
17
+ # --- 1. 輔助函數:粵拼智慧切分器 ---
18
  def _tokenize_jyutping_smart(jyutping_str: str) -> list:
19
  """
20
  將單個粵拼音節 (如 'gwong2') 根據聲韻學結構切分為 token。
21
  Target: 'gwong2' -> ['gw', 'o', 'ng', '2']
 
22
  """
23
  try:
 
 
24
  parsed = pycantonese.parse_jyutping(jyutping_str)
 
25
  tokens = []
26
  for jp in parsed:
27
  if jp.onset: tokens.append(jp.onset)
28
  if jp.nucleus: tokens.append(jp.nucleus)
29
  if jp.coda: tokens.append(jp.coda)
30
  if jp.tone: tokens.append(jp.tone)
 
31
  return tokens
32
  except:
 
 
33
  return re.findall(r'[a-z]+|[0-9]', jyutping_str)
34
 
35
+ # --- 2. 智慧 G2P 歸屬邏輯 ---
36
  def _get_target_jyutping_by_char(sentence: str) -> (list, list):
37
  """
38
  將中文句子轉換為「字」級別的粵拼目標。
39
  """
 
 
40
  segmented_result = pycantonese.characters_to_jyutping(sentence)
41
 
42
  original_chars_flat = []
43
  target_jyutping_groups = []
 
 
44
  jyutping_syllable_pattern = re.compile(r'([a-z]+[1-6])')
45
 
46
  for word_segment, jyutping_segment in segmented_result:
47
+ if not jyutping_segment: continue
 
48
 
49
  syllables = jyutping_syllable_pattern.findall(jyutping_segment)
50
 
 
51
  if len(word_segment) == len(syllables):
52
  for char, syl in zip(word_segment, syllables):
53
  original_chars_flat.append(char)
 
54
  target_jyutping_groups.append(_tokenize_jyutping_smart(syl))
55
  else:
 
56
  print(f"WARNING: Mismatch length for {word_segment}. Fallback to char-by-char G2P.")
57
  for char in word_segment:
58
  original_chars_flat.append(char)
 
59
  single_res = pycantonese.characters_to_jyutping(char)
60
  if single_res and single_res[0][1]:
61
  target_jyutping_groups.append(_tokenize_jyutping_smart(single_res[0][1]))
 
64
 
65
  return original_chars_flat, target_jyutping_groups
66
 
67
+ # --- 3. 核心分析函數 ---
68
  def analyze(audio_file_path: str, target_sentence: str, cache: dict = {}) -> dict:
69
  if "model" not in cache:
70
  print(f"Cache miss (ASR_zh_hk). Loading model '{MODEL_NAME}'...")
 
103
  logits = model(input_values).logits
104
  predicted_ids = torch.argmax(logits, dim=-1)
105
 
106
+ # 3. 獲取使用者輸出
 
107
  raw_output_str = processor.decode(predicted_ids[0])
108
 
109
+ # 處理 User Tokens
110
+ # 嘗試抓取標準音節,如果失敗則退化為 smart parse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  user_tokens = []
 
112
  user_syllables = re.findall(r'[a-z]+[0-9]', raw_output_str)
113
 
114
  if user_syllables:
 
115
  for syl in user_syllables:
116
  user_tokens.extend(_tokenize_jyutping_smart(syl))
117
  else:
118
+ # 如果用戶完全沒讀出聲調,或者是亂碼
 
 
 
 
 
 
119
  user_tokens = _tokenize_jyutping_smart(raw_output_str)
120
 
 
121
  # 4. 對齊 (Alignment)
122
  word_alignments = _get_phoneme_alignments_by_word(user_tokens, target_jyutping_by_char)
123
 
124
  return _format_to_json_structure(word_alignments, target_sentence, target_chars)
125
 
 
 
 
 
 
 
 
 
126
 
127
+ # --- 4. 對齊函數 (已強化:類型感知 Type-Aware) ---
128
  def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
129
  target_phonemes_flat = []
130
  word_boundaries_indices = []
 
135
  current_idx += len(word_ipa_tokens)
136
  word_boundaries_indices.append(current_idx - 1)
137
 
138
+ # DP Initialization
139
  dp = np.zeros((len(user_phonemes) + 1, len(target_phonemes_flat) + 1))
140
  for i in range(1, len(user_phonemes) + 1): dp[i][0] = i
141
  for j in range(1, len(target_phonemes_flat) + 1): dp[0][j] = j
142
 
143
+ # 【【【 Type-Aware Cost Calculation 】】】
144
  for i in range(1, len(user_phonemes) + 1):
145
  for j in range(1, len(target_phonemes_flat) + 1):
146
+ u_char = user_phonemes[i-1]
147
+ t_char = target_phonemes_flat[j-1]
148
+
149
+ # 判斷是否為數字 (聲調)
150
+ u_is_digit = u_char.isdigit()
151
+ t_is_digit = t_char.isdigit()
152
+
153
+ if u_char == t_char:
154
+ cost = 0
155
+ elif u_is_digit != t_is_digit:
156
+ # 💥 關鍵修改:如果類型不同 (數字 vs 字母),給予超大懲罰
157
+ # 這會強制算法選擇 Insertion 或 Deletion,而不是 Substitution
158
+ cost = 100
159
+ else:
160
+ # 類型相同但字符不同 (e.g. '2' vs '3', 'a' vs 'o') -> 一般錯誤
161
+ cost = 1
162
+
163
  dp[i][j] = min(dp[i-1][j] + 1, dp[i][j-1] + 1, dp[i-1][j-1] + cost)
164
 
165
+ # Backtracking (需要保持一致的 cost 邏輯)
166
  i, j = len(user_phonemes), len(target_phonemes_flat)
167
  user_path, target_path = [], []
168
  while i > 0 or j > 0:
169
+ # 重算當前格子的 cost 以決定路徑
170
+ if i > 0 and j > 0:
171
+ u_char = user_phonemes[i-1]
172
+ t_char = target_phonemes_flat[j-1]
173
+ u_is_digit = u_char.isdigit()
174
+ t_is_digit = t_char.isdigit()
175
+
176
+ if u_char == t_char:
177
+ match_cost = 0
178
+ elif u_is_digit != t_is_digit:
179
+ match_cost = 100
180
+ else:
181
+ match_cost = 1
182
+ else:
183
+ match_cost = float('inf') # 邊界情況
184
+
185
+ # 檢查是否來自對角線 (Substitution/Match)
186
+ if i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + match_cost:
187
+ user_path.insert(0, user_phonemes[i-1])
188
+ target_path.insert(0, target_phonemes_flat[j-1])
189
+ i -= 1; j -= 1
190
+ # 檢查是否來自上方 (Deletion / Missing in User)
191
  elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
192
+ user_path.insert(0, user_phonemes[i-1])
193
+ target_path.insert(0, '-')
194
+ i -= 1
195
+ # 檢查是否來自左方 (Insertion / Extra in User)
196
  else:
197
+ user_path.insert(0, '-')
198
+ target_path.insert(0, target_phonemes_flat[j-1])
199
+ j -= 1
200
 
201
+ # --- 下面的切分邏輯保持不變 ---
202
  alignments_by_word = []
203
  word_start_idx_in_path = 0
204
  target_phoneme_counter_in_path = 0
 
227
 
228
  return alignments_by_word
229
 
230
+ # --- 5. 格式化函數 (保持與英文版一致) ---
231
  def _format_to_json_structure(alignments, sentence, original_words) -> dict:
232
  total_phonemes = 0
233
  total_errors = 0