Spaces:
Running
Running
improved cantonese version v2
Browse files- analyzer/ASR_zh_hk.py +60 -70
analyzer/ASR_zh_hk.py
CHANGED
|
@@ -14,64 +14,48 @@ print(f"INFO: ASR_zh_hk.py is configured to use device: {DEVICE}")
|
|
| 14 |
|
| 15 |
MODEL_NAME = "HK0712/Wav2Vec2_Cantonese"
|
| 16 |
|
| 17 |
-
# --- 1. 輔助函數:粵拼智慧切分器
|
| 18 |
def _tokenize_jyutping_smart(jyutping_str: str) -> list:
|
| 19 |
"""
|
| 20 |
將單個粵拼音節 (如 'gwong2') 根據聲韻學結構切分為 token。
|
| 21 |
Target: 'gwong2' -> ['gw', 'o', 'ng', '2']
|
| 22 |
-
這樣前端顯示時會是 "gw o ng 2",比 "g w o n g 2" 易讀得多。
|
| 23 |
"""
|
| 24 |
try:
|
| 25 |
-
# pycantonese.parse_jyutping 回傳的是一個列表,包含 Jyutping 物件
|
| 26 |
-
# 例如: parse_jyutping('gwong2') -> [Jyutping(onset='gw', nucleus='o', coda='ng', tone='2')]
|
| 27 |
parsed = pycantonese.parse_jyutping(jyutping_str)
|
| 28 |
-
|
| 29 |
tokens = []
|
| 30 |
for jp in parsed:
|
| 31 |
if jp.onset: tokens.append(jp.onset)
|
| 32 |
if jp.nucleus: tokens.append(jp.nucleus)
|
| 33 |
if jp.coda: tokens.append(jp.coda)
|
| 34 |
if jp.tone: tokens.append(jp.tone)
|
| 35 |
-
|
| 36 |
return tokens
|
| 37 |
except:
|
| 38 |
-
# 萬一解析失敗(例如模型輸出的拼音不標準),回退到簡單切分
|
| 39 |
-
# 但保留數字作為獨立 token
|
| 40 |
return re.findall(r'[a-z]+|[0-9]', jyutping_str)
|
| 41 |
|
| 42 |
-
# --- 2. 智慧 G2P 歸屬邏輯
|
| 43 |
def _get_target_jyutping_by_char(sentence: str) -> (list, list):
|
| 44 |
"""
|
| 45 |
將中文句子轉換為「字」級別的粵拼目標。
|
| 46 |
"""
|
| 47 |
-
# pycantonese.characters_to_jyutping 會處理變調與分詞
|
| 48 |
-
# 範例: "廣東話" -> [('廣東話', 'gwong2dung1waa2')]
|
| 49 |
segmented_result = pycantonese.characters_to_jyutping(sentence)
|
| 50 |
|
| 51 |
original_chars_flat = []
|
| 52 |
target_jyutping_groups = []
|
| 53 |
-
|
| 54 |
-
# 簡單的正則表達式,用來把連在一起的拼音分開 (e.g. 'gwong2dung1' -> 'gwong2', 'dung1')
|
| 55 |
jyutping_syllable_pattern = re.compile(r'([a-z]+[1-6])')
|
| 56 |
|
| 57 |
for word_segment, jyutping_segment in segmented_result:
|
| 58 |
-
if not jyutping_segment:
|
| 59 |
-
continue
|
| 60 |
|
| 61 |
syllables = jyutping_syllable_pattern.findall(jyutping_segment)
|
| 62 |
|
| 63 |
-
# 嘗試將分詞後的結果對齊回單個漢字
|
| 64 |
if len(word_segment) == len(syllables):
|
| 65 |
for char, syl in zip(word_segment, syllables):
|
| 66 |
original_chars_flat.append(char)
|
| 67 |
-
# 使用智慧切分:'gwong2' -> ['gw', 'o', 'ng', '2']
|
| 68 |
target_jyutping_groups.append(_tokenize_jyutping_smart(syl))
|
| 69 |
else:
|
| 70 |
-
# 長度不匹配時的備用方案 (逐字處理)
|
| 71 |
print(f"WARNING: Mismatch length for {word_segment}. Fallback to char-by-char G2P.")
|
| 72 |
for char in word_segment:
|
| 73 |
original_chars_flat.append(char)
|
| 74 |
-
# 對單字再做一次 G2P
|
| 75 |
single_res = pycantonese.characters_to_jyutping(char)
|
| 76 |
if single_res and single_res[0][1]:
|
| 77 |
target_jyutping_groups.append(_tokenize_jyutping_smart(single_res[0][1]))
|
|
@@ -80,7 +64,7 @@ def _get_target_jyutping_by_char(sentence: str) -> (list, list):
|
|
| 80 |
|
| 81 |
return original_chars_flat, target_jyutping_groups
|
| 82 |
|
| 83 |
-
# --- 3. 核心分析函數
|
| 84 |
def analyze(audio_file_path: str, target_sentence: str, cache: dict = {}) -> dict:
|
| 85 |
if "model" not in cache:
|
| 86 |
print(f"Cache miss (ASR_zh_hk). Loading model '{MODEL_NAME}'...")
|
|
@@ -119,66 +103,28 @@ def analyze(audio_file_path: str, target_sentence: str, cache: dict = {}) -> dic
|
|
| 119 |
logits = model(input_values).logits
|
| 120 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 121 |
|
| 122 |
-
# 3. 獲取使用者輸出
|
| 123 |
-
# 模型輸出: "gwong2 dung1 waa2" (字串)
|
| 124 |
raw_output_str = processor.decode(predicted_ids[0])
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
#
|
| 128 |
-
# 這樣才能跟 Target 的結構對齊
|
| 129 |
-
|
| 130 |
-
# 步驟 A: 移除空格,變成連續字串 "gwong2dung1waa2"
|
| 131 |
-
# 注意:這一步假設模型輸出的拼音是標準的。如果模型輸出亂碼,tokenize 可能會切得不完美,
|
| 132 |
-
# 但 Needleman-Wunsch 算法會處理這些 mismatch,所以沒關係。
|
| 133 |
-
user_jyutping_clean = raw_output_str.replace(" ", "")
|
| 134 |
-
|
| 135 |
-
# 步驟 B: 使用相同的邏輯切分用戶輸入
|
| 136 |
-
# 因為用戶輸入是一長串,我們用正則表達式把 [a-z] 和 [0-9] 分開,或者嘗試 parse
|
| 137 |
-
# 這裡用一個簡單的策略:把它當作一連串的 components
|
| 138 |
-
# 為了最佳對齊,我們這裡還是用 "Character + Number" 的粒度比較好,
|
| 139 |
-
# 因為用戶可能讀錯導致無法形成合法的 onset/nucleus。
|
| 140 |
-
#
|
| 141 |
-
# ★ 關鍵決策:為了避免用戶讀錯導致 crash,用戶端我們使用較細的粒度 (Regex Split),
|
| 142 |
-
# 然後讓對齊算法去匹配 Target 的 "gw", "o", "ng"。
|
| 143 |
-
# 等等,如果 Target 是 "gw" (1個token),User 是 "g", "w" (2個 tokens),對齊會錯位。
|
| 144 |
-
#
|
| 145 |
-
# ★ 修正策略:
|
| 146 |
-
# 我們也嘗試用 pycantonese.parse_jyutping 去解析用戶的整句輸出。
|
| 147 |
-
# 如果解析成功,我們就用結構化 token。如果失敗(亂讀),回退到字母切分。
|
| 148 |
-
|
| 149 |
user_tokens = []
|
| 150 |
-
# 嘗試把用戶輸出拆成音節 (e.g. "gwong2", "dung1")
|
| 151 |
user_syllables = re.findall(r'[a-z]+[0-9]', raw_output_str)
|
| 152 |
|
| 153 |
if user_syllables:
|
| 154 |
-
# 如果能抓到音節,就用結構化切分
|
| 155 |
for syl in user_syllables:
|
| 156 |
user_tokens.extend(_tokenize_jyutping_smart(syl))
|
| 157 |
else:
|
| 158 |
-
#
|
| 159 |
-
# 但這會導致跟 Target (gw) 對不上。
|
| 160 |
-
# 為了保險,我們這裡對於 Target 也許應該退化成簡單切分?
|
| 161 |
-
# 不,Target 是 Ground Truth,應該保持結構。
|
| 162 |
-
#
|
| 163 |
-
# 最終方案:讓 User stream 盡量 "粘" 在一起。
|
| 164 |
-
# 實際上,Wav2Vec2 輸出的通常是標準拼音。我們直接用 smart parse。
|
| 165 |
user_tokens = _tokenize_jyutping_smart(raw_output_str)
|
| 166 |
|
| 167 |
-
|
| 168 |
# 4. 對齊 (Alignment)
|
| 169 |
word_alignments = _get_phoneme_alignments_by_word(user_tokens, target_jyutping_by_char)
|
| 170 |
|
| 171 |
return _format_to_json_structure(word_alignments, target_sentence, target_chars)
|
| 172 |
|
| 173 |
-
# --- 4. 對齊與格式化 (保持原樣或微調) ---
|
| 174 |
-
# 這裡的邏輯與之前相同,不需要大改,因為它只是比較兩個 list 的相似度。
|
| 175 |
-
# 只要 user_tokens 和 target_jyutping_by_char 的元素 (token) 粒度一致即可。
|
| 176 |
-
# ... ( _get_phoneme_alignments_by_word 與 _format_to_json_structure 代碼同上) ...
|
| 177 |
-
|
| 178 |
-
# 為了節省篇幅,請使用上一版提供的 _get_phoneme_alignments_by_word 和 _format_to_json_structure
|
| 179 |
-
# 只需要替換上面的 _tokenize_jyutping_smart 和 analyze 函數即可。
|
| 180 |
-
# 下面我會把完整的 _get_phoneme_alignments_by_word 貼上以確保完整性。
|
| 181 |
|
|
|
|
| 182 |
def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
|
| 183 |
target_phonemes_flat = []
|
| 184 |
word_boundaries_indices = []
|
|
@@ -189,27 +135,70 @@ def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
|
|
| 189 |
current_idx += len(word_ipa_tokens)
|
| 190 |
word_boundaries_indices.append(current_idx - 1)
|
| 191 |
|
| 192 |
-
# DP
|
| 193 |
dp = np.zeros((len(user_phonemes) + 1, len(target_phonemes_flat) + 1))
|
| 194 |
for i in range(1, len(user_phonemes) + 1): dp[i][0] = i
|
| 195 |
for j in range(1, len(target_phonemes_flat) + 1): dp[0][j] = j
|
| 196 |
|
|
|
|
| 197 |
for i in range(1, len(user_phonemes) + 1):
|
| 198 |
for j in range(1, len(target_phonemes_flat) + 1):
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
dp[i][j] = min(dp[i-1][j] + 1, dp[i][j-1] + 1, dp[i-1][j-1] + cost)
|
| 201 |
|
|
|
|
| 202 |
i, j = len(user_phonemes), len(target_phonemes_flat)
|
| 203 |
user_path, target_path = [], []
|
| 204 |
while i > 0 or j > 0:
|
| 205 |
-
|
| 206 |
-
if i > 0 and j > 0
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
|
| 209 |
-
user_path.insert(0, user_phonemes[i-1])
|
|
|
|
|
|
|
|
|
|
| 210 |
else:
|
| 211 |
-
user_path.insert(0, '-')
|
|
|
|
|
|
|
| 212 |
|
|
|
|
| 213 |
alignments_by_word = []
|
| 214 |
word_start_idx_in_path = 0
|
| 215 |
target_phoneme_counter_in_path = 0
|
|
@@ -238,6 +227,7 @@ def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
|
|
| 238 |
|
| 239 |
return alignments_by_word
|
| 240 |
|
|
|
|
| 241 |
def _format_to_json_structure(alignments, sentence, original_words) -> dict:
|
| 242 |
total_phonemes = 0
|
| 243 |
total_errors = 0
|
|
|
|
| 14 |
|
| 15 |
MODEL_NAME = "HK0712/Wav2Vec2_Cantonese"
|
| 16 |
|
| 17 |
+
# --- 1. 輔助函數:粵拼智慧切分器 ---
|
| 18 |
def _tokenize_jyutping_smart(jyutping_str: str) -> list:
|
| 19 |
"""
|
| 20 |
將單個粵拼音節 (如 'gwong2') 根據聲韻學結構切分為 token。
|
| 21 |
Target: 'gwong2' -> ['gw', 'o', 'ng', '2']
|
|
|
|
| 22 |
"""
|
| 23 |
try:
|
|
|
|
|
|
|
| 24 |
parsed = pycantonese.parse_jyutping(jyutping_str)
|
|
|
|
| 25 |
tokens = []
|
| 26 |
for jp in parsed:
|
| 27 |
if jp.onset: tokens.append(jp.onset)
|
| 28 |
if jp.nucleus: tokens.append(jp.nucleus)
|
| 29 |
if jp.coda: tokens.append(jp.coda)
|
| 30 |
if jp.tone: tokens.append(jp.tone)
|
|
|
|
| 31 |
return tokens
|
| 32 |
except:
|
|
|
|
|
|
|
| 33 |
return re.findall(r'[a-z]+|[0-9]', jyutping_str)
|
| 34 |
|
| 35 |
+
# --- 2. 智慧 G2P 歸屬邏輯 ---
|
| 36 |
def _get_target_jyutping_by_char(sentence: str) -> (list, list):
|
| 37 |
"""
|
| 38 |
將中文句子轉換為「字」級別的粵拼目標。
|
| 39 |
"""
|
|
|
|
|
|
|
| 40 |
segmented_result = pycantonese.characters_to_jyutping(sentence)
|
| 41 |
|
| 42 |
original_chars_flat = []
|
| 43 |
target_jyutping_groups = []
|
|
|
|
|
|
|
| 44 |
jyutping_syllable_pattern = re.compile(r'([a-z]+[1-6])')
|
| 45 |
|
| 46 |
for word_segment, jyutping_segment in segmented_result:
|
| 47 |
+
if not jyutping_segment: continue
|
|
|
|
| 48 |
|
| 49 |
syllables = jyutping_syllable_pattern.findall(jyutping_segment)
|
| 50 |
|
|
|
|
| 51 |
if len(word_segment) == len(syllables):
|
| 52 |
for char, syl in zip(word_segment, syllables):
|
| 53 |
original_chars_flat.append(char)
|
|
|
|
| 54 |
target_jyutping_groups.append(_tokenize_jyutping_smart(syl))
|
| 55 |
else:
|
|
|
|
| 56 |
print(f"WARNING: Mismatch length for {word_segment}. Fallback to char-by-char G2P.")
|
| 57 |
for char in word_segment:
|
| 58 |
original_chars_flat.append(char)
|
|
|
|
| 59 |
single_res = pycantonese.characters_to_jyutping(char)
|
| 60 |
if single_res and single_res[0][1]:
|
| 61 |
target_jyutping_groups.append(_tokenize_jyutping_smart(single_res[0][1]))
|
|
|
|
| 64 |
|
| 65 |
return original_chars_flat, target_jyutping_groups
|
| 66 |
|
| 67 |
+
# --- 3. 核心分析函數 ---
|
| 68 |
def analyze(audio_file_path: str, target_sentence: str, cache: dict = {}) -> dict:
|
| 69 |
if "model" not in cache:
|
| 70 |
print(f"Cache miss (ASR_zh_hk). Loading model '{MODEL_NAME}'...")
|
|
|
|
| 103 |
logits = model(input_values).logits
|
| 104 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 105 |
|
| 106 |
+
# 3. 獲取使用者輸出
|
|
|
|
| 107 |
raw_output_str = processor.decode(predicted_ids[0])
|
| 108 |
|
| 109 |
+
# 處理 User Tokens
|
| 110 |
+
# 嘗試抓取標準音節,如果失敗則退化為 smart parse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
user_tokens = []
|
|
|
|
| 112 |
user_syllables = re.findall(r'[a-z]+[0-9]', raw_output_str)
|
| 113 |
|
| 114 |
if user_syllables:
|
|
|
|
| 115 |
for syl in user_syllables:
|
| 116 |
user_tokens.extend(_tokenize_jyutping_smart(syl))
|
| 117 |
else:
|
| 118 |
+
# 如果用戶完全沒讀出聲調,或者是亂碼
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
user_tokens = _tokenize_jyutping_smart(raw_output_str)
|
| 120 |
|
|
|
|
| 121 |
# 4. 對齊 (Alignment)
|
| 122 |
word_alignments = _get_phoneme_alignments_by_word(user_tokens, target_jyutping_by_char)
|
| 123 |
|
| 124 |
return _format_to_json_structure(word_alignments, target_sentence, target_chars)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
# --- 4. 對齊函數 (已強化:類型感知 Type-Aware) ---
|
| 128 |
def _get_phoneme_alignments_by_word(user_phonemes, target_words_ipa_tokenized):
|
| 129 |
target_phonemes_flat = []
|
| 130 |
word_boundaries_indices = []
|
|
|
|
| 135 |
current_idx += len(word_ipa_tokens)
|
| 136 |
word_boundaries_indices.append(current_idx - 1)
|
| 137 |
|
| 138 |
+
# DP Initialization
|
| 139 |
dp = np.zeros((len(user_phonemes) + 1, len(target_phonemes_flat) + 1))
|
| 140 |
for i in range(1, len(user_phonemes) + 1): dp[i][0] = i
|
| 141 |
for j in range(1, len(target_phonemes_flat) + 1): dp[0][j] = j
|
| 142 |
|
| 143 |
+
# 【【【 Type-Aware Cost Calculation 】】】
|
| 144 |
for i in range(1, len(user_phonemes) + 1):
|
| 145 |
for j in range(1, len(target_phonemes_flat) + 1):
|
| 146 |
+
u_char = user_phonemes[i-1]
|
| 147 |
+
t_char = target_phonemes_flat[j-1]
|
| 148 |
+
|
| 149 |
+
# 判斷是否為數字 (聲調)
|
| 150 |
+
u_is_digit = u_char.isdigit()
|
| 151 |
+
t_is_digit = t_char.isdigit()
|
| 152 |
+
|
| 153 |
+
if u_char == t_char:
|
| 154 |
+
cost = 0
|
| 155 |
+
elif u_is_digit != t_is_digit:
|
| 156 |
+
# 💥 關鍵修改:如果類型不同 (數字 vs 字母),給予超大懲罰
|
| 157 |
+
# 這會強制算法選擇 Insertion 或 Deletion,而不是 Substitution
|
| 158 |
+
cost = 100
|
| 159 |
+
else:
|
| 160 |
+
# 類型相同但字符不同 (e.g. '2' vs '3', 'a' vs 'o') -> 一般錯誤
|
| 161 |
+
cost = 1
|
| 162 |
+
|
| 163 |
dp[i][j] = min(dp[i-1][j] + 1, dp[i][j-1] + 1, dp[i-1][j-1] + cost)
|
| 164 |
|
| 165 |
+
# Backtracking (需要保持一致的 cost 邏輯)
|
| 166 |
i, j = len(user_phonemes), len(target_phonemes_flat)
|
| 167 |
user_path, target_path = [], []
|
| 168 |
while i > 0 or j > 0:
|
| 169 |
+
# 重算當前格子的 cost 以決定路徑
|
| 170 |
+
if i > 0 and j > 0:
|
| 171 |
+
u_char = user_phonemes[i-1]
|
| 172 |
+
t_char = target_phonemes_flat[j-1]
|
| 173 |
+
u_is_digit = u_char.isdigit()
|
| 174 |
+
t_is_digit = t_char.isdigit()
|
| 175 |
+
|
| 176 |
+
if u_char == t_char:
|
| 177 |
+
match_cost = 0
|
| 178 |
+
elif u_is_digit != t_is_digit:
|
| 179 |
+
match_cost = 100
|
| 180 |
+
else:
|
| 181 |
+
match_cost = 1
|
| 182 |
+
else:
|
| 183 |
+
match_cost = float('inf') # 邊界情況
|
| 184 |
+
|
| 185 |
+
# 檢查是否來自對角線 (Substitution/Match)
|
| 186 |
+
if i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + match_cost:
|
| 187 |
+
user_path.insert(0, user_phonemes[i-1])
|
| 188 |
+
target_path.insert(0, target_phonemes_flat[j-1])
|
| 189 |
+
i -= 1; j -= 1
|
| 190 |
+
# 檢查是否來自上方 (Deletion / Missing in User)
|
| 191 |
elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
|
| 192 |
+
user_path.insert(0, user_phonemes[i-1])
|
| 193 |
+
target_path.insert(0, '-')
|
| 194 |
+
i -= 1
|
| 195 |
+
# 檢查是否來自左方 (Insertion / Extra in User)
|
| 196 |
else:
|
| 197 |
+
user_path.insert(0, '-')
|
| 198 |
+
target_path.insert(0, target_phonemes_flat[j-1])
|
| 199 |
+
j -= 1
|
| 200 |
|
| 201 |
+
# --- 下面的切分邏輯保持不變 ---
|
| 202 |
alignments_by_word = []
|
| 203 |
word_start_idx_in_path = 0
|
| 204 |
target_phoneme_counter_in_path = 0
|
|
|
|
| 227 |
|
| 228 |
return alignments_by_word
|
| 229 |
|
| 230 |
+
# --- 5. 格式化函數 (保持與英文版一致) ---
|
| 231 |
def _format_to_json_structure(alignments, sentence, original_words) -> dict:
|
| 232 |
total_phonemes = 0
|
| 233 |
total_errors = 0
|