Spaces:
Running
Running
| import os | |
| import tempfile | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import streamlit as st | |
| from gradio_client import Client | |
| # Backward compat for gradio_client versions without JobStatus enum | |
| try: # pragma: no cover | |
| from gradio_client import JobStatus # type: ignore | |
| except Exception: # pragma: no cover | |
| class JobStatus: # minimal shim | |
| FINISHED = "FINISHED" | |
| FAILED = "FAILED" | |
| CANCELLED = "CANCELLED" | |
| st.set_page_config( | |
| page_title="InfiniteTalk · Remote Streamlit", | |
| page_icon="🎬", | |
| layout="wide", | |
| ) | |
| DEFAULT_SPACE_ID = os.getenv("HF_SPACE_ID", "your-username/InfiniteTalk") | |
| def get_client(space_id: str, hf_token: Optional[str]) -> Client: | |
| """ | |
| Cache the gradio client so we do not re-create the session for each run. | |
| """ | |
| if not hf_token: | |
| return Client(space_id) | |
| # Gradio client renamed the token kwarg; try a few fallbacks for compatibility | |
| for kwargs in ({"hf_token": hf_token}, {"token": hf_token}, {"headers": {"Authorization": f"Bearer {hf_token}"}}): | |
| try: | |
| return Client(space_id, **kwargs) | |
| except TypeError: | |
| continue | |
| return Client(space_id) | |
| class InferencePayload: | |
| image_path: Optional[str] | |
| video_path: Optional[str] | |
| task_mode: str | |
| prompt: str | |
| negative_prompt: str | |
| audio_path_1: Optional[str] | |
| audio_path_2: Optional[str] | |
| steps: int | |
| seed: int | |
| text_scale: float | |
| audio_scale: float | |
| mode_selector: str | |
| tts_text: str | |
| resolution: str | |
| voice_1: str | |
| voice_2: str | |
| def _save_upload(upload, suffix_fallback: str) -> Optional[str]: | |
| if upload is None: | |
| return None | |
| suffix = Path(upload.name).suffix or suffix_fallback | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| tmp.write(upload.read()) | |
| return tmp.name | |
| def _resolve_media_paths( | |
| task_mode: str, | |
| image_upload, | |
| video_upload, | |
| audio_1_upload, | |
| audio_2_upload, | |
| use_sample: bool, | |
| ) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: | |
| """ | |
| Convert uploaded files (or bundled examples) into file paths for the remote API. | |
| """ | |
| sample_img = Path("examples/single/ref_image.png") | |
| sample_vid = Path("examples/single/ref_video.mp4") | |
| sample_audio = Path("examples/single/1.wav") | |
| image_path = None | |
| video_path = None | |
| audio_1_path = None | |
| audio_2_path = None | |
| if task_mode == "SingleImageDriven": | |
| if use_sample and sample_img.exists(): | |
| image_path = str(sample_img) | |
| else: | |
| image_path = _save_upload(image_upload, ".png") | |
| else: | |
| if use_sample and sample_vid.exists(): | |
| video_path = str(sample_vid) | |
| else: | |
| video_path = _save_upload(video_upload, ".mp4") | |
| if use_sample and sample_audio.exists(): | |
| audio_1_path = str(sample_audio) | |
| else: | |
| audio_1_path = _save_upload(audio_1_upload, ".wav") | |
| audio_2_path = _save_upload(audio_2_upload, ".wav") | |
| return image_path, video_path, audio_1_path, audio_2_path | |
| def _submit_job(client: Client, payload: InferencePayload): | |
| """ | |
| Submit the request to the remote Gradio Space built from this repo. | |
| The input ordering mirrors the click() wiring in app.py. | |
| """ | |
| return client.submit( | |
| payload.image_path, | |
| payload.video_path, | |
| payload.task_mode, | |
| payload.prompt, | |
| payload.negative_prompt, | |
| payload.audio_path_1, | |
| payload.audio_path_2, | |
| payload.steps, | |
| payload.seed, | |
| payload.text_scale, | |
| payload.audio_scale, | |
| payload.mode_selector, | |
| payload.tts_text, | |
| payload.resolution, | |
| payload.voice_1, | |
| payload.voice_2, | |
| api_name="/predict", | |
| ) | |
| def _render_hero(): | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Keep layout stable in fullscreen (HF embed sometimes toggles scrollbars) */ | |
| html, body { | |
| width: 100%; | |
| overflow-x: hidden; | |
| overflow-y: scroll; /* always show vertical scrollbar to avoid horizontal shift */ | |
| scrollbar-gutter: stable both-edges; | |
| overscroll-behavior: contain; | |
| } | |
| [data-testid="stAppViewContainer"] { | |
| overflow: hidden; | |
| } | |
| .main { | |
| background: radial-gradient(circle at 20% 20%, rgba(101, 80, 255, 0.08), transparent 35%), | |
| radial-gradient(circle at 80% 0%, rgba(0, 186, 255, 0.12), transparent 40%), | |
| linear-gradient(120deg, #0c0f1a, #0a0d18 50%, #0b0f1a); | |
| color: #e8edf7; | |
| } | |
| .block-container { | |
| padding-top: 2rem; | |
| padding-bottom: 3rem; | |
| max-width: 1400px; | |
| margin: 0 auto; | |
| } | |
| .glass { | |
| border-radius: 18px; | |
| background: rgba(255, 255, 255, 0.03); | |
| border: 1px solid rgba(255, 255, 255, 0.05); | |
| box-shadow: 0 12px 60px rgba(0, 0, 0, 0.45); | |
| backdrop-filter: blur(12px); | |
| padding: 18px 18px 8px 18px; | |
| } | |
| .pill { | |
| display: inline-flex; | |
| padding: 6px 12px; | |
| border-radius: 999px; | |
| border: 1px solid rgba(255,255,255,0.1); | |
| background: rgba(255,255,255,0.06); | |
| font-size: 12px; | |
| color: #b4c2ff; | |
| margin-right: 8px; | |
| } | |
| h1 { | |
| font-weight: 800; | |
| letter-spacing: -0.5px; | |
| margin-bottom: 0.2rem; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| col1, col2 = st.columns([1.4, 1], vertical_alignment="center") | |
| with col1: | |
| st.markdown( | |
| """ | |
| <div class="glass"> | |
| <div class="pill">Remote · GPU free (via Hugging Face Space)</div> | |
| <h1>InfiniteTalk Remote Control</h1> | |
| <p style="color:#d3dcff;font-size:16px;line-height:1.6;"> | |
| Upload a video or a single image, add voice tracks or TTS, and stream | |
| the heavy lifting to a Hugging Face Space instead of your local GPU. | |
| </p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| with col2: | |
| st.image("assets/logo2.jpg", use_container_width=True) | |
| def _read_file_bytes(path: str) -> bytes: | |
| with open(path, "rb") as f: | |
| return f.read() | |
| def main(): | |
| _render_hero() | |
| with st.sidebar: | |
| st.subheader("Remote Backend") | |
| space_id = st.text_input( | |
| "Hugging Face Space ID", | |
| value=DEFAULT_SPACE_ID, | |
| help="Any running Space that uses this repo's gradio app (e.g. username/InfiniteTalk).", | |
| ) | |
| hf_token = st.text_input( | |
| "HF Token (optional)", | |
| type="password", | |
| help="Needed if the Space is private or gated.", | |
| ) | |
| st.caption( | |
| "提示: 当前公开 InfiniteTalk Space 偶尔会休眠,如果请求失败,请换一个 Space ID " | |
| "(可以在 Hugging Face 直接 Duplicate 官方仓库后获得免费 GPU 时段)。" | |
| ) | |
| st.markdown("---") | |
| st.subheader("Output") | |
| default_steps = st.slider("Diffusion steps", min_value=4, max_value=100, value=12) | |
| default_seed = st.number_input("Seed (-1 for random)", value=-1, step=1) | |
| text_scale = st.slider("Text guide scale", 0.0, 20.0, 1.5, step=0.5) | |
| audio_scale = st.slider("Audio guide scale", 0.0, 20.0, 2.0, step=0.5) | |
| resolution = st.radio( | |
| "Resolution budget", | |
| options=["infinitetalk-480", "infinitetalk-720"], | |
| horizontal=True, | |
| ) | |
| st.markdown("---") | |
| st.markdown( | |
| "💡 推荐流程:如果你还没有在线 Space,可以先勾选“使用示例素材”检查前端,再把 Space ID 换成自己的 Hugging Face Space。" | |
| ) | |
| st.markdown("### 任务配置") | |
| task_mode = st.radio( | |
| "任务", | |
| options=["VideoDubbing", "SingleImageDriven"], | |
| horizontal=True, | |
| index=0, | |
| help="VideoDubbing: 视频+音频对口型;SingleImageDriven: 单张图+音频生成视频。", | |
| ) | |
| col_input, col_audio = st.columns([1.35, 1]) | |
| with col_input: | |
| st.markdown("#### 视觉输入") | |
| use_sample = st.checkbox("使用仓库自带示例素材", value=False) | |
| video_upload = None | |
| image_upload = None | |
| if task_mode == "VideoDubbing": | |
| video_upload = st.file_uploader( | |
| "上传参考视频 (mp4)", | |
| type=["mp4", "mov", "mkv"], | |
| accept_multiple_files=False, | |
| ) | |
| else: | |
| image_upload = st.file_uploader( | |
| "上传参考图片", | |
| type=["png", "jpg", "jpeg"], | |
| accept_multiple_files=False, | |
| ) | |
| prompt = st.text_area( | |
| "正向提示词", | |
| value="A cinematic talking head shot, natural lighting, film look", | |
| help="描述你希望视频呈现的氛围、镜头、风格等。", | |
| ) | |
| negative_prompt = st.text_area( | |
| "反向提示词", | |
| value=( | |
| "bright tones, overexposed, static, blurred details, subtitles, style, paintings, " | |
| "JPEG artifacts, ugly, distorted hands or faces, messy background" | |
| ), | |
| ) | |
| with col_audio: | |
| st.markdown("#### 音频 & 声音") | |
| mode_selector = st.selectbox( | |
| "音频模式", | |
| options=[ | |
| "Single Person(Local File)", | |
| "Single Person(TTS)", | |
| "Multi Person(Local File, audio add)", | |
| "Multi Person(Local File, audio parallel)", | |
| "Multi Person(TTS)", | |
| ], | |
| index=0, | |
| ) | |
| audio_1_upload = None | |
| audio_2_upload = None | |
| tts_text = "" | |
| if "Local File" in mode_selector: | |
| audio_1_upload = st.file_uploader( | |
| "说话人 1 音频 (wav/mp3)", | |
| type=["wav", "mp3", "flac", "m4a"], | |
| accept_multiple_files=False, | |
| ) | |
| if "Multi Person" in mode_selector: | |
| audio_2_upload = st.file_uploader( | |
| "说话人 2 音频 (wav/mp3)", | |
| type=["wav", "mp3", "flac", "m4a"], | |
| accept_multiple_files=False, | |
| ) | |
| else: | |
| tts_text = st.text_area( | |
| "TTS 文本", | |
| value="Hello, welcome to InfiniteTalk remote generation demo!", | |
| ) | |
| voice_1 = st.text_input( | |
| "Voice ID (左声道)", | |
| value="weights/Kokoro-82M/voices/am_adam.pt", | |
| ) | |
| voice_2 = st.text_input( | |
| "Voice ID (右声道)", | |
| value="weights/Kokoro-82M/voices/af_heart.pt", | |
| help="双人对话时需要第二个声音;单人模式可忽略。", | |
| ) | |
| st.markdown("---") | |
| generate = st.button("🚀 开始生成 (运行在远端 Space)", type="primary") | |
| if generate: | |
| if not space_id: | |
| st.error("请先填写可用的 Hugging Face Space ID。") | |
| return | |
| image_path, video_path, audio_1_path, audio_2_path = _resolve_media_paths( | |
| task_mode, | |
| image_upload, | |
| video_upload, | |
| audio_1_upload, | |
| audio_2_upload, | |
| use_sample, | |
| ) | |
| if task_mode == "VideoDubbing" and not video_path: | |
| st.error("请上传视频或勾选示例素材。") | |
| return | |
| if task_mode == "SingleImageDriven" and not image_path: | |
| st.error("请上传图片或勾选示例素材。") | |
| return | |
| if "Local File" in mode_selector and not audio_1_path: | |
| st.error("请提供至少一段音频,或切换到 TTS。") | |
| return | |
| if "Multi Person" in mode_selector and "Local File" in mode_selector and not audio_2_path: | |
| st.error("多说话人模式需要第二段音频,或者改用 TTS。") | |
| return | |
| payload = InferencePayload( | |
| image_path=image_path, | |
| video_path=video_path, | |
| task_mode=task_mode, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| audio_path_1=audio_1_path, | |
| audio_path_2=audio_2_path, | |
| steps=int(default_steps), | |
| seed=int(default_seed), | |
| text_scale=float(text_scale), | |
| audio_scale=float(audio_scale), | |
| mode_selector=mode_selector, | |
| tts_text=tts_text, | |
| resolution=resolution, | |
| voice_1=voice_1, | |
| voice_2=voice_2, | |
| ) | |
| status_area = st.status("连接远端空间...", state="running") | |
| try: | |
| client = get_client(space_id, hf_token) | |
| status_area.update(label="排队 & 处理请求...", state="running") | |
| job = _submit_job(client, payload) | |
| info_placeholder = st.empty() | |
| while True: | |
| current_status = job.status() | |
| code = getattr(current_status, "code", current_status) | |
| eta = getattr(current_status, "eta_seconds", None) | |
| code_name = code.name if hasattr(code, "name") else str(code) | |
| info_placeholder.info( | |
| f"队列状态: {code_name} | 预计剩余 {eta or '?'}s", | |
| icon="⏱️", | |
| ) | |
| if str(code) in ( | |
| str(getattr(JobStatus, "FINISHED", "FINISHED")), | |
| str(getattr(JobStatus, "CANCELLED", "CANCELLED")), | |
| str(getattr(JobStatus, "FAILED", "FAILED")), | |
| ): | |
| break | |
| time.sleep(3) | |
| if str(code) == str(getattr(JobStatus, "FINISHED", "FINISHED")): | |
| result = job.result() | |
| output_path = None | |
| if isinstance(result, (list, tuple)) and result: | |
| output_path = result[0] | |
| elif isinstance(result, dict) and "video" in result: | |
| output_path = result["video"] | |
| elif isinstance(result, str): | |
| output_path = result | |
| if not output_path or not Path(output_path).exists(): | |
| status_area.update( | |
| label="远端已完成,但未拿到视频路径,请检查 Space 配置。", | |
| state="error", | |
| ) | |
| return | |
| status_area.update(label="生成完成 🎉", state="complete") | |
| st.success("远端生成完成,下面可以直接预览或下载。") | |
| st.video(output_path) | |
| st.download_button( | |
| "下载视频", | |
| data=_read_file_bytes(output_path), | |
| file_name=Path(output_path).name, | |
| mime="video/mp4", | |
| ) | |
| else: | |
| msg = getattr(current_status, "message", None) | |
| status_area.update( | |
| label=f"任务失败: {msg or code_name}", | |
| state="error", | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| status_area.update(label="请求失败", state="error") | |
| st.error( | |
| f"无法连接到 Hugging Face Space({space_id})。请确认 Space 正在运行,或更换 Space ID。\n\n详情: {exc}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |