Spaces:
Running
Running
| # app.py | |
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import ( | |
| pipeline, | |
| WhisperProcessor, | |
| WhisperForConditionalGeneration, | |
| ) | |
| # === 你的模型設定 === | |
| HF_REPO = "aciang/ATD-TriPyr-ASR" | |
| SUBROOT = "refined_20251014_013556" | |
| CKPTS = ["checkpoint-200", "checkpoint-400", "checkpoint-600", "checkpoint-776"] | |
| DEFAULT = "checkpoint-400" | |
| # 讀 Secret(Private Space 建議用 secret) | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| # CPU / GPU 自動偵測 | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| # 依 ckpt 快取 pipeline,避免每次重建 | |
| _CACHE = {} | |
| def build_asr(ckpt: str): | |
| """ | |
| 建立/取用 ASR pipeline;只下載所選 checkpoint 的權重 | |
| """ | |
| key = f"{ckpt}_{DEVICE}" | |
| if key in _CACHE: | |
| return _CACHE[key] | |
| subfolder = f"{SUBROOT}/{ckpt}" | |
| # 1) Processor(tokenizer + feature_extractor)先嘗試用你的模型庫「根目錄」 | |
| # 若找不到 preprocessor_config.json,就退回 openai/whisper-small | |
| try: | |
| processor = WhisperProcessor.from_pretrained(HF_REPO, token=HF_TOKEN) | |
| except Exception: | |
| processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
| # 2) Model 從子資料夾載入(你的 checkpoint) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| HF_REPO, | |
| subfolder=subfolder, | |
| torch_dtype="auto", # 讓 transformers 自行決定 dtype(GPU 多半是 fp16) | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN, | |
| ) | |
| asr = pipeline( | |
| task="automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| device=DEVICE, | |
| chunk_length_s=30, # 長音檔切段 | |
| return_timestamps=False, | |
| ) | |
| _CACHE[key] = asr | |
| return asr | |
| def transcribe(audio_path, ckpt, lang_choice): | |
| """ | |
| audio_path: 來自 gr.Audio(type='filepath') 的檔案路徑 | |
| ckpt: checkpoint 選擇 | |
| lang_choice:'auto' 或指定語言代碼(例:'tay','pwn','ami') | |
| """ | |
| if not audio_path: | |
| return "請先上傳或錄製一段音檔" | |
| try: | |
| asr = build_asr(ckpt) | |
| except Exception as e: | |
| return f"(初始化模型時出錯){e}" | |
| # Whisper:若指定語言,就帶入 generate_kwargs;否則交給 auto 偵測 | |
| call_kwargs = {} | |
| if lang_choice and lang_choice != "auto": | |
| call_kwargs["generate_kwargs"] = {"language": lang_choice, "task": "transcribe"} | |
| try: | |
| result = asr(audio_path, **call_kwargs) | |
| except Exception as e: | |
| return f"(轉錄時出錯){e}" | |
| if isinstance(result, dict) and "text" in result: | |
| return result["text"] | |
| return str(result) | |
| # === Gradio 介面 === | |
| with gr.Blocks(title="ATD TriPyr ASR (HF Space)") as demo: | |
| gr.Markdown("## ATD TriPyr ASR\n選擇 checkpoint、上傳或錄音,一鍵轉文字。") | |
| with gr.Row(): | |
| ckpt = gr.Dropdown(choices=CKPTS, value=DEFAULT, label="Checkpoint") | |
| # 依需求可自行增減;auto=自動偵測、tay=泰雅賽考利克、pwn=排灣、ami=阿美 | |
| lang = gr.Dropdown(choices=["auto", "tay", "pwn", "ami"], value="auto", label="Language") | |
| with gr.Row(): | |
| audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio(16~48kHz 皆可)") | |
| text_out = gr.Textbox(label="Transcription", lines=8) | |
| btn = gr.Button("Transcribe", variant="primary") | |
| btn.click(fn=transcribe, inputs=[audio, ckpt, lang], outputs=text_out) | |
| gr.Markdown( | |
| """ | |
| **小提醒:** | |
| - 第一次選擇某個 checkpoint 會下載權重;看到 Busy 請稍等。 | |
| - 若 Space 出現 storage limit exceeded,請到模型庫刪除訓練用大檔(optimizer/scheduler 等),只保留 `model.safetensors` 與必要設定檔。 | |
| - 若你的模型庫根目錄暫時沒有 `preprocessor_config.json`,本程式會自動改用 `openai/whisper-small` 的 processor。 | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| # 開啟 SSR(HF Spaces 預設可),ZeroGPU/CPU 也能跑 | |
| demo.launch() | |