Aman112's picture
Update app.py
2d4856d verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
MOSS Transcribe Diarize Gradio Demo (Remote API)
==========================
Provides a web interface for audio/video upload, transcribing using a fixed prompt.
"""
import base64
import argparse
import json
import os
import re
import subprocess
from pathlib import Path
from typing import Any, Tuple
import uuid
import shutil
import gradio as gr
import requests
DEFAULT_API_URL = os.getenv("MOSS_API_URL", "https://studio.mosi.cn/v1/audio/transcriptions") #http://inference.volc.mosi.cn/asr-hf
# DEFAULT_AUTH_TOKEN = os.getenv("MOSS_API_AUTH_TOKEN", "")
DEFAULT_AUTH_TOKEN = os.getenv("MOSS_API_KEY", os.getenv("MOSS_API_AUTH_TOKEN", ""))
DEFAULT_MODEL = os.getenv("MOSS_MODEL", "moss-transcribe-diarize")
MAX_AUDIO_DURATION = 30 # seconds
MAX_FILE_SIZE = 1024 * 1024 * 1024 # 10MB
FIXED_PROMPT_NAME = "Speaker + Text Labeling"
FIXED_PROMPT = (
"请将以下对话转录为文本,使用 [S1] [S2] 等说话人标签,对于音频中的事件,使用 [event] 标签表示。"
"富有情感的文本用<emotion>对应文本</emotion> 表示,使用 <ovl> 标签表示音频有部分重叠,<ins></ins> 标签表示音频有插入。"
"自动检测音频的语言,说话人标签和 <ovl> <ins> 始终用英文,event 和 emotion 跟随音频语言。"
)
AUDIO_SUFFIXES: Tuple[str, ...] = (
".wav",
".mp3",
".flac",
".aac",
".m4a",
".ogg",
".wma",
".mp4",
".mov",
".mkv",
".avi",
".wmv",
".webm",
)
APP_ARGS: argparse.Namespace = argparse.Namespace()
# --- Time Formatting Helper ---
def _sec_to_hhmmss_cs(sec: float) -> str:
"""Convert seconds to compact HH:MM:SS.ss format."""
if sec < 0:
sec = 0.0
total = float(sec)
hh = int(total // 3600)
mm = int((total % 3600) // 60)
ss = total % 60.0
if hh > 0:
return f"{hh:02d}:{mm:02d}:{ss:05.2f}"
if mm > 0:
return f"{mm:02d}:{ss:05.2f}"
return f"{ss:05.2f}"
# --- I18N Configuration (For UI Elements) ---
i18n = gr.I18n(
en={
"header": "## 🎤 Transcribe Diarize Model:(90%) Accurate Transcription with Speaker Diarization",
"audio_tab": "🎵 Audio",
"audio_label": "📥 Upload / Record Audio",
"video_tab": "🎬 Video",
"video_tip": "💡 **Note**: Uploading a video will extract the audio for transcription.",
"video_label": "📥 Upload Video",
"run_btn": "🚀 Start Transcription",
"output_label": "📝 Transcription Result",
},
**{"zh-CN": {
"header": "## 🎤 Transcribe Diarize: 精准转写与说话人识别",
"audio_tab": "🎵 音频",
"audio_label": "📥 上传/录制音频",
"video_tab": "🎬 视频",
"video_tip": "💡 **提示**:上传视频将提取其中的音频进行转录。",
"video_label": "📥 上传视频",
"run_btn": "🚀 开始转写",
"output_label": "📝 转写结果",
}}
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Transcribe Diarize Gradio Demo (Remote API)")
parser.add_argument("--api_url", default=DEFAULT_API_URL, help="Remote inference service URL")
parser.add_argument(
"--auth_token",
default=DEFAULT_AUTH_TOKEN,
help="HTTP Authorization header (or use MOSS_API_AUTH_TOKEN env)",
)
parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or snapshot")
parser.add_argument("--timeout", type=int, default=120, help="HTTP request timeout (seconds)")
parser.add_argument("--max_new_tokens", type=int, default=1024, help="Max new tokens")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
parser.add_argument("--top_k", type=int, default=20, help="Sampling top_k")
parser.add_argument("--top_p", type=float, default=1.0, help="Sampling top_p")
parser.add_argument("--target_sample_rate", type=int, default=16000, help="Resample to this rate (0 to disable)")
parser.add_argument("--keep_channels", action="store_true", help="Keep multiple channels (default: downmix to mono)")
parser.add_argument("--share", action="store_true", help="Whether to generate a public link")
parser.add_argument("--server_name", default="0.0.0.0", help="Gradio server name")
parser.add_argument("--server_port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", "7860")), help="Gradio server port")
return parser.parse_args()
def _get_duration(file_path: str) -> float:
"""Get the duration of an audio/video file in seconds."""
cmd = [
"ffprobe", "-v", "error", "-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", file_path
]
try:
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True)
return float(proc.stdout.strip())
except Exception:
return 0.0
def _ffmpeg_to_wav_bytes(file_path: str, target_sample_rate: int, keep_channels: bool, duration_limit: float = 0.0) -> bytes:
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-nostdin"]
if duration_limit > 0:
cmd += ["-t", str(duration_limit)]
cmd += ["-i", file_path]
if not keep_channels:
cmd += ["-ac", "1"]
if target_sample_rate and int(target_sample_rate) > 0:
cmd += ["-ar", str(int(target_sample_rate))]
cmd += ["-f", "wav", "pipe:1"]
if not os.path.exists(file_path):
raise RuntimeError(f"File not found before ffmpeg: {file_path}")
cmd += ["-i", str(file_path)]
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if proc.returncode != 0:
err = proc.stderr.decode(errors="ignore").strip()
raise RuntimeError(f"ffmpeg transcoding failed: {err or 'unknown error'}")
return proc.stdout
def _file_to_wav_bytes(file_path: str, duration_limit: float = 0.0) -> bytes:
return _ffmpeg_to_wav_bytes(file_path, APP_ARGS.target_sample_rate, APP_ARGS.keep_channels, duration_limit)
def _file_to_data_uri(file_path: str, duration_limit: float = 0.0) -> str:
wav_bytes = _file_to_wav_bytes(file_path, duration_limit)
b64 = base64.b64encode(wav_bytes).decode("utf-8")
return f"data:audio/wav;base64,{b64}"
# def _call_remote_asr(
# prompt: str,
# model: str,
# audio_data_uri: str,
# max_new_tokens: int,
# temperature: float,
# top_k: int,
# top_p: float,
# ) -> Any:
# payload = {
# "text": prompt,
# "audio_data": audio_data_uri,
# "sampling_params": {
# "max_new_tokens": int(max_new_tokens),
# "temperature": float(temperature),
# "top_k": int(top_k),
# "top_p": float(top_p),
# },
# }
# headers = {"Content-Type": "application/json"}
# if APP_ARGS.auth_token:
# headers["Authorization"] = APP_ARGS.auth_token
# print(APP_ARGS.auth_token)
# print(APP_ARGS.api_url)
# try:
# print(APP_ARGS.api_url)
# resp = requests.post(APP_ARGS.api_url, headers=headers, json=payload, timeout=int(APP_ARGS.timeout))
# except Exception as e:
# raise RuntimeError(f"Request failed: {e}") from e
# if resp.status_code != 200:
# text = (resp.text or "").strip()
# raise RuntimeError(f"HTTP {resp.status_code}: {text}")
# try:
# return resp.json()
# except Exception:
# return {"raw_text": resp.text+" AUT_Token:"+str(DEFAULT_AUTH_TOKEN)+DEFAULT_API_URL}
def _call_remote_asr(
audio_data_uri: str,
model: str,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
) -> Any:
token = (APP_ARGS.auth_token or "").strip()
if not token:
raise RuntimeError("Missing API key.")
# ✅ correct handling
auth_header = token if token.lower().startswith("bearer ") else f"Bearer {token}"
print(auth_header,'auth_headerauth_headerauth_headerauth_header')
payload = {
"model": model,
"audio_data": audio_data_uri,
"sampling_params": {
"max_new_tokens": int(max_new_tokens),
"temperature": float(temperature),
"top_k": int(top_k),
"top_p": float(top_p),
},
"meta_info": True,
}
headers = {
"Content-Type": "application/json",
"Authorization": auth_header,
}
try:
resp = requests.post(APP_ARGS.api_url, headers=headers, json=payload, timeout=int(APP_ARGS.timeout))
except Exception as e:
raise RuntimeError(f"Request failed: {e}") from e
if resp.status_code != 200:
text = (resp.text or "").strip()
# if len(text) > 2000:
# text = text[:2000] + " ... (truncated)"
raise RuntimeError(f"HTTP {resp.status_code}: {text}")
try:
return resp.json()
except Exception:
return {"raw_text": resp.text}
def _safe_float(value: Any) -> float | None:
try:
return float(value)
except Exception:
return None
def _format_segments(segments: Any) -> str:
if not isinstance(segments, list):
return ""
lines: list[str] = []
for seg in segments:
if not isinstance(seg, dict):
continue
text = seg.get("text")
if not isinstance(text, str):
continue
text = re.sub(r"<[^>]+>", "", text).strip()
if not text:
continue
start_sec = _safe_float(seg.get("start_s"))
end_sec = _safe_float(seg.get("end_s"))
speaker = seg.get("speaker")
speaker_tag = f"[{speaker.strip()}]" if isinstance(speaker, str) and speaker.strip() else ""
if start_sec is not None and end_sec is not None:
start_fmt = _sec_to_hhmmss_cs(start_sec)
end_fmt = _sec_to_hhmmss_cs(end_sec)
line = f"[{start_fmt}-{end_fmt}] {speaker_tag} {text}".strip()
elif speaker_tag:
line = f"{speaker_tag} {text}"
else:
line = text
lines.append(line)
return "\n".join(lines).strip()
# def _post_process_transcription(text: str) -> str:
# """Merge consecutive identical speaker tags and wrap lines by speaker, and remove <xxx> tags."""
# # 1. Remove <xxx> and </xxx> tags (like <emotion>, <ovl>, etc.)
# text = re.sub(r'<[^>]+>', '', text)
# # 2. Split by speaker tags [S\d+]
# pattern = re.compile(r'(\[S\d+\])')
# parts = pattern.split(text)
# processed_turns = []
# current_speaker = None
# current_text = []
# for i in range(len(parts)):
# part = parts[i]
# if not part:
# continue
# if pattern.match(part):
# speaker = part
# if speaker == current_speaker:
# continue
# else:
# if current_speaker is not None:
# txt = "".join(current_text).strip()
# if txt:
# processed_turns.append(f"{current_speaker} {txt}")
# current_speaker = speaker
# current_text = []
# else:
# current_text.append(part)
# if current_speaker is not None:
# txt = "".join(current_text).strip()
# if txt:
# processed_turns.append(f"{current_speaker} {txt}")
# elif current_text:
# return "".join(current_text).strip()
# return "\n".join(processed_turns)
def _post_process_transcription(text: str) -> str:
"""Post process model output.
- Remove <xxx> tags (like <emotion>, <ovl>, etc.)
- If text contains timestamped segments like:
[0.00][S01]... [4.20][8.38][S02]...
format them into:
[start-end] [Sxx] 内容
where start/end are converted from seconds to HH:MM:SS.ss
- If timestamp parsing fails, strip numeric timestamps and fallback to speaker-only formatting.
"""
# print(text)
def _is_time_token(tok: str) -> bool:
# seconds like 0.00 / 12 / 12.3 / 12.34
return re.fullmatch(r"\d+(?:\.\d+)?", tok) is not None
def _is_speaker_token(tok: str) -> bool:
# S01 / S1 / S001 ...
return re.fullmatch(r"S\d{1,3}", tok) is not None
def _strip_numeric_timestamps(s: str) -> str:
return re.sub(r"\[(?:\d+(?:\.\d+)?)\]", "", s)
# 1) Remove <xxx> and </xxx> tags
text = re.sub(r"<[^>]+>", "", text)
# 2) Try timestamped parsing first
try:
bracket_pat = re.compile(r"\[([^\]]+)\]")
segments: list[tuple[float, float, str, str]] = []
times_buffer: list[float] = []
cur_speaker: str | None = None
cur_start: float | None = None
cur_text: list[str] = []
idx = 0
for m in bracket_pat.finditer(text):
between = text[idx:m.start()]
if cur_speaker is not None and between:
cur_text.append(between)
idx = m.end()
tok = (m.group(1) or "").strip()
if not tok:
continue
# time token: accumulate,稍后在说话人边界统一分配
if _is_time_token(tok):
times_buffer.append(float(tok))
continue
# 说话人 token
if _is_speaker_token(tok):
speaker = f"[{tok}]"
if cur_speaker is None:
# 第一个 segment:用最近的时间作为起始
cur_start = times_buffer[-1] if times_buffer else None
cur_speaker = speaker
cur_text = []
times_buffer = []
else:
# 连续重复同一说话人、且中间没有时间戳:忽略这个重复标签
if not times_buffer and speaker == cur_speaker:
continue
# 结束上一个 segment,并为新 segment 设定 start
if not times_buffer:
raise ValueError("no timestamp between speakers")
prev_end = times_buffer[0]
if cur_start is None:
raise ValueError("segment without start time")
txt = "".join(cur_text).strip()
segments.append((cur_start, prev_end, cur_speaker, txt))
# 新说话人:用最后一个时间作为 start(成对时即第二个值)
next_start = times_buffer[-1]
cur_start = next_start
cur_speaker = speaker
cur_text = []
times_buffer = []
continue
# 其他方括号内容(如 [event]),保留在文本中
if cur_speaker is not None:
cur_text.append(f"[{tok}]")
# 处理最后一段文本
tail = text[idx:]
if cur_speaker is not None and tail:
cur_text.append(tail)
# 收尾:最后一个 segment 的结束时间取最后一个时间戳
if cur_speaker is not None:
if not times_buffer or cur_start is None:
raise ValueError("last segment missing timestamp")
last_end = times_buffer[-1]
txt = "".join(cur_text).strip()
segments.append((cur_start, last_end, cur_speaker, txt))
if not segments:
raise ValueError("no valid segments parsed")
formatted_lines: list[str] = []
for s_sec, e_sec, spk, txt in segments:
s_fmt = _sec_to_hhmmss_cs(s_sec)
e_fmt = _sec_to_hhmmss_cs(e_sec)
formatted_lines.append(f"[{s_fmt}-{e_fmt}] {spk} {txt}")
return "\n".join(formatted_lines).strip()
except Exception:
# fallback: strip numeric timestamps then apply speaker-only formatting
text_wo_ts = _strip_numeric_timestamps(text)
# 3) Speaker-only formatting (merge consecutive identical speakers)
speaker_pat = re.compile(r"(\[S\d{1,3}\])")
parts = speaker_pat.split(text_wo_ts)
processed_turns: list[str] = []
current_speaker: str | None = None
current_text: list[str] = []
for part in parts:
if not part:
continue
if speaker_pat.fullmatch(part):
speaker = part
if speaker == current_speaker:
continue
if current_speaker is not None:
txt = "".join(current_text).strip()
if txt:
processed_turns.append(f"{current_speaker} {txt}")
current_speaker = speaker
current_text = []
else:
current_text.append(part)
if current_speaker is not None:
txt = "".join(current_text).strip()
if txt:
processed_turns.append(f"{current_speaker} {txt}")
elif current_text:
return "".join(current_text).strip()
return "\n".join(processed_turns).strip()
# def _format_api_response(resp_obj: Any) -> str:
# raw_text = ""
# if isinstance(resp_obj, str):
# raw_text = resp_obj
# elif isinstance(resp_obj, dict):
# for k in ("text", "result", "transcription", "output", "generated_text"):
# v = resp_obj.get(k)
# if isinstance(v, str) and v.strip():
# raw_text = v
# break
# else:
# raw_text = json.dumps(resp_obj, ensure_ascii=False, indent=2)
# else:
# raw_text = str(resp_obj)
# return _post_process_transcription(raw_text)
def _format_api_response(resp_obj: Any) -> str:
raw_text = ""
if isinstance(resp_obj, str):
raw_text = resp_obj
elif isinstance(resp_obj, dict):
asr_result = resp_obj.get("asr_transcription_result")
if isinstance(asr_result, dict):
segments_text = _format_segments(asr_result.get("segments"))
if segments_text:
return segments_text
full_text = asr_result.get("full_text")
if isinstance(full_text, str) and full_text.strip():
return _post_process_transcription(full_text)
for k in ("text", "result", "transcription", "output", "generated_text"):
v = resp_obj.get(k)
if isinstance(v, str) and v.strip():
raw_text = v
break
else:
raw_text = json.dumps(resp_obj, ensure_ascii=False, indent=2)
else:
raw_text = str(resp_obj)
return _post_process_transcription(raw_text)
def _normalize_path(file_obj) -> str:
if isinstance(file_obj, dict):
name = file_obj.get("name")
if isinstance(name, str):
return name
if isinstance(file_obj, str):
return file_obj
name = getattr(file_obj, "name", None)
if isinstance(name, str):
return name
raise gr.Error("Unrecognized file object.")
def preprocess_file(audio_obj, video_obj) -> str:
provided = [obj for obj in (audio_obj, video_obj) if obj]
if len(provided) == 0:
raise gr.Error("Please upload an audio or video file.")
if len(provided) > 1:
raise gr.Error("Please select either audio or video, not both.")
file_path = _normalize_path(provided[0])
suffix = Path(file_path).suffix.lower()
if suffix not in AUDIO_SUFFIXES:
raise gr.Error("Unsupported file format.")
# 🔥 FIX: copy to stable location
safe_dir = Path("/tmp/safe_audio")
safe_dir.mkdir(exist_ok=True)
safe_path = safe_dir / f"{uuid.uuid4()}{suffix}"
shutil.copy(file_path, safe_path)
return str(safe_path)
def run_transcription(
audio_obj,
video_obj,
progress=gr.Progress(track_tqdm=False),
) -> str:
progress(0.15, "Processing file...")
audio_path = preprocess_file(audio_obj, video_obj)
duration = _get_duration(audio_path)
actual_limit = duration
progress(0.3, "Decoding and transcoding to WAV (ffmpeg)...")
try:
audio_data_uri = _file_to_data_uri(audio_path, duration_limit=actual_limit)
# print(audio_data_uri,'audio_data_uri')
except Exception as e:
raise gr.Error(str(e))
progress(0.6, "Requesting remote inference service...")
# try:
# resp_obj = _call_remote_asr(
# FIXED_PROMPT,
# audio_data_uri,
# int(APP_ARGS.max_new_tokens),
# float(APP_ARGS.temperature),
# int(APP_ARGS.top_k),
# float(APP_ARGS.top_p),
# )
# except Exception as e:
# raise gr.Error(str(e))
try:
resp_obj = _call_remote_asr(
audio_data_uri,
APP_ARGS.model,
int(APP_ARGS.max_new_tokens),
float(APP_ARGS.temperature),
int(APP_ARGS.top_k),
float(APP_ARGS.top_p),
)
result = _format_api_response(resp_obj)
if not result.strip():
result = json.dumps(resp_obj, ensure_ascii=False, indent=2)
progress(1.0, "Done")
return result
except Exception as e:
raise gr.Error(str(e))
def build_demo() -> gr.Blocks:
with gr.Blocks(title="Transcribe Diarize", theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column(scale=1):
with gr.Tabs() as tabs:
with gr.Tab(i18n("audio_tab"), id="audio") as audio_tab:
audio_input = gr.Audio(
label=i18n("audio_label"),
sources=["upload", "microphone"],
type="filepath",
interactive=True,
)
with gr.Tab(i18n("video_tab"), id="video") as video_tab:
gr.Markdown(i18n("video_tip"))
video_input = gr.Video(
label=i18n("video_label"),
interactive=True,
)
run_button = gr.Button(i18n("run_btn"), variant="primary")
with gr.Column(scale=1):
output_box = gr.Textbox(label=i18n("output_label"), lines=100)
audio_tab.select(fn=lambda: None, outputs=video_input)
video_tab.select(fn=lambda: None, outputs=audio_input)
run_button.click(
run_transcription,
inputs=[
audio_input,
video_input,
],
outputs=output_box,
)
return demo
def main() -> None:
global APP_ARGS
APP_ARGS = parse_args()
demo = build_demo()
demo.queue().launch(
share=APP_ARGS.share,
server_name=APP_ARGS.server_name,
server_port=APP_ARGS.server_port,
max_file_size=MAX_FILE_SIZE,
i18n=i18n,
)
if __name__ == "__main__":
main()