HK0712 commited on
Commit
211b028
·
1 Parent(s): 8d0199e

feat: Implement core pronunciation analysis API

Browse files
Files changed (5) hide show
  1. .gitignore +15 -10
  2. ASR.py → analyzer/ASR_en_us.py +119 -112
  3. analyzer/__init__.py +0 -0
  4. main.py +127 -0
  5. requirements.txt +6 -4
.gitignore CHANGED
@@ -1,16 +1,21 @@
1
- # 忽略 Python 虛擬環境
 
 
 
 
 
2
  venv/
 
3
 
4
- # 忽略 VS Code 的設定
5
  .vscode/
 
6
 
7
- # 忽略 Python 的快取檔案
8
- __pycache__/
9
- *.pyc
10
-
11
- # 忽略下載的本地模型 (非常重要,因為它太大了!)
12
  ASRs/
13
 
14
- # 忽略音訊檔案 (如果它們只是測試用的話)
15
- TestAudio/
16
- *.wav
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .env
7
  venv/
8
+ env/
9
 
10
+ # IDE / Editor
11
  .vscode/
12
+ .idea/
13
 
14
+ # ASR Models (非常重要,模型檔案通常很大)
 
 
 
 
15
  ASRs/
16
 
17
+ # Temporary files
18
+ temp_audio/
19
+
20
+ # macOS
21
+ .DS_Store
ASR.py → analyzer/ASR_en_us.py RENAMED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import soundfile as sf
3
  import librosa
@@ -5,70 +7,94 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
  import os
6
  from phonemizer import phonemize
7
  import numpy as np
8
- from datetime import datetime
9
- from colorama import init, Fore, Style
10
-
11
- # 初始化 colorama
12
- init(autoreset=True)
13
 
14
- # --- 1. 全域設定 ---
15
- TARGET_SENTENCE = "how was your day"
16
- AUDIO_FILE_PATH = "./TestAudio/hello.wav"
17
  MODEL_NAME = "MultiBridge/wav2vec-LnNor-IPA-ft"
18
  MODEL_SAVE_PATH = "./ASRs/MultiBridge-wav2vec-LnNor-IPA-ft-local"
19
 
20
- # --- 2. 載入模型和處理器 (保持不變) ---
21
- print(f"正在準備模型 '{MODEL_NAME}'...")
22
- try:
23
- if not os.path.exists(MODEL_SAVE_PATH):
24
- print(f"本地找不到模型,正在從 Hugging Face 下載並儲存...")
25
- processor_to_save = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
26
- model_to_save = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
27
- processor_to_save.save_pretrained(MODEL_SAVE_PATH)
28
- model_to_save.save_pretrained(MODEL_SAVE_PATH)
29
- print("模型已成功下載並儲存。")
30
- else:
31
- print(f"在 '{MODEL_SAVE_PATH}' 中找到本地模型。")
32
- processor = Wav2Vec2Processor.from_pretrained(MODEL_SAVE_PATH)
33
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_SAVE_PATH)
34
- print("模型和處理器載入成功!")
35
- except Exception as e:
36
- print(f"處理或載入模型時發生錯誤: {e}")
37
- exit()
38
-
39
- # --- 3. 準備目標音標 (Target) - (已修改) ---
40
- print("正在準備目標音標...")
41
- # 在這一步就徹底移除重音符號,得到最乾淨的目標音標列表
42
- target_ipa_by_word = [
43
- word.replace('ˌ', '').replace('ˈ', '').replace('ː', '')
44
- for word in phonemize(TARGET_SENTENCE, language='en-us', backend='espeak', with_stress=True).split()
45
- ]
46
-
47
- # --- 4. 讀取音訊並進行辨識 (保持不變) ---
48
- print(f"正在讀取音訊檔案: {AUDIO_FILE_PATH}...")
49
- try:
50
- speech, sample_rate = sf.read(AUDIO_FILE_PATH)
51
- if sample_rate != 16000:
52
- speech = librosa.resample(y=speech, orig_sr=sample_rate, target_sr=16000)
53
- except Exception as e:
54
- print(f"讀取或處理音訊時發生錯誤: {e}")
55
- exit()
56
- print("正在辨識用戶的實際發音...")
57
- input_values = processor(speech, sampling_rate=16000, return_tensors="pt").input_values
58
- with torch.no_grad():
59
- logits = model(input_values).logits
60
- predicted_ids = torch.argmax(logits, dim=-1)
61
- user_ipa_full = processor.decode(predicted_ids[0])
62
-
63
-
64
- # --- 5. 核心函式:現在處理的都是乾淨的音標,邏輯保持不變 ---
65
- def get_phoneme_alignments_by_word(user_phoneme_str, target_words_ipa):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  user_phonemes = list(user_phoneme_str.replace(' ', ''))
67
  target_phonemes_flat = []
