|
|
|
|
|
|
|
|
""" |
|
|
MOSS Transcribe Diarize Gradio Demo (Remote API) |
|
|
========================== |
|
|
Provides a web interface for audio/video upload, transcribing using a fixed prompt. |
|
|
""" |
|
|
import base64 |
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import re |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from typing import Any, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import requests |
|
|
|
|
|
|
|
|
DEFAULT_API_URL = os.getenv("MOSS_API_URL", "http://inference.volc.mosi.cn/asr-hf") |
|
|
DEFAULT_AUTH_TOKEN = os.getenv("MOSS_API_AUTH_TOKEN", "") |
|
|
MAX_AUDIO_DURATION = 30 |
|
|
MAX_FILE_SIZE = 1024 * 1024 * 10 |
|
|
|
|
|
|
|
|
FIXED_PROMPT_NAME = "Speaker + Text Labeling" |
|
|
FIXED_PROMPT = ( |
|
|
"请将以下对话转录为文本,使用 [S1] [S2] 等说话人标签,对于音频中的事件,使用 [event] 标签表示。" |
|
|
"富有情感的文本用<emotion>对应文本</emotion> 表示,使用 <ovl> 标签表示音频有部分重叠,<ins></ins> 标签表示音频有插入。" |
|
|
"自动检测音频的语言,说话人标签和 <ovl> <ins> 始终用英文,event 和 emotion 跟随音频语言。" |
|
|
) |
|
|
|
|
|
|
|
|
AUDIO_SUFFIXES: Tuple[str, ...] = ( |
|
|
".wav", |
|
|
".mp3", |
|
|
".flac", |
|
|
".aac", |
|
|
".m4a", |
|
|
".ogg", |
|
|
".wma", |
|
|
".mp4", |
|
|
".mov", |
|
|
".mkv", |
|
|
".avi", |
|
|
".wmv", |
|
|
".webm", |
|
|
) |
|
|
|
|
|
|
|
|
APP_ARGS: argparse.Namespace = argparse.Namespace() |
|
|
|
|
|
|
|
|
i18n = gr.I18n( |
|
|
en={ |
|
|
"header": "## 🎤 MOSS Transcribe Diarize: Accurate Transcription with Speaker Diarization", |
|
|
"tips": ( |
|
|
f"> **💡 Note**: This demo currently supports ASR with speaker recognition for audio clips up to **{MAX_AUDIO_DURATION}s**. \n" |
|
|
"> Stay tuned! Full-featured long audio support (including timestamps) is **Coming Soon**. \n" |
|
|
"> **🔗 Links**: [paper](https://arxiv.org/abs/2601.01554) · [model page](https://mosi.cn/models/moss-transcribe-diarize)" |
|
|
), |
|
|
"audio_tab": "🎵 Audio", |
|
|
"audio_label": "📥 Upload Audio", |
|
|
"video_tab": "🎬 Video", |
|
|
"video_tip": "💡 **Note**: Uploading a video will extract the audio for transcription.", |
|
|
"video_label": "📥 Upload Video", |
|
|
"run_btn": "🚀 Start Transcription", |
|
|
"output_label": "📝 Transcription Result", |
|
|
}, |
|
|
**{"zh-CN": { |
|
|
"header": "## 🎤 MOSS Transcribe Diarize: 精准转写与说话人识别", |
|
|
"tips": ( |
|
|
f"> **💡 说明**:本演示版本仅支持短音频(最长 **{MAX_AUDIO_DURATION}s**)的文本转写及说话人识别。 \n" |
|
|
"> 敬请期待!支持完整时间戳的长音频 API **即将上线**。 \n" |
|
|
"> **🔗 链接**:[paper](https://arxiv.org/abs/2601.01554) · [model page](https://mosi.cn/models/moss-transcribe-diarize)" |
|
|
), |
|
|
"audio_tab": "🎵 音频", |
|
|
"audio_label": "📥 上传音频", |
|
|
"video_tab": "🎬 视频", |
|
|
"video_tip": "💡 **提示**:上传视频将提取其中的音频进行转录。", |
|
|
"video_label": "📥 上传视频", |
|
|
"run_btn": "🚀 开始转写", |
|
|
"output_label": "📝 转写结果", |
|
|
}} |
|
|
) |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="MOSS Transcribe Diarize Gradio Demo (Remote API)") |
|
|
parser.add_argument("--api_url", default=DEFAULT_API_URL, help="Remote inference service URL") |
|
|
parser.add_argument( |
|
|
"--auth_token", |
|
|
default=DEFAULT_AUTH_TOKEN, |
|
|
help="HTTP Authorization header (or use MOSS_API_AUTH_TOKEN env)", |
|
|
) |
|
|
parser.add_argument("--timeout", type=int, default=120, help="HTTP request timeout (seconds)") |
|
|
parser.add_argument("--max_new_tokens", type=int, default=1024, help="Max new tokens") |
|
|
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") |
|
|
parser.add_argument("--top_k", type=int, default=20, help="Sampling top_k") |
|
|
parser.add_argument("--top_p", type=float, default=1.0, help="Sampling top_p") |
|
|
parser.add_argument("--target_sample_rate", type=int, default=16000, help="Resample to this rate (0 to disable)") |
|
|
parser.add_argument("--keep_channels", action="store_true", help="Keep multiple channels (default: downmix to mono)") |
|
|
parser.add_argument("--share", action="store_true", help="Whether to generate a public link") |
|
|
parser.add_argument("--server_name", default="0.0.0.0", help="Gradio server name") |
|
|
parser.add_argument("--server_port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")), help="Gradio server port") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def _get_duration(file_path: str) -> float: |
|
|
"""Get the duration of an audio/video file in seconds.""" |
|
|
cmd = [ |
|
|
"ffprobe", "-v", "error", "-show_entries", "format=duration", |
|
|
"-of", "default=noprint_wrappers=1:nokey=1", file_path |
|
|
] |
|
|
try: |
|
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True) |
|
|
return float(proc.stdout.strip()) |
|
|
except Exception: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
def _ffmpeg_to_wav_bytes(file_path: str, target_sample_rate: int, keep_channels: bool, duration_limit: float = 0.0) -> bytes: |
|
|
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-nostdin"] |
|
|
if duration_limit > 0: |
|
|
cmd += ["-t", str(duration_limit)] |
|
|
cmd += ["-i", file_path] |
|
|
if not keep_channels: |
|
|
cmd += ["-ac", "1"] |
|
|
if target_sample_rate and int(target_sample_rate) > 0: |
|
|
cmd += ["-ar", str(int(target_sample_rate))] |
|
|
cmd += ["-f", "wav", "pipe:1"] |
|
|
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
|
|
if proc.returncode != 0: |
|
|
err = proc.stderr.decode(errors="ignore").strip() |
|
|
raise RuntimeError(f"ffmpeg transcoding failed: {err or 'unknown error'}") |
|
|
return proc.stdout |
|
|
|
|
|
|
|
|
def _file_to_wav_bytes(file_path: str, duration_limit: float = 0.0) -> bytes: |
|
|
return _ffmpeg_to_wav_bytes(file_path, APP_ARGS.target_sample_rate, APP_ARGS.keep_channels, duration_limit) |
|
|
|
|
|
|
|
|
def _file_to_data_uri(file_path: str, duration_limit: float = 0.0) -> str: |
|
|
wav_bytes = _file_to_wav_bytes(file_path, duration_limit) |
|
|
b64 = base64.b64encode(wav_bytes).decode("utf-8") |
|
|
return f"data:audio/wav;base64,{b64}" |
|
|
|
|
|
|
|
|
def _call_remote_asr( |
|
|
prompt: str, |
|
|
audio_data_uri: str, |
|
|
max_new_tokens: int, |
|
|
temperature: float, |
|
|
top_k: int, |
|
|
top_p: float, |
|
|
) -> Any: |
|
|
payload = { |
|
|
"text": prompt, |
|
|
"audio_data": audio_data_uri, |
|
|
"sampling_params": { |
|
|
"max_new_tokens": int(max_new_tokens), |
|
|
"temperature": float(temperature), |
|
|
"top_k": int(top_k), |
|
|
"top_p": float(top_p), |
|
|
}, |
|
|
} |
|
|
headers = {"Content-Type": "application/json"} |
|
|
if APP_ARGS.auth_token: |
|
|
headers["Authorization"] = APP_ARGS.auth_token |
|
|
|
|
|
try: |
|
|
resp = requests.post(APP_ARGS.api_url, headers=headers, json=payload, timeout=int(APP_ARGS.timeout)) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Request failed: {e}") from e |
|
|
|
|
|
if resp.status_code != 200: |
|
|
text = (resp.text or "").strip() |
|
|
if len(text) > 2000: |
|
|
text = text[:2000] + " ... (truncated)" |
|
|
raise RuntimeError(f"HTTP {resp.status_code}: {text}") |
|
|
|
|
|
try: |
|
|
return resp.json() |
|
|
except Exception: |
|
|
return {"raw_text": resp.text} |
|
|
|
|
|
|
|
|
def _post_process_transcription(text: str) -> str: |
|
|
"""Merge consecutive identical speaker tags and wrap lines by speaker, and remove <xxx> tags.""" |
|
|
|
|
|
text = re.sub(r'<[^>]+>', '', text) |
|
|
|
|
|
|
|
|
pattern = re.compile(r'(\[S\d+\])') |
|
|
parts = pattern.split(text) |
|
|
|
|
|
processed_turns = [] |
|
|
current_speaker = None |
|
|
current_text = [] |
|
|
|
|
|
for i in range(len(parts)): |
|
|
part = parts[i] |
|
|
if not part: |
|
|
continue |
|
|
|
|
|
if pattern.match(part): |
|
|
speaker = part |
|
|
if speaker == current_speaker: |
|
|
continue |
|
|
else: |
|
|
if current_speaker is not None: |
|
|
txt = "".join(current_text).strip() |
|
|
if txt: |
|
|
processed_turns.append(f"{current_speaker} {txt}") |
|
|
current_speaker = speaker |
|
|
current_text = [] |
|
|
else: |
|
|
current_text.append(part) |
|
|
|
|
|
if current_speaker is not None: |
|
|
txt = "".join(current_text).strip() |
|
|
if txt: |
|
|
processed_turns.append(f"{current_speaker} {txt}") |
|
|
elif current_text: |
|
|
return "".join(current_text).strip() |
|
|
|
|
|
return "\n".join(processed_turns) |
|
|
|
|
|
|
|
|
def _format_api_response(resp_obj: Any) -> str: |
|
|
raw_text = "" |
|
|
if isinstance(resp_obj, str): |
|
|
raw_text = resp_obj |
|
|
elif isinstance(resp_obj, dict): |
|
|
for k in ("text", "result", "transcription", "output", "generated_text"): |
|
|
v = resp_obj.get(k) |
|
|
if isinstance(v, str) and v.strip(): |
|
|
raw_text = v |
|
|
break |
|
|
else: |
|
|
raw_text = json.dumps(resp_obj, ensure_ascii=False, indent=2) |
|
|
else: |
|
|
raw_text = str(resp_obj) |
|
|
|
|
|
return _post_process_transcription(raw_text) |
|
|
|
|
|
|
|
|
def _normalize_path(file_obj) -> str: |
|
|
if isinstance(file_obj, dict): |
|
|
name = file_obj.get("name") |
|
|
if isinstance(name, str): |
|
|
return name |
|
|
if isinstance(file_obj, str): |
|
|
return file_obj |
|
|
name = getattr(file_obj, "name", None) |
|
|
if isinstance(name, str): |
|
|
return name |
|
|
raise gr.Error("Unrecognized file object.") |
|
|
|
|
|
|
|
|
def preprocess_file(audio_obj, video_obj) -> str: |
|
|
provided = [obj for obj in (audio_obj, video_obj) if obj] |
|
|
if len(provided) == 0: |
|
|
raise gr.Error("Please upload an audio or video file.") |
|
|
if len(provided) > 1: |
|
|
raise gr.Error("Please select either audio or video, not both.") |
|
|
|
|
|
file_path = _normalize_path(provided[0]) |
|
|
suffix = Path(file_path).suffix.lower() |
|
|
if suffix in AUDIO_SUFFIXES: |
|
|
return file_path |
|
|
raise gr.Error("Unsupported file format.") |
|
|
|
|
|
|
|
|
def run_transcription( |
|
|
audio_obj, |
|
|
video_obj, |
|
|
progress=gr.Progress(track_tqdm=False), |
|
|
) -> str: |
|
|
progress(0.15, "Processing file...") |
|
|
audio_path = preprocess_file(audio_obj, video_obj) |
|
|
|
|
|
duration = _get_duration(audio_path) |
|
|
actual_limit = 0.0 |
|
|
if duration > MAX_AUDIO_DURATION + 0.1: |
|
|
actual_limit = float(MAX_AUDIO_DURATION) |
|
|
gr.Warning(f"File is too long ({duration:.1f}s). It has been truncated to the first {MAX_AUDIO_DURATION}s.") |
|
|
|
|
|
progress(0.3, "Decoding and transcoding to WAV (ffmpeg)...") |
|
|
try: |
|
|
audio_data_uri = _file_to_data_uri(audio_path, duration_limit=actual_limit) |
|
|
except Exception as e: |
|
|
raise gr.Error(str(e)) |
|
|
|
|
|
progress(0.6, "Requesting remote inference service...") |
|
|
try: |
|
|
resp_obj = _call_remote_asr( |
|
|
FIXED_PROMPT, |
|
|
audio_data_uri, |
|
|
int(APP_ARGS.max_new_tokens), |
|
|
float(APP_ARGS.temperature), |
|
|
int(APP_ARGS.top_k), |
|
|
float(APP_ARGS.top_p), |
|
|
) |
|
|
except Exception as e: |
|
|
raise gr.Error(str(e)) |
|
|
|
|
|
progress(1.0, "Done") |
|
|
return _format_api_response(resp_obj) |
|
|
|
|
|
|
|
|
def build_demo() -> gr.Blocks: |
|
|
with gr.Blocks(title="MOSS Transcribe Diarize", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(i18n("header")) |
|
|
gr.Markdown(i18n("tips")) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Tabs() as tabs: |
|
|
with gr.Tab(i18n("audio_tab"), id="audio") as audio_tab: |
|
|
audio_input = gr.Audio( |
|
|
label=i18n("audio_label"), |
|
|
sources=["upload"], |
|
|
type="filepath", |
|
|
interactive=True, |
|
|
) |
|
|
with gr.Tab(i18n("video_tab"), id="video") as video_tab: |
|
|
gr.Markdown(i18n("video_tip")) |
|
|
video_input = gr.Video( |
|
|
label=i18n("video_label"), |
|
|
interactive=True, |
|
|
) |
|
|
run_button = gr.Button(i18n("run_btn"), variant="primary") |
|
|
with gr.Column(scale=1): |
|
|
output_box = gr.Textbox(label=i18n("output_label"), lines=18) |
|
|
|
|
|
audio_tab.select(fn=lambda: None, outputs=video_input) |
|
|
video_tab.select(fn=lambda: None, outputs=audio_input) |
|
|
|
|
|
run_button.click( |
|
|
run_transcription, |
|
|
inputs=[ |
|
|
audio_input, |
|
|
video_input, |
|
|
], |
|
|
outputs=output_box, |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
global APP_ARGS |
|
|
APP_ARGS = parse_args() |
|
|
|
|
|
demo = build_demo() |
|
|
demo.queue().launch( |
|
|
share=APP_ARGS.share, |
|
|
server_name=APP_ARGS.server_name, |
|
|
server_port=APP_ARGS.server_port, |
|
|
max_file_size=MAX_FILE_SIZE, |
|
|
i18n=i18n, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|