Spaces:
Runtime error
Runtime error
| import os, tempfile, queue, threading, time, numpy as np, soundfile as sf | |
| import gradio as gr | |
| from stream_pipeline_offline import StreamSDK | |
| import torch | |
| from PIL import Image | |
| from pathlib import Path | |
| import cv2 | |
| # モデル設定 | |
| CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
| DATA_ROOT = "checkpoints/ditto_pytorch" | |
| # サンプルファイルのディレクトリ | |
| EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() | |
| OUTPUT_DIR = (Path(__file__).parent / "output").resolve() | |
| # 出力ディレクトリの作成 | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # グローバルで一度だけロード(concurrency_count=1 前提) | |
| sdk: StreamSDK | None = None | |
| def init_sdk(): | |
| global sdk | |
| if sdk is None: | |
| sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
| return sdk | |
| # 音声チャンクサイズ(秒) | |
| CHUNK_SEC = 0.20 # 16000*0.20 = 3200 sample ≒ 5 フレーム | |
| def generator(mic, src_img): | |
| """ | |
| Gradio 生成関数 | |
| mic : (sr, np.ndarray) 形式 (Gradio Audio streaming=True) | |
| src_img : 画像ファイルパス | |
| Yields : PIL.Image (現在フレーム) または (最後に mp4) | |
| """ | |
| if mic is None: | |
| yield None, None, "マイク入力を開始してください" | |
| return | |
| if src_img is None: | |
| yield None, None, "ソース画像をアップロードしてください" | |
| return | |
| try: | |
| sr, wav_full = mic | |
| sdk = init_sdk() | |
| # setup: online_mode=True でストリーミング | |
| import uuid | |
| output_filename = f"{uuid.uuid4()}.mp4" | |
| tmp_out = str(OUTPUT_DIR / output_filename) | |
| sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024) | |
| N_total = int(np.ceil(len(wav_full) / sr * 20)) # 概算フレーム数 | |
| sdk.setup_Nd(N_total) | |
| # 処理開始時刻 | |
| start_time = time.time() | |
| frame_count = 0 | |
| # 音声を CHUNK_SEC ごとに送り込む | |
| hop = int(sr * CHUNK_SEC) | |
| for start_idx in range(0, len(wav_full), hop): | |
| chunk = wav_full[start_idx : start_idx + hop] | |
| if len(chunk) < hop: | |
| chunk = np.pad(chunk, (0, hop - len(chunk))) | |
| sdk.run_chunk(chunk) | |
| # 直近で書き込まれたフレームをキューから取得 | |
| frames_processed = 0 | |
| while sdk.writer_queue.qsize() > 0 and frames_processed < 5: | |
| try: | |
| frame = sdk.writer_queue.get_nowait() | |
| if frame is not None: | |
| # numpy array (H, W, 3) を PIL Image に変換 | |
| pil_frame = Image.fromarray(frame) | |
| frame_count += 1 | |
| elapsed = time.time() - start_time | |
| fps = frame_count / elapsed if elapsed > 0 else 0 | |
| yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}" | |
| frames_processed += 1 | |
| except queue.Empty: | |
| break | |
| # 少し待機(CPU負荷調整) | |
| time.sleep(0.01) | |
| # 残りのフレームを処理 | |
| print("音声チャンクの送信完了、残りフレームを処理中...") | |
| timeout_count = 0 | |
| while timeout_count < 50: # 最大5秒待機 | |
| if sdk.writer_queue.qsize() > 0: | |
| try: | |
| frame = sdk.writer_queue.get_nowait() | |
| if frame is not None: | |
| pil_frame = Image.fromarray(frame) | |
| frame_count += 1 | |
| elapsed = time.time() - start_time | |
| fps = frame_count / elapsed if elapsed > 0 else 0 | |
| yield pil_frame, None, f"処理中... フレーム: {frame_count}, FPS: {fps:.1f}" | |
| timeout_count = 0 | |
| except queue.Empty: | |
| time.sleep(0.1) | |
| timeout_count += 1 | |
| else: | |
| time.sleep(0.1) | |
| timeout_count += 1 | |
| # SDKを閉じて最終的なMP4を生成 | |
| print("SDKを閉じて最終的なMP4を生成中...") | |
| sdk.close() # ワーカー join & mp4 結合 | |
| # 処理完了 | |
| elapsed_total = time.time() - start_time | |
| yield None, gr.Video(tmp_out), f"✅ 完了! 総フレーム数: {frame_count}, 処理時間: {elapsed_total:.1f}秒" | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ エラー: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| yield None, None, error_msg | |
| # Gradio UI | |
| with gr.Blocks(title="DittoTalkingHead Streaming") as demo: | |
| gr.Markdown(""" | |
| # DittoTalkingHead - ストリーミング版 | |
| 音声をリアルタイムで処理し、生成されたフレームを逐次表示します。 | |
| ## 使い方 | |
| 1. **ソース画像**(PNG/JPG形式)をアップロード | |
| 2. **Start**ボタンをクリックしてマイク録音開始 | |
| 3. 録音中、ライブフレームが更新されます | |
| 4. 録音停止後、最終的なMP4が生成されます | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_in = gr.Image( | |
| type="filepath", | |
| label="ソース画像 / Source Image", | |
| value=str(EXAMPLES_DIR / "reference.png") if (EXAMPLES_DIR / "reference.png").exists() else None | |
| ) | |
| mic_in = gr.Audio( | |
| sources=["microphone"], | |
| streaming=True, | |
| label="マイク入力 (16 kHz)", | |
| format="wav" | |
| ) | |
| with gr.Column(): | |
| live_img = gr.Image(label="ライブフレーム", type="pil") | |
| final_mp4 = gr.Video(label="最終結果 (MP4)") | |
| status_text = gr.Textbox(label="ステータス", value="待機中...") | |
| btn = gr.Button("Start Streaming", variant="primary") | |
| # ストリーミング処理を開始 | |
| btn.click( | |
| fn=generator, | |
| inputs=[mic_in, img_in], | |
| outputs=[live_img, final_mp4, status_text], | |
| stream_every=0.1 # 100msごとに更新 | |
| ) | |
| # サンプル | |
| if EXAMPLES_DIR.exists(): | |
| gr.Examples( | |
| examples=[ | |
| [str(EXAMPLES_DIR / "reference.png")] | |
| ], | |
| inputs=[img_in], | |
| label="サンプル画像" | |
| ) | |
| # 起動設定 | |
| if __name__ == "__main__": | |
| # GPU最適化設定 | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.backends.cudnn.benchmark = True | |
| # 環境変数設定 | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| print("=== DittoTalkingHead ストリーミング版 起動 ===") | |
| print(f"- チャンクサイズ: {CHUNK_SEC}秒") | |
| print(f"- 最大解像度: 1024px") | |
| print(f"- GPU: {'利用可能' if torch.cuda.is_available() else '利用不可'}") | |
| # モデルの事前ロード | |
| print("モデルを事前ロード中...") | |
| init_sdk() | |
| print("✅ モデルロード完了") | |
| demo.queue(concurrency_count=1, max_size=8).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| allowed_paths=[str(EXAMPLES_DIR), str(OUTPUT_DIR)] | |
| ) |