InfiniteTalk / streamlit_app.py
marcelo1126's picture
Update streamlit_app.py
1961fb5 verified
import os
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import streamlit as st
from gradio_client import Client
# Backward compat for gradio_client versions without JobStatus enum
try: # pragma: no cover
from gradio_client import JobStatus # type: ignore
except Exception: # pragma: no cover
class JobStatus: # minimal shim
FINISHED = "FINISHED"
FAILED = "FAILED"
CANCELLED = "CANCELLED"
st.set_page_config(
page_title="InfiniteTalk · Remote Streamlit",
page_icon="🎬",
layout="wide",
)
DEFAULT_SPACE_ID = os.getenv("HF_SPACE_ID", "your-username/InfiniteTalk")
@st.cache_resource(show_spinner=False)
def get_client(space_id: str, hf_token: Optional[str]) -> Client:
"""
Cache the gradio client so we do not re-create the session for each run.
"""
if not hf_token:
return Client(space_id)
# Gradio client renamed the token kwarg; try a few fallbacks for compatibility
for kwargs in ({"hf_token": hf_token}, {"token": hf_token}, {"headers": {"Authorization": f"Bearer {hf_token}"}}):
try:
return Client(space_id, **kwargs)
except TypeError:
continue
return Client(space_id)
@dataclass
class InferencePayload:
image_path: Optional[str]
video_path: Optional[str]
task_mode: str
prompt: str
negative_prompt: str
audio_path_1: Optional[str]
audio_path_2: Optional[str]
steps: int
seed: int
text_scale: float
audio_scale: float
mode_selector: str
tts_text: str
resolution: str
voice_1: str
voice_2: str
def _save_upload(upload, suffix_fallback: str) -> Optional[str]:
if upload is None:
return None
suffix = Path(upload.name).suffix or suffix_fallback
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(upload.read())
return tmp.name
def _resolve_media_paths(
task_mode: str,
image_upload,
video_upload,
audio_1_upload,
audio_2_upload,
use_sample: bool,
) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
"""
Convert uploaded files (or bundled examples) into file paths for the remote API.
"""
sample_img = Path("examples/single/ref_image.png")
sample_vid = Path("examples/single/ref_video.mp4")
sample_audio = Path("examples/single/1.wav")
image_path = None
video_path = None
audio_1_path = None
audio_2_path = None
if task_mode == "SingleImageDriven":
if use_sample and sample_img.exists():
image_path = str(sample_img)
else:
image_path = _save_upload(image_upload, ".png")
else:
if use_sample and sample_vid.exists():
video_path = str(sample_vid)
else:
video_path = _save_upload(video_upload, ".mp4")
if use_sample and sample_audio.exists():
audio_1_path = str(sample_audio)
else:
audio_1_path = _save_upload(audio_1_upload, ".wav")
audio_2_path = _save_upload(audio_2_upload, ".wav")
return image_path, video_path, audio_1_path, audio_2_path
def _submit_job(client: Client, payload: InferencePayload):
"""
Submit the request to the remote Gradio Space built from this repo.
The input ordering mirrors the click() wiring in app.py.
"""
return client.submit(
payload.image_path,
payload.video_path,
payload.task_mode,
payload.prompt,
payload.negative_prompt,
payload.audio_path_1,
payload.audio_path_2,
payload.steps,
payload.seed,
payload.text_scale,
payload.audio_scale,
payload.mode_selector,
payload.tts_text,
payload.resolution,
payload.voice_1,
payload.voice_2,
api_name="/predict",
)
def _render_hero():
st.markdown(
"""
<style>
/* Keep layout stable in fullscreen (HF embed sometimes toggles scrollbars) */
html, body {
width: 100%;
overflow-x: hidden;
overflow-y: scroll; /* always show vertical scrollbar to avoid horizontal shift */
scrollbar-gutter: stable both-edges;
overscroll-behavior: contain;
}
[data-testid="stAppViewContainer"] {
overflow: hidden;
}
.main {
background: radial-gradient(circle at 20% 20%, rgba(101, 80, 255, 0.08), transparent 35%),
radial-gradient(circle at 80% 0%, rgba(0, 186, 255, 0.12), transparent 40%),
linear-gradient(120deg, #0c0f1a, #0a0d18 50%, #0b0f1a);
color: #e8edf7;
}
.block-container {
padding-top: 2rem;
padding-bottom: 3rem;
max-width: 1400px;
margin: 0 auto;
}
.glass {
border-radius: 18px;
background: rgba(255, 255, 255, 0.03);
border: 1px solid rgba(255, 255, 255, 0.05);
box-shadow: 0 12px 60px rgba(0, 0, 0, 0.45);
backdrop-filter: blur(12px);
padding: 18px 18px 8px 18px;
}
.pill {
display: inline-flex;
padding: 6px 12px;
border-radius: 999px;
border: 1px solid rgba(255,255,255,0.1);
background: rgba(255,255,255,0.06);
font-size: 12px;
color: #b4c2ff;
margin-right: 8px;
}
h1 {
font-weight: 800;
letter-spacing: -0.5px;
margin-bottom: 0.2rem;
}
</style>
""",
unsafe_allow_html=True,
)
col1, col2 = st.columns([1.4, 1], vertical_alignment="center")
with col1:
st.markdown(
"""
<div class="glass">
<div class="pill">Remote · GPU free (via Hugging Face Space)</div>
<h1>InfiniteTalk Remote Control</h1>
<p style="color:#d3dcff;font-size:16px;line-height:1.6;">
Upload a video or a single image, add voice tracks or TTS, and stream
the heavy lifting to a Hugging Face Space instead of your local GPU.
</p>
</div>
""",
unsafe_allow_html=True,
)
with col2:
st.image("assets/logo2.jpg", use_container_width=True)
def _read_file_bytes(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()
def main():
_render_hero()
with st.sidebar:
st.subheader("Remote Backend")
space_id = st.text_input(
"Hugging Face Space ID",
value=DEFAULT_SPACE_ID,
help="Any running Space that uses this repo's gradio app (e.g. username/InfiniteTalk).",
)
hf_token = st.text_input(
"HF Token (optional)",
type="password",
help="Needed if the Space is private or gated.",
)
st.caption(
"提示: 当前公开 InfiniteTalk Space 偶尔会休眠,如果请求失败,请换一个 Space ID "
"(可以在 Hugging Face 直接 Duplicate 官方仓库后获得免费 GPU 时段)。"
)
st.markdown("---")
st.subheader("Output")
default_steps = st.slider("Diffusion steps", min_value=4, max_value=100, value=12)
default_seed = st.number_input("Seed (-1 for random)", value=-1, step=1)
text_scale = st.slider("Text guide scale", 0.0, 20.0, 1.5, step=0.5)
audio_scale = st.slider("Audio guide scale", 0.0, 20.0, 2.0, step=0.5)
resolution = st.radio(
"Resolution budget",
options=["infinitetalk-480", "infinitetalk-720"],
horizontal=True,
)
st.markdown("---")
st.markdown(
"💡 推荐流程:如果你还没有在线 Space,可以先勾选“使用示例素材”检查前端,再把 Space ID 换成自己的 Hugging Face Space。"
)
st.markdown("### 任务配置")
task_mode = st.radio(
"任务",
options=["VideoDubbing", "SingleImageDriven"],
horizontal=True,
index=0,
help="VideoDubbing: 视频+音频对口型;SingleImageDriven: 单张图+音频生成视频。",
)
col_input, col_audio = st.columns([1.35, 1])
with col_input:
st.markdown("#### 视觉输入")
use_sample = st.checkbox("使用仓库自带示例素材", value=False)
video_upload = None
image_upload = None
if task_mode == "VideoDubbing":
video_upload = st.file_uploader(
"上传参考视频 (mp4)",
type=["mp4", "mov", "mkv"],
accept_multiple_files=False,
)
else:
image_upload = st.file_uploader(
"上传参考图片",
type=["png", "jpg", "jpeg"],
accept_multiple_files=False,
)
prompt = st.text_area(
"正向提示词",
value="A cinematic talking head shot, natural lighting, film look",
help="描述你希望视频呈现的氛围、镜头、风格等。",
)
negative_prompt = st.text_area(
"反向提示词",
value=(
"bright tones, overexposed, static, blurred details, subtitles, style, paintings, "
"JPEG artifacts, ugly, distorted hands or faces, messy background"
),
)
with col_audio:
st.markdown("#### 音频 & 声音")
mode_selector = st.selectbox(
"音频模式",
options=[
"Single Person(Local File)",
"Single Person(TTS)",
"Multi Person(Local File, audio add)",
"Multi Person(Local File, audio parallel)",
"Multi Person(TTS)",
],
index=0,
)
audio_1_upload = None
audio_2_upload = None
tts_text = ""
if "Local File" in mode_selector:
audio_1_upload = st.file_uploader(
"说话人 1 音频 (wav/mp3)",
type=["wav", "mp3", "flac", "m4a"],
accept_multiple_files=False,
)
if "Multi Person" in mode_selector:
audio_2_upload = st.file_uploader(
"说话人 2 音频 (wav/mp3)",
type=["wav", "mp3", "flac", "m4a"],
accept_multiple_files=False,
)
else:
tts_text = st.text_area(
"TTS 文本",
value="Hello, welcome to InfiniteTalk remote generation demo!",
)
voice_1 = st.text_input(
"Voice ID (左声道)",
value="weights/Kokoro-82M/voices/am_adam.pt",
)
voice_2 = st.text_input(
"Voice ID (右声道)",
value="weights/Kokoro-82M/voices/af_heart.pt",
help="双人对话时需要第二个声音;单人模式可忽略。",
)
st.markdown("---")
generate = st.button("🚀 开始生成 (运行在远端 Space)", type="primary")
if generate:
if not space_id:
st.error("请先填写可用的 Hugging Face Space ID。")
return
image_path, video_path, audio_1_path, audio_2_path = _resolve_media_paths(
task_mode,
image_upload,
video_upload,
audio_1_upload,
audio_2_upload,
use_sample,
)
if task_mode == "VideoDubbing" and not video_path:
st.error("请上传视频或勾选示例素材。")
return
if task_mode == "SingleImageDriven" and not image_path:
st.error("请上传图片或勾选示例素材。")
return
if "Local File" in mode_selector and not audio_1_path:
st.error("请提供至少一段音频,或切换到 TTS。")
return
if "Multi Person" in mode_selector and "Local File" in mode_selector and not audio_2_path:
st.error("多说话人模式需要第二段音频,或者改用 TTS。")
return
payload = InferencePayload(
image_path=image_path,
video_path=video_path,
task_mode=task_mode,
prompt=prompt,
negative_prompt=negative_prompt,
audio_path_1=audio_1_path,
audio_path_2=audio_2_path,
steps=int(default_steps),
seed=int(default_seed),
text_scale=float(text_scale),
audio_scale=float(audio_scale),
mode_selector=mode_selector,
tts_text=tts_text,
resolution=resolution,
voice_1=voice_1,
voice_2=voice_2,
)
status_area = st.status("连接远端空间...", state="running")
try:
client = get_client(space_id, hf_token)
status_area.update(label="排队 & 处理请求...", state="running")
job = _submit_job(client, payload)
info_placeholder = st.empty()
while True:
current_status = job.status()
code = getattr(current_status, "code", current_status)
eta = getattr(current_status, "eta_seconds", None)
code_name = code.name if hasattr(code, "name") else str(code)
info_placeholder.info(
f"队列状态: {code_name} | 预计剩余 {eta or '?'}s",
icon="⏱️",
)
if str(code) in (
str(getattr(JobStatus, "FINISHED", "FINISHED")),
str(getattr(JobStatus, "CANCELLED", "CANCELLED")),
str(getattr(JobStatus, "FAILED", "FAILED")),
):
break
time.sleep(3)
if str(code) == str(getattr(JobStatus, "FINISHED", "FINISHED")):
result = job.result()
output_path = None
if isinstance(result, (list, tuple)) and result:
output_path = result[0]
elif isinstance(result, dict) and "video" in result:
output_path = result["video"]
elif isinstance(result, str):
output_path = result
if not output_path or not Path(output_path).exists():
status_area.update(
label="远端已完成,但未拿到视频路径,请检查 Space 配置。",
state="error",
)
return
status_area.update(label="生成完成 🎉", state="complete")
st.success("远端生成完成,下面可以直接预览或下载。")
st.video(output_path)
st.download_button(
"下载视频",
data=_read_file_bytes(output_path),
file_name=Path(output_path).name,
mime="video/mp4",
)
else:
msg = getattr(current_status, "message", None)
status_area.update(
label=f"任务失败: {msg or code_name}",
state="error",
)
except Exception as exc: # noqa: BLE001
status_area.update(label="请求失败", state="error")
st.error(
f"无法连接到 Hugging Face Space({space_id})。请确认 Space 正在运行,或更换 Space ID。\n\n详情: {exc}"
)
if __name__ == "__main__":
main()