yhzx233's picture
feat: update app
bab68f6
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MOSS Transcribe Diarize Gradio Demo (Remote API)
==========================
Provides a web interface for audio/video upload, transcribing using a fixed prompt.
"""
import base64
import argparse
import json
import os
import re
import subprocess
from pathlib import Path
from typing import Any, Tuple
import gradio as gr
import requests
DEFAULT_API_URL = os.getenv("MOSS_API_URL", "http://inference.volc.mosi.cn/asr-hf")
DEFAULT_AUTH_TOKEN = os.getenv("MOSS_API_AUTH_TOKEN", "")
MAX_AUDIO_DURATION = 30 # seconds
MAX_FILE_SIZE = 1024 * 1024 * 10 # 10MB
FIXED_PROMPT_NAME = "Speaker + Text Labeling"
FIXED_PROMPT = (
"请将以下对话转录为文本,使用 [S1] [S2] 等说话人标签,对于音频中的事件,使用 [event] 标签表示。"
"富有情感的文本用<emotion>对应文本</emotion> 表示,使用 <ovl> 标签表示音频有部分重叠,<ins></ins> 标签表示音频有插入。"
"自动检测音频的语言,说话人标签和 <ovl> <ins> 始终用英文,event 和 emotion 跟随音频语言。"
)
AUDIO_SUFFIXES: Tuple[str, ...] = (
".wav",
".mp3",
".flac",
".aac",
".m4a",
".ogg",
".wma",
".mp4",
".mov",
".mkv",
".avi",
".wmv",
".webm",
)
APP_ARGS: argparse.Namespace = argparse.Namespace()
# --- I18N Configuration (For UI Elements) ---
i18n = gr.I18n(
en={
"header": "## 🎤 MOSS Transcribe Diarize: Accurate Transcription with Speaker Diarization",
"tips": (
f"> **💡 Note**: This demo currently supports ASR with speaker recognition for audio clips up to **{MAX_AUDIO_DURATION}s**. \n"
"> Stay tuned! Full-featured long audio support (including timestamps) is **Coming Soon**. \n"
"> **🔗 Links**: [paper](https://arxiv.org/abs/2601.01554) · [model page](https://mosi.cn/models/moss-transcribe-diarize)"
),
"audio_tab": "🎵 Audio",
"audio_label": "📥 Upload Audio",
"video_tab": "🎬 Video",
"video_tip": "💡 **Note**: Uploading a video will extract the audio for transcription.",
"video_label": "📥 Upload Video",
"run_btn": "🚀 Start Transcription",
"output_label": "📝 Transcription Result",
},
**{"zh-CN": {
"header": "## 🎤 MOSS Transcribe Diarize: 精准转写与说话人识别",
"tips": (
f"> **💡 说明**:本演示版本仅支持短音频(最长 **{MAX_AUDIO_DURATION}s**)的文本转写及说话人识别。 \n"
"> 敬请期待!支持完整时间戳的长音频 API **即将上线**。 \n"
"> **🔗 链接**:[paper](https://arxiv.org/abs/2601.01554) · [model page](https://mosi.cn/models/moss-transcribe-diarize)"
),
"audio_tab": "🎵 音频",
"audio_label": "📥 上传音频",
"video_tab": "🎬 视频",
"video_tip": "💡 **提示**:上传视频将提取其中的音频进行转录。",
"video_label": "📥 上传视频",
"run_btn": "🚀 开始转写",
"output_label": "📝 转写结果",
}}
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="MOSS Transcribe Diarize Gradio Demo (Remote API)")
parser.add_argument("--api_url", default=DEFAULT_API_URL, help="Remote inference service URL")
parser.add_argument(
"--auth_token",
default=DEFAULT_AUTH_TOKEN,
help="HTTP Authorization header (or use MOSS_API_AUTH_TOKEN env)",
)
parser.add_argument("--timeout", type=int, default=120, help="HTTP request timeout (seconds)")
parser.add_argument("--max_new_tokens", type=int, default=1024, help="Max new tokens")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
parser.add_argument("--top_k", type=int, default=20, help="Sampling top_k")
parser.add_argument("--top_p", type=float, default=1.0, help="Sampling top_p")
parser.add_argument("--target_sample_rate", type=int, default=16000, help="Resample to this rate (0 to disable)")
parser.add_argument("--keep_channels", action="store_true", help="Keep multiple channels (default: downmix to mono)")
parser.add_argument("--share", action="store_true", help="Whether to generate a public link")
parser.add_argument("--server_name", default="0.0.0.0", help="Gradio server name")
parser.add_argument("--server_port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")), help="Gradio server port")
return parser.parse_args()
def _get_duration(file_path: str) -> float:
"""Get the duration of an audio/video file in seconds."""
cmd = [
"ffprobe", "-v", "error", "-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", file_path
]
try:
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True)
return float(proc.stdout.strip())
except Exception:
return 0.0
def _ffmpeg_to_wav_bytes(file_path: str, target_sample_rate: int, keep_channels: bool, duration_limit: float = 0.0) -> bytes:
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-nostdin"]
if duration_limit > 0:
cmd += ["-t", str(duration_limit)]
cmd += ["-i", file_path]
if not keep_channels:
cmd += ["-ac", "1"]
if target_sample_rate and int(target_sample_rate) > 0:
cmd += ["-ar", str(int(target_sample_rate))]
cmd += ["-f", "wav", "pipe:1"]
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if proc.returncode != 0:
err = proc.stderr.decode(errors="ignore").strip()
raise RuntimeError(f"ffmpeg transcoding failed: {err or 'unknown error'}")
return proc.stdout
def _file_to_wav_bytes(file_path: str, duration_limit: float = 0.0) -> bytes:
return _ffmpeg_to_wav_bytes(file_path, APP_ARGS.target_sample_rate, APP_ARGS.keep_channels, duration_limit)
def _file_to_data_uri(file_path: str, duration_limit: float = 0.0) -> str:
wav_bytes = _file_to_wav_bytes(file_path, duration_limit)
b64 = base64.b64encode(wav_bytes).decode("utf-8")
return f"data:audio/wav;base64,{b64}"
def _call_remote_asr(
prompt: str,
audio_data_uri: str,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
) -> Any:
payload = {
"text": prompt,
"audio_data": audio_data_uri,
"sampling_params": {
"max_new_tokens": int(max_new_tokens),
"temperature": float(temperature),
"top_k": int(top_k),
"top_p": float(top_p),
},
}
headers = {"Content-Type": "application/json"}
if APP_ARGS.auth_token:
headers["Authorization"] = APP_ARGS.auth_token
try:
resp = requests.post(APP_ARGS.api_url, headers=headers, json=payload, timeout=int(APP_ARGS.timeout))
except Exception as e:
raise RuntimeError(f"Request failed: {e}") from e
if resp.status_code != 200:
text = (resp.text or "").strip()
if len(text) > 2000:
text = text[:2000] + " ... (truncated)"
raise RuntimeError(f"HTTP {resp.status_code}: {text}")
try:
return resp.json()
except Exception:
return {"raw_text": resp.text}
def _post_process_transcription(text: str) -> str:
"""Merge consecutive identical speaker tags and wrap lines by speaker, and remove <xxx> tags."""
# 1. Remove <xxx> and </xxx> tags (like <emotion>, <ovl>, etc.)
text = re.sub(r'<[^>]+>', '', text)
# 2. Split by speaker tags [S\d+]
pattern = re.compile(r'(\[S\d+\])')
parts = pattern.split(text)
processed_turns = []
current_speaker = None
current_text = []
for i in range(len(parts)):
part = parts[i]
if not part:
continue
if pattern.match(part):
speaker = part
if speaker == current_speaker:
continue
else:
if current_speaker is not None:
txt = "".join(current_text).strip()
if txt:
processed_turns.append(f"{current_speaker} {txt}")
current_speaker = speaker
current_text = []
else:
current_text.append(part)
if current_speaker is not None:
txt = "".join(current_text).strip()
if txt:
processed_turns.append(f"{current_speaker} {txt}")
elif current_text:
return "".join(current_text).strip()
return "\n".join(processed_turns)
def _format_api_response(resp_obj: Any) -> str:
raw_text = ""
if isinstance(resp_obj, str):
raw_text = resp_obj
elif isinstance(resp_obj, dict):
for k in ("text", "result", "transcription", "output", "generated_text"):
v = resp_obj.get(k)
if isinstance(v, str) and v.strip():
raw_text = v
break
else:
raw_text = json.dumps(resp_obj, ensure_ascii=False, indent=2)
else:
raw_text = str(resp_obj)
return _post_process_transcription(raw_text)
def _normalize_path(file_obj) -> str:
if isinstance(file_obj, dict):
name = file_obj.get("name")
if isinstance(name, str):
return name
if isinstance(file_obj, str):
return file_obj
name = getattr(file_obj, "name", None)
if isinstance(name, str):
return name
raise gr.Error("Unrecognized file object.")
def preprocess_file(audio_obj, video_obj) -> str:
provided = [obj for obj in (audio_obj, video_obj) if obj]
if len(provided) == 0:
raise gr.Error("Please upload an audio or video file.")
if len(provided) > 1:
raise gr.Error("Please select either audio or video, not both.")
file_path = _normalize_path(provided[0])
suffix = Path(file_path).suffix.lower()
if suffix in AUDIO_SUFFIXES:
return file_path
raise gr.Error("Unsupported file format.")
def run_transcription(
audio_obj,
video_obj,
progress=gr.Progress(track_tqdm=False),
) -> str:
progress(0.15, "Processing file...")
audio_path = preprocess_file(audio_obj, video_obj)
duration = _get_duration(audio_path)
actual_limit = 0.0
if duration > MAX_AUDIO_DURATION + 0.1:
actual_limit = float(MAX_AUDIO_DURATION)
gr.Warning(f"File is too long ({duration:.1f}s). It has been truncated to the first {MAX_AUDIO_DURATION}s.")
progress(0.3, "Decoding and transcoding to WAV (ffmpeg)...")
try:
audio_data_uri = _file_to_data_uri(audio_path, duration_limit=actual_limit)
except Exception as e:
raise gr.Error(str(e))
progress(0.6, "Requesting remote inference service...")
try:
resp_obj = _call_remote_asr(
FIXED_PROMPT,
audio_data_uri,
int(APP_ARGS.max_new_tokens),
float(APP_ARGS.temperature),
int(APP_ARGS.top_k),
float(APP_ARGS.top_p),
)
except Exception as e:
raise gr.Error(str(e))
progress(1.0, "Done")
return _format_api_response(resp_obj)
def build_demo() -> gr.Blocks:
with gr.Blocks(title="MOSS Transcribe Diarize", theme=gr.themes.Soft()) as demo:
gr.Markdown(i18n("header"))
gr.Markdown(i18n("tips"))
with gr.Row():
with gr.Column(scale=1):
with gr.Tabs() as tabs:
with gr.Tab(i18n("audio_tab"), id="audio") as audio_tab:
audio_input = gr.Audio(
label=i18n("audio_label"),
sources=["upload"],
type="filepath",
interactive=True,
)
with gr.Tab(i18n("video_tab"), id="video") as video_tab:
gr.Markdown(i18n("video_tip"))
video_input = gr.Video(
label=i18n("video_label"),
interactive=True,
)
run_button = gr.Button(i18n("run_btn"), variant="primary")
with gr.Column(scale=1):
output_box = gr.Textbox(label=i18n("output_label"), lines=18)
audio_tab.select(fn=lambda: None, outputs=video_input)
video_tab.select(fn=lambda: None, outputs=audio_input)
run_button.click(
run_transcription,
inputs=[
audio_input,
video_input,
],
outputs=output_box,
)
return demo
def main() -> None:
global APP_ARGS
APP_ARGS = parse_args()
demo = build_demo()
demo.queue().launch(
share=APP_ARGS.share,
server_name=APP_ARGS.server_name,
server_port=APP_ARGS.server_port,
max_file_size=MAX_FILE_SIZE,
i18n=i18n,
)
if __name__ == "__main__":
main()