68
  word_boundaries = []
69
  current_idx = 0
70
  for word_ipa in target_words_ipa:
71
- phonemes = list(word_ipa) # 已經是乾淨的音標
72
  target_phonemes_flat.extend(phonemes)
73
  current_idx += len(phonemes)
74
  word_boundaries.append(current_idx)
@@ -111,79 +137,60 @@ def get_phoneme_alignments_by_word(user_phoneme_str, target_words_ipa):
111
 
112
  return alignments_by_word
113
 
114
- # --- 6. 格式化輸出函式 (已簡化) ---
115
- def format_and_print_final_report(alignments):
 
116
  total_phonemes = 0
117
  total_errors = 0
118
- correct_words = 0
119
-
120
- target_line_parts = []
121
- user_line_parts = []
122
 
123
- for alignment in alignments:
124
  word_is_correct = True
 
125
 
126
- max_lens = [max(len(t), len(u)) for t, u in zip(alignment['target'], alignment['user'])]
127
-
128
- target_word_parts = [p.ljust(max_lens[j]) for j, p in enumerate(alignment['target'])]
129
- target_line_parts.append(f"[ {' '.join(target_word_parts)} ]")
130
-
131
- user_word_parts = []
132
- for j, user_phoneme in enumerate(alignment['user']):
133
  target_phoneme = alignment['target'][j]
 
134
  is_match = (user_phoneme == target_phoneme)
135
 
 
 
 
 
 
 
136
  if not is_match:
137
  word_is_correct = False
138
- if user_phoneme != '-' and target_phoneme != '-': # 替換
139
- total_errors += 1
140
- elif user_phoneme == '-': # 省略
141
- total_errors += 1
142
- else: # 插入
143
- total_errors += 1
144
-
145
- color = Fore.GREEN if is_match else Fore.RED
146
- user_word_parts.append(f"{color}{user_phoneme.ljust(max_lens[j])}{Style.RESET_ALL}")
147
-
148
- user_line_parts.append(f"[ {' '.join(user_word_parts)} ]")
149
 
150
  if word_is_correct:
151
- correct_words += 1
 
 
 
 
 
 
152
 
153
  total_phonemes += sum(1 for p in alignment['target'] if p != '-')
154
 
155
- # --- 計算統計資料 ---
156
  total_words = len(alignments)
157
- incorrect_words = total_words - correct_words
158
- overall_score = (correct_words / total_words) * 100 if total_words > 0 else 0
159
  phoneme_error_rate = (total_errors / total_phonemes) * 100 if total_phonemes > 0 else 0
160
 
161
- # --- 列印報告 ---
162
- separator = "="*70
163
- print("\n" + separator)
164
- print("Pronunciation Analysis".center(70))
165
- print(separator + "\n")
166
-
167
- print(f"Sentence: {TARGET_SENTENCE}\n")
168
- print(f"Target : {' '.join(target_line_parts)}")
169
- print(f"User : {' '.join(user_line_parts)}")
170
-
171
- print("\n" + "-" * 70)
172
- print("[ Summary ]")
173
- print("-" * 70)
174
- print(f"- Overall Score: {overall_score:.1f}%")
175
- print(f"- Total Words: {total_words}")
176
- print(f"- Correct Words: {correct_words}")
177
- print(f"- Incorrect Words: {incorrect_words}")
178
- print(f"- Phoneme Error Rate: {phoneme_error_rate:.2f}% ({total_errors} errors in {total_phonemes} target phonemes)")
179
- # (已修改) 使用 UTC 時間
180
- print(f"- Analysis Timestamp: {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')} (UTC)")
181
 
182
- print("\n" + separator)
183
-
184
-
185
- # --- 主流程 ---
186
- print("正在進行音素級對齊...")
187
- word_alignments = get_phoneme_alignments_by_word(user_ipa_full, target_ipa_by_word)
188
-
189
- format_and_print_final_report(word_alignments)
 
1
+ # analyzer/ASR_en_us.py
2
+
3
  import torch
4
  import soundfile as sf
5
  import librosa
 
7
  import os
