FireRedASR / app.py
luoqiu's picture
Update app.py
5c80dc5
import sys
import os
# 相当于 export PYTHONPATH=$PWD/:$PYTHONPATH
sys.path.append(os.getcwd())
# 如果还需要 PATH
os.environ["PATH"] = os.path.join(os.getcwd(), "fireredasr") + ":" + \
os.path.join(os.getcwd(), "fireredasr", "utils") + ":" + \
os.environ["PATH"]
import gradio as gr
import librosa
import soundfile as sf
from fireredasr.models.fireredasr import FireRedAsr
import os
from huggingface_hub import snapshot_download
model_dir = "pretrained_models/FireRedASR-AED-L"
if not os.path.exists(model_dir):
print("Downloading FireRedASR-AED-L from Hugging Face...")
snapshot_download(
repo_id="FireRedTeam/FireRedASR-AED-L",
local_dir=model_dir,
local_dir_use_symlinks=False
)
# 1. 加载模型(只加载一次)
model = FireRedAsr.from_pretrained(
"aed",
"pretrained_models/FireRedASR-AED-L"
)
def preprocess_audio(input_path):
"""
将音频重采样到 16kHz 单声道并保存到临时文件
"""
audio, sr = librosa.load(input_path, sr=16000, mono=True)
tmp_path = input_path + "_16k.wav"
sf.write(tmp_path, audio, 16000)
return tmp_path
def transcribe_fn(audio_file):
"""
audio_file: Gradio 会返回音频的临时文件路径
"""
if audio_file is None:
return "请先录音或上传音频"
# 预处理(重采样)
processed_path = preprocess_audio(audio_file)
batch_uttid = ["audio_input"]
batch_wav_path = [processed_path]
results = model.transcribe(
batch_uttid,
batch_wav_path,
{
"use_gpu": 0, # Spaces CPU = 0, GPU Space 可改成 1
"beam_size": 3,
"nbest": 1,
"decode_max_len": 0,
"softmax_smoothing": 1.0,
"aed_length_penalty": 0.0,
"eos_penalty": 1.0
}
)
if isinstance(results, list):
return results[0].get("text", str(results[0]))
else:
return str(results)
# 2. Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🔊 FireRedASR 实时语音识别 Demo\n支持麦克风实时识别 + 文件上传")
with gr.Tab("📁 文件上传识别或者实时麦克风录音"):
gr.Markdown("上传音频文件或者实时录音,点击按钮进行识别。")
file_audio = gr.Audio(sources=["upload"], type="filepath", label="上传音频文件(支持 WAV/MP3 等)")
file_output = gr.Textbox(label="识别结果")
file_btn = gr.Button("开始识别")
file_btn.click(fn=transcribe_fn, inputs=file_audio, outputs=file_output)
if __name__ == "__main__":
demo.launch()