aciang's picture
Update app.py
798a875 verified
# app.py
import os
import torch
import gradio as gr
from transformers import (
pipeline,
WhisperProcessor,
WhisperForConditionalGeneration,
)
# === 你的模型設定 ===
HF_REPO = "aciang/ATD-TriPyr-ASR"
SUBROOT = "refined_20251014_013556"
CKPTS = ["checkpoint-200", "checkpoint-400", "checkpoint-600", "checkpoint-776"]
DEFAULT = "checkpoint-400"
# 讀 Secret(Private Space 建議用 secret)
HF_TOKEN = os.getenv("HF_TOKEN", None)
# CPU / GPU 自動偵測
DEVICE = 0 if torch.cuda.is_available() else -1
# 依 ckpt 快取 pipeline,避免每次重建
_CACHE = {}
def build_asr(ckpt: str):
"""
建立/取用 ASR pipeline;只下載所選 checkpoint 的權重
"""
key = f"{ckpt}_{DEVICE}"
if key in _CACHE:
return _CACHE[key]
subfolder = f"{SUBROOT}/{ckpt}"
# 1) Processor(tokenizer + feature_extractor)先嘗試用你的模型庫「根目錄」
# 若找不到 preprocessor_config.json,就退回 openai/whisper-small
try:
processor = WhisperProcessor.from_pretrained(HF_REPO, token=HF_TOKEN)
except Exception:
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
# 2) Model 從子資料夾載入(你的 checkpoint)
model = WhisperForConditionalGeneration.from_pretrained(
HF_REPO,
subfolder=subfolder,
torch_dtype="auto", # 讓 transformers 自行決定 dtype(GPU 多半是 fp16)
low_cpu_mem_usage=True,
token=HF_TOKEN,
)
asr = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
device=DEVICE,
chunk_length_s=30, # 長音檔切段
return_timestamps=False,
)
_CACHE[key] = asr
return asr
def transcribe(audio_path, ckpt, lang_choice):
"""
audio_path: 來自 gr.Audio(type='filepath') 的檔案路徑
ckpt: checkpoint 選擇
lang_choice:'auto' 或指定語言代碼(例:'tay','pwn','ami')
"""
if not audio_path:
return "請先上傳或錄製一段音檔"
try:
asr = build_asr(ckpt)
except Exception as e:
return f"(初始化模型時出錯){e}"
# Whisper:若指定語言,就帶入 generate_kwargs;否則交給 auto 偵測
call_kwargs = {}
if lang_choice and lang_choice != "auto":
call_kwargs["generate_kwargs"] = {"language": lang_choice, "task": "transcribe"}
try:
result = asr(audio_path, **call_kwargs)
except Exception as e:
return f"(轉錄時出錯){e}"
if isinstance(result, dict) and "text" in result:
return result["text"]
return str(result)
# === Gradio 介面 ===
with gr.Blocks(title="ATD TriPyr ASR (HF Space)") as demo:
gr.Markdown("## ATD TriPyr ASR\n選擇 checkpoint、上傳或錄音,一鍵轉文字。")
with gr.Row():
ckpt = gr.Dropdown(choices=CKPTS, value=DEFAULT, label="Checkpoint")
# 依需求可自行增減;auto=自動偵測、tay=泰雅賽考利克、pwn=排灣、ami=阿美
lang = gr.Dropdown(choices=["auto", "tay", "pwn", "ami"], value="auto", label="Language")
with gr.Row():
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio(16~48kHz 皆可)")
text_out = gr.Textbox(label="Transcription", lines=8)
btn = gr.Button("Transcribe", variant="primary")
btn.click(fn=transcribe, inputs=[audio, ckpt, lang], outputs=text_out)
gr.Markdown(
"""
**小提醒:**
- 第一次選擇某個 checkpoint 會下載權重;看到 Busy 請稍等。
- 若 Space 出現 storage limit exceeded,請到模型庫刪除訓練用大檔(optimizer/scheduler 等),只保留 `model.safetensors` 與必要設定檔。
- 若你的模型庫根目錄暫時沒有 `preprocessor_config.json`,本程式會自動改用 `openai/whisper-small` 的 processor。
"""
)
if __name__ == "__main__":
# 開啟 SSR(HF Spaces 預設可),ZeroGPU/CPU 也能跑
demo.launch()