8
  from phonemizer import phonemize
9
  import numpy as np
10
+ from datetime import datetime, timezone
 
 
 
 
11
 
12
+ # --- 1. 全域設定與模型載入函數 ---
13
+ # 模型名稱和路徑保持不變
 
14
  MODEL_NAME = "MultiBridge/wav2vec-LnNor-IPA-ft"
15
  MODEL_SAVE_PATH = "./ASRs/MultiBridge-wav2vec-LnNor-IPA-ft-local"
16
 
17
+ # processor model 設為全域變數,以便快取
18
+ processor = None
19
+ model = None
20
+
21
+ def load_model():
22
+ """
23
+ 在應用程式啟動時載入模型和處理器。
24
+ 如果模型已載入,則跳過。
25
+ """
26
+ global processor, model
27
+ if processor and model:
28
+ print("英文模型已載入,跳過。")
29
+ return True
30
+
31
+ print(f"正在準備英文 (en-us) ASR 模型 '{MODEL_NAME}'...")
32
+ try:
33
+ if not os.path.exists(MODEL_SAVE_PATH):
34
+ print(f"本地找不到模型,正在從 Hugging Face 下載並儲存...")
35
+ processor_to_save = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
36
+ model_to_save = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
37
+ processor_to_save.save_pretrained(MODEL_SAVE_PATH)
38
+ model_to_save.save_pretrained(MODEL_SAVE_PATH)
39
+ print("模型已成功下載並儲存。")
40
+ else:
41
+ print(f"在 '{MODEL_SAVE_PATH}' 中找到本地模型。")
42
+
43
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_SAVE_PATH)
44
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_SAVE_PATH)
45
+ print("英文 (en-us) 模型和處理器載入成功!")
46
+ return True
47
+ except Exception as e:
48
+ print(f"處理或載入 en-us 模型時發生錯誤: {e}")
49
+ # 將錯誤向上拋出,讓主應用知道失敗
50
+ raise RuntimeError(f"Failed to load en-us model: {e}")
51
+
52
+ # --- 2. 核心分析函數 (主入口) ---
53
+ def analyze(audio_file_path: str, target_sentence: str) -> dict:
54
+ """
55
+ 接收音訊檔案路徑和目標句子,回傳詳細的發音分析字典。
56
+ 這是此模組的主要進入點。
57
+ """
58
+ if not processor or not model:
59
+ raise RuntimeError("模型尚未載入。請確保在呼叫 analyze 之前已成功執行 load_model()。")
60
+
61
+ # --- 準備目標音標 (您的原始邏輯) ---
62
+ target_ipa_by_word = [
63
+ word.replace('ˌ', '').replace('ˈ', '').replace('ː', '')
64
+ for word in phonemize(target_sentence, language='en-us', backend='espeak', with_stress=True).split()
65
+ ]
66
+ target_words_original = target_sentence.split()
67
+
68
+ # --- 讀取音訊並進行辨識 (您的原始邏輯) ---
69
+ try:
70
+ speech, sample_rate = sf.read(audio_file_path)
71
+ if sample_rate != 16000:
72
+ speech = librosa.resample(y=speech, orig_sr=sample_rate, target_sr=16000)
73
+ except Exception as e:
74
+ raise IOError(f"讀取或處理音訊時發生錯誤: {e}")
75
+
76
+ input_values = processor(speech, sampling_rate=16000, return_tensors="pt").input_values
77
+ with torch.no_grad():
78
+ logits = model(input_values).logits
79
+ predicted_ids = torch.argmax(logits, dim=-1)
80
+ user_ipa_full = processor.decode(predicted_ids[0])
81
+
82
+ # --- 音素級對齊 (您的原始邏輯) ---
83
+ word_alignments = _get_phoneme_alignments_by_word(user_ipa_full, target_ipa_by_word)
84
+
85
+ # --- 格式化為指定的 JSON 結構 ---
86
+ return _format_to_json_structure(word_alignments, target_sentence, target_words_original)
87
+
88
+
89
+ # --- 3. 您的原始對齊函數 (設為內部函數,未修改邏輯) ---
90
+ def _get_phoneme_alignments_by_word(user_phoneme_str, target_words_ipa):
91
+ # ... 您的程式碼完全不變 ...
92
  user_phonemes = list(user_phoneme_str.replace(' ', ''))
93
  target_phonemes_flat = []
94
  word_boundaries = []
95
  current_idx = 0
96
  for word_ipa in target_words_ipa:
