STLooo's picture
Update app.py
4436daf verified
import os
import time
import base64
import hashlib
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import gradio as gr
from faster_whisper import WhisperModel
# Tencent Cloud SDK
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
# Tencent TMT (Translate)
from tencentcloud.tmt.v20180321 import tmt_client, models as tmt_models
# Tencent TTS (Text-to-Speech)
from tencentcloud.tts.v20190823 import tts_client, models as tts_models
# ======================
# Config
# ======================
MODEL_NAME = os.getenv("WHISPER_MODEL", "small") # CPU: small; if slow -> base
DEVICE = "cpu"
COMPUTE_TYPE = "int8"
# Tencent region
TENCENT_REGION = os.getenv("TENCENT_REGION", "ap-shanghai").strip()
# Tencent TTS voice types
# Default voice types:
# - ZH default: 0 (often "云小宁" default timbre)
# - EN: 101001 is commonly used in docs as an example timbre ID; if it fails, set your own in Secrets.
VOICE_EN = int(os.getenv("TENCENT_TTS_VOICE_EN", "101001"))
VOICE_ZH = int(os.getenv("TENCENT_TTS_VOICE_ZH", "0"))
# Generate TTS only for latest published line (to avoid load)
TTS_GENERATE_MODE = "latest_only" # keep MVP stable
# ======================
# Helpers
# ======================
def _now_ms() -> int:
return int(time.time() * 1000)
def _session_id() -> str:
return str(_now_ms())
def _hash(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()[:12]
def _require_env(name: str) -> str:
v = os.getenv(name, "").strip()
if not v:
raise RuntimeError(f"Missing env: {name}. Set it in HF Space Settings → Secrets.")
return v
# ======================
# Tencent Clients
# ======================
_TMT_CLIENT: Optional[tmt_client.TmtClient] = None
_TTS_CLIENT: Optional[tts_client.TtsClient] = None
def _make_client(endpoint: str):
secret_id = _require_env("TENCENT_SECRET_ID")
secret_key = _require_env("TENCENT_SECRET_KEY")
cred = credential.Credential(secret_id, secret_key)
httpProfile = HttpProfile()
httpProfile.endpoint = endpoint
clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile
return cred, clientProfile
def get_tmt_client() -> tmt_client.TmtClient:
global _TMT_CLIENT
if _TMT_CLIENT is not None:
return _TMT_CLIENT
cred, clientProfile = _make_client("tmt.tencentcloudapi.com")
_TMT_CLIENT = tmt_client.TmtClient(cred, TENCENT_REGION, clientProfile)
return _TMT_CLIENT
def get_tts_client() -> tts_client.TtsClient:
global _TTS_CLIENT
if _TTS_CLIENT is not None:
return _TTS_CLIENT
cred, clientProfile = _make_client("tts.tencentcloudapi.com")
_TTS_CLIENT = tts_client.TtsClient(cred, TENCENT_REGION, clientProfile)
return _TTS_CLIENT
# ======================
# Whisper Model
# ======================
whisper = WhisperModel(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)
# ======================
# In-memory State (MVP)
# ======================
@dataclass
class Chunk:
chunk_id: int
start_s: float
end_s: float
raw_text_en: str
edited_text_en: str
status: str = "raw" # raw / published
rev: int = 0
zh_text: str = "" # translation (on publish)
tts_en_path: str = "" # cached mp3 filepath
tts_zh_path: str = "" # cached mp3 filepath
STATE: Dict[str, List[Chunk]] = {} # session_id -> chunks
# caches across sessions (MVP)
TRANS_CACHE: Dict[str, str] = {} # key -> zh text
TTS_CACHE: Dict[str, str] = {} # key -> mp3 path
# ======================
# Translation (EN -> ZH) with caching
# ======================
def translate_en_to_zh(text_en: str) -> str:
text_en = (text_en or "").strip()
if not text_en:
return ""
key = f"tmt:en->zh:{_hash(text_en)}"
if key in TRANS_CACHE:
return TRANS_CACHE[key]
client = get_tmt_client()
req = tmt_models.TextTranslateRequest()
req.SourceText = text_en
req.Source = "en"
req.Target = "zh"
req.ProjectId = 0
resp = client.TextTranslate(req)
out = getattr(resp, "TargetText", "") or ""
TRANS_CACHE[key] = out
return out
# ======================
# TTS (Text -> mp3) with caching
# ======================
def tts_to_mp3(text: str, voice_type: int) -> str:
text = (text or "").strip()
if not text:
return ""
key = f"tts:{voice_type}:{_hash(text)}"
if key in TTS_CACHE:
return TTS_CACHE[key]
client = get_tts_client()
req = tts_models.TextToVoiceRequest()
req.Text = text
req.SessionId = key
req.ModelType = 1
req.VoiceType = voice_type
req.Volume = 5
req.Speed = 0
req.SampleRate = 16000
req.Codec = "mp3"
resp = client.TextToVoice(req)
audio_b64 = getattr(resp, "Audio", "") or ""
if not audio_b64:
return ""
audio_bytes = base64.b64decode(audio_b64)
out_dir = "outputs"
os.makedirs(out_dir, exist_ok=True)
path = os.path.join(out_dir, f"{key}.mp3")
with open(path, "wb") as f:
f.write(audio_bytes)
TTS_CACHE[key] = path
return path
# ======================
# Core pipeline
# ======================
def transcribe_to_chunks(audio_path: str, session_id: str) -> str:
segments, info = whisper.transcribe(audio_path, vad_filter=True)
detected = getattr(info, "language", None) or "auto"
chunks: List[Chunk] = []
cid = 0
for seg in segments:
txt = (seg.text or "").strip()
if not txt:
continue
chunks.append(Chunk(
chunk_id=cid,
start_s=float(seg.start),
end_s=float(seg.end),
raw_text_en=txt,
edited_text_en=txt,
status="raw",
rev=0,
zh_text="",
tts_en_path="",
tts_zh_path=""
))
cid += 1
STATE[session_id] = chunks
return f"OK: {len(chunks)} chunks · detected_lang={detected} · model={MODEL_NAME}/{COMPUTE_TYPE}"
def editor_table(session_id: str):
rows = []
for c in STATE.get(session_id, []):
rows.append([
c.chunk_id,
f"{c.start_s:.2f}-{c.end_s:.2f}",
c.status,
c.raw_text_en,
c.edited_text_en,
c.zh_text,
c.rev
])
return rows
def publish_one(session_id: str, chunk_id: int, edited_text_en: str):
chunks = STATE.get(session_id, [])
if chunk_id < 0 or chunk_id >= len(chunks):
return "Chunk ID out of range", editor_table(session_id)
c = chunks[chunk_id]
if edited_text_en and edited_text_en.strip():
c.edited_text_en = edited_text_en.strip()
c.status = "published"
c.rev += 1
# Translate after publish (cost control + higher quality)
try:
c.zh_text = translate_en_to_zh(c.edited_text_en)
msg = f"Published #{chunk_id} rev={c.rev} · translated"
except Exception as e:
c.zh_text = ""
msg = f"Published #{chunk_id} rev={c.rev} · translation failed: {str(e)}"
# Reset TTS cache for this chunk if text changed
c.tts_en_path = ""
c.tts_zh_path = ""
return msg, editor_table(session_id)
def publish_all(session_id: str):
chunks = STATE.get(session_id, [])
ok, fail = 0, 0
for c in chunks:
if c.status != "published":
c.status = "published"
c.rev += 1
if not c.zh_text and c.edited_text_en:
try:
c.zh_text = translate_en_to_zh(c.edited_text_en)
ok += 1
except:
fail += 1
c.tts_en_path = ""
c.tts_zh_path = ""
return f"Published ALL · translated_ok={ok} fail={fail}", editor_table(session_id)
# ======================
# Audience rendering + TTS generation (stable MVP)
# ======================
def render_audience_html(chunks: List[Chunk], view_lang: str) -> str:
# show last 50 published
published = [c for c in chunks if c.status == "published"][-50:]
def one(c: Chunk) -> str:
en = (c.edited_text_en or c.raw_text_en).strip()
zh = (c.zh_text or "").strip()
text = zh if view_lang == "zh" else en
return (
"<div style='padding:10px 12px;border:1px solid #ddd;border-radius:10px;margin:10px 0;'>"
f"<div style='font-size:12px;color:#666'>#{c.chunk_id} · {c.start_s:.2f}-{c.end_s:.2f}</div>"
f"<div style='font-size:16px;line-height:1.45'>{text}</div>"
"</div>"
)
if not published:
return "<i>No published captions yet.</i>"
return "".join(one(c) for c in published)
def ensure_latest_tts(session_id: str, view_lang: str) -> Tuple[str, Optional[str]]:
"""
Returns (status_msg, audio_filepath_or_None) for the latest published chunk in selected language.
This avoids heavy load and avoids relying on browser speechSynthesis.
"""
chunks = STATE.get(session_id, [])
published = [c for c in chunks if c.status == "published"]
if not published:
return "No published captions yet.", None
latest = published[-1]
# Ensure translation exists if user wants ZH
if view_lang == "zh" and not latest.zh_text:
try:
latest.zh_text = translate_en_to_zh(latest.edited_text_en)
except Exception as e:
return f"ZH translation failed: {str(e)}", None
try:
if view_lang == "en":
if not latest.tts_en_path:
latest.tts_en_path = tts_to_mp3(latest.edited_text_en, VOICE_EN)
return f"TTS ready (EN) for chunk #{latest.chunk_id}", latest.tts_en_path or None
else:
if not latest.tts_zh_path:
latest.tts_zh_path = tts_to_mp3(latest.zh_text, VOICE_ZH)
return f"TTS ready (ZH) for chunk #{latest.chunk_id}", latest.tts_zh_path or None
except Exception as e:
return f"TTS failed: {str(e)}", None
def refresh_audience(session_id: str, view_lang: str):
chunks = STATE.get(session_id, [])
html = render_audience_html(chunks, view_lang)
tts_msg, audio_path = ensure_latest_tts(session_id, view_lang)
return html, tts_msg, audio_path
# ======================
# Gradio UI
# ======================
with gr.Blocks(title="Live Caption MVP (HF)") as demo:
gr.Markdown(
"# Live Caption MVP (HF)\n"
"全英文轉寫 → 校對(EN)→ 自動翻譯(ZH)→ 發佈 → 觀眾端 EN/ZH 字幕 + 後端 TTS 生成 mp3 播放(不依賴手機瀏覽器 TTS)"
)
sid = gr.State(_session_id())
with gr.Tab("1) Ingest"):
gr.Markdown("上傳 iPhone 錄音檔(m4a/wav/mp3)→ 轉寫切段(Whisper segments)")
audio = gr.Audio(type="filepath", label="Upload audio")
btn_run = gr.Button("Transcribe & Build Chunks")
ingest_status = gr.Textbox(label="Status", interactive=False)
with gr.Tab("2) Editor"):
gr.Markdown("校對台:修改英文後 Publish,系統自動翻譯成中文(只對 Publish 後內容翻譯,省錢且更準)。")
table = gr.Dataframe(
headers=["chunk_id", "time", "status", "raw_en", "edited_en", "zh", "rev"],
datatype=["number", "str", "str", "str", "str", "str", "number"],
interactive=False
)
chunk_id_in = gr.Number(label="chunk_id", value=0, precision=0)
edited_in = gr.Textbox(label="edited_en (paste here)", lines=3)
btn_pub_one = gr.Button("Publish One (translate)")
btn_pub_all = gr.Button("Publish All (translate missing)")
editor_status = gr.Textbox(label="Editor Status", interactive=False)
with gr.Tab("3) Audience"):
gr.Markdown(
"觀眾端:顯示已發佈字幕。按 Refresh 會同時產生「最新一句」的音檔(EN 或 ZH 取決於選擇),用播放器播放。"
)
view_lang = gr.Radio(choices=["en", "zh"], value="zh", label="View language")
btn_refresh = gr.Button("Refresh Audience View")
aud_html = gr.HTML(label="Captions")
tts_status = gr.Textbox(label="TTS Status", interactive=False)
aud_audio = gr.Audio(label="Play latest line", type="filepath")
# ---- Actions ----
def _do_ingest(audio_path, session_id):
if not audio_path:
return "Please upload an audio file first.", []
msg = transcribe_to_chunks(audio_path, session_id)
return msg, editor_table(session_id)
btn_run.click(_do_ingest, inputs=[audio, sid], outputs=[ingest_status, table])
def _pub_one(session_id, cid, txt):
return publish_one(session_id, int(cid), txt)
btn_pub_one.click(_pub_one, inputs=[sid, chunk_id_in, edited_in], outputs=[editor_status, table])
btn_pub_all.click(lambda session_id: publish_all(session_id), inputs=[sid], outputs=[editor_status, table])
btn_refresh.click(refresh_audience, inputs=[sid, view_lang], outputs=[aud_html, tts_status, aud_audio])
demo.launch()