97
+ phonemes = list(word_ipa)
98
  target_phonemes_flat.extend(phonemes)
99
  current_idx += len(phonemes)
100
  word_boundaries.append(current_idx)
 
137
 
138
  return alignments_by_word
139
 
140
+ # --- 4. 新增的格式化函數 (設為內部函數) ---
141
+ def _format_to_json_structure(alignments, sentence, original_words) -> dict:
142
+ # ... 與上一版相同,用於生成您指定的 JSON 結構 ...
143
  total_phonemes = 0
144
  total_errors = 0
145
+ correct_words_count = 0
146
+ words_data = []
 
 
147
 
148
+ for i, alignment in enumerate(alignments):
149
  word_is_correct = True
150
+ phonemes_data = []
151
 
152
+ for j in range(len(alignment['target'])):
 
 
 
 
 
 
153
  target_phoneme = alignment['target'][j]
154
+ user_phoneme = alignment['user'][j]
155
  is_match = (user_phoneme == target_phoneme)
156
 
157
+ phonemes_data.append({
158
+ "target": target_phoneme,
159
+ "user": user_phoneme,
160
+ "isMatch": is_match
161
+ })
162
+
163
  if not is_match:
164
  word_is_correct = False
165
+ if user_phoneme != '-' and target_phoneme != '-': total_errors += 1
166
+ elif user_phoneme == '-': total_errors += 1
167
+ else: total_errors += 1
 
 
 
 
 
 
 
 
168
 
169
  if word_is_correct:
170
+ correct_words_count += 1
171
+
172
+ words_data.append({
173
+ "word": original_words[i] if i < len(original_words) else "N/A",
174
+ "isCorrect": word_is_correct,
175
+ "phonemes": phonemes_data
176
+ })
177
 
178
  total_phonemes += sum(1 for p in alignment['target'] if p != '-')
179
 
 
180
  total_words = len(alignments)
181
+ overall_score = (correct_words_count / total_words) * 100 if total_words > 0 else 0
 
182
  phoneme_error_rate = (total_errors / total_phonemes) * 100 if total_phonemes > 0 else 0
183
 
184
+ final_result = {
185
+ "sentence": sentence,
186
+ "analysisTimestampUTC": datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'),
187
+ "summary": {
188
+ "overallScore": round(overall_score, 1),
189
+ "totalWords": total_words,
190
+ "correctWords": correct_words_count,
191
+ "phonemeErrorRate": round(phoneme_error_rate, 2)
192
+ },
193
+ "words": words_data
194
+ }
 
 
 
 
 
 
 
 
 
195
 
196
+ return final_result
 
 
 
 
 
 
 
analyzer/__init__.py ADDED
File without changes
main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py (Final Corrected Version)
2
+
3
+ import uvicorn
4
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
5
+ from fastapi.responses import JSONResponse
6
+ import os
7
+ import shutil
8
+ from contextlib import asynccontextmanager
9
+ import asyncio
10
+ import importlib.util
11
+ import sys
12
+ from datetime import datetime # The required import statement
13
+
14
+ # Ngrok is optional, so we handle its potential absence
15
+ try:
16
+ from pyngrok import ngrok, conf
17
+ PYNGROK_INSTALLED = True
18
+ except ImportError:
19
+ PYNGROK_INSTALLED = False
20
+
21
+ # --- Analyzer Loading Logic ---
22
+ ANALYZER_MODULES = {}
23
+ SUPPORTED_LANGUAGES = ["en_us"]
24
+
25
+ async def load_analyzers():
26
+ print("正在預載入所有支援的分析器模型...")
27
+ for lang in SUPPORTED_LANGUAGES:
28
+ try:
29
+ module_name = f"analyzer.ASR_{lang}"
30
+ spec = importlib.util.find_spec(module_name)
31
+ if spec is None:
32
+ print(f"警告:找不到 {lang} 的分析器模組: {module_name}")
33
+ continue
34
+
35
+ analyzer_module = importlib.util.module_from_spec(spec)
36
+ sys.modules[module_name] = analyzer_module
37
+ spec.loader.exec_module(analyzer_module)
38
+
39
+ if hasattr(analyzer_module, 'load_model'):
40
+ await asyncio.to_thread(analyzer_module.load_model)
41
+ ANALYZER_MODULES[lang] = analyzer_module
42
+ print(f"'{lang}' 分析器載入成功。")
43
+ else:
44
+ print(f"警告:'{lang}' 模組中沒有找到 load_model 函數。")
45
+ except Exception as e:
46
+ print(f"錯誤:載入 '{lang}' 分析器時失敗: {e}")
47
+
48
+ # --- FastAPI Lifespan ---
49
+ @asynccontextmanager
50
+ async def lifespan(app: FastAPI):
51
+ print("應用程式啟動中...")
52
+ await load_analyzers()
53
+
54
+ if PYNGROK_INSTALLED:
55
+ NGROK_AUTHTOKEN = os.environ.get("NGROK_AUTHTOKEN")
56
+ if NGROK_AUTHTOKEN:
57
+ conf.get_default().auth_token = NGROK_AUTHTOKEN
58
+ print("正在啟動 ngrok 通道...")
59
+ public_url = await asyncio.to_thread(ngrok.connect, 8000, name="pronunciation-api")
60
+ print(f"Ngrok 通道已建立,公開 URL: {public_url}")
61
+ else:
62
+ print("警告:未設定 NGROK_AUTHTOKEN,Ngrok 將不會啟動。")
63
+ else:
64
+ print("警告: pyngrok 套件未安裝,Ngrok 將不會啟動。")
65
+
66
+ yield
67
+
68
+ print("應用程式關閉中...")
69
+ if PYNGROK_INSTALLED and ngrok.get_tunnels():
70
+ ngrok.disconnect()
71
+ print("Ngrok 通道已關閉。")
72
+
73
+ # --- FastAPI App Initialization ---
74
+ app = FastAPI(lifespan=lifespan)
75
+ TEMP_DIR = "temp_audio"
76
+ os.makedirs(TEMP_DIR, exist_ok=True)
77
+
78
+ # --- API Endpoint ---
79
+ @app.post("/api/v1/recognize")
80
+ async def recognize_speech_api(
81
+ language: str = Form(...),
82
+ target_sentence: str = Form(...),
83
+ file: UploadFile = File(...)
84
+ ):
85
+ if language not in ANALYZER_MODULES:
86
+ raise HTTPException(status_code=400, detail=f"不支援的語言: '{language}'。支援的語言: {list(ANALYZER_MODULES.keys())}")
87
+
88
+ if not file.filename or not file.filename.lower().endswith('.wav'):
89
+ raise HTTPException(status_code=400, detail="檔案格式錯誤或檔名無效,請上傳 .wav 檔案。")
90
+
91
+ safe_filename = os.path.basename(file.filename)
92
+ temp_file_path = os.path.join(TEMP_DIR, f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{safe_filename}")
93
+
94
+ try:
95
+ with open(temp_file_path, "wb") as buffer:
96
+ shutil.copyfileobj(file.file, buffer)
97
+
98
+ analyzer_module = ANALYZER_MODULES[language]
99
+ print(f"使用 '{language}' 分析器處理檔案: {file.filename}")
100
+
101
+ analysis_result = await asyncio.to_thread(
102
+ analyzer_module.analyze, temp_file_path, target_sentence
103
+ )
104
+
105
+ return JSONResponse(content=analysis_result)
106
+ except Exception as e:
107
+ print(f"處理請求時發生未預期的錯誤: {e}")
108
+ raise HTTPException(status_code=500, detail=f"伺服器內部錯誤: {str(e)}")
109
+ finally:
110
+ if os.path.exists(temp_file_path):
111
+ os.remove(temp_file_path)
112
+ if file:
113
+ await file.close()
114
+
115
+ @app.get("/")
116
+ def read_root():
117
+ return {"message": "發音分析 API 已啟動。請使用 POST /api/v1/recognize 端點。"}
118
+
119
+ # --- Server Execution ---
120
+ if __name__ == "__main__":
121
+ print("="*60)
122
+ if PYNGROK_INSTALLED:
123
+ print("請確保已設定 NGROK_AUTHTOKEN 環境變數以便 ngrok 正常運作。")
124
+ else:
125
+ print("pyngrok 未安裝,服務僅在本地運行。")
126
+ print("="*60)
127
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
requirements.txt CHANGED
@@ -1,8 +1,10 @@
 
 
 
 
1
  torch
2
  soundfile
3
  librosa
4
  transformers
5
- phonemizer
6
- fastapi
7
- uvicorn[standard]
8
- colorama
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pyngrok
4
+ python-multipart
5
  torch
6
  soundfile
7
  librosa
8
  transformers
9
+ phonemizer[espeak]
10
+ numpy