Spaces:
Runtime error
Runtime error
| """ | |
| Optimized DittoTalkingHead App with Phase 3 Performance Improvements | |
| """ | |
| import gradio as gr | |
| import os | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| import torch | |
| import time | |
| from typing import Optional, Dict, Any | |
| import io | |
| from model_manager import ModelManager | |
| from core.optimization import ( | |
| FixedResolutionProcessor, | |
| GPUOptimizer, | |
| AvatarCache, | |
| AvatarTokenManager, | |
| ColdStartOptimizer, | |
| InferenceCache, | |
| CachedInference, | |
| ParallelProcessor, | |
| ParallelInference, | |
| OptimizedInferenceWrapper | |
| ) | |
| from cleanup_old_files import initialize_cleanup, get_cleanup_status | |
| # サンプルファイルのディレクトリを定義 | |
| EXAMPLES_DIR = (Path(__file__).parent / "example").resolve() | |
| OUTPUT_DIR = (Path(__file__).parent / "output").resolve() | |
| # 出力ディレクトリの作成 | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # ファイルクリーンアップの初期化(24時間後に自動削除) | |
| initialize_cleanup(OUTPUT_DIR, max_age_hours=24) | |
| # 初期化フラグ | |
| print("=== Phase 3 最適化版 - 初期化開始 ===") | |
| # 1. 解像度最適化の初期化 | |
| resolution_optimizer = FixedResolutionProcessor() | |
| FIXED_RESOLUTION = resolution_optimizer.get_max_dim() # 320 | |
| print(f"✅ 解像度固定: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}") | |
| # 2. GPU最適化の初期化 | |
| gpu_optimizer = GPUOptimizer() | |
| print(gpu_optimizer.get_optimization_summary()) | |
| # 3. Cold Start最適化の初期化 | |
| cold_start_optimizer = ColdStartOptimizer() | |
| # 4. アバターキャッシュの初期化 | |
| avatar_cache = AvatarCache(cache_dir="/tmp/avatar_cache", ttl_days=14) | |
| token_manager = AvatarTokenManager(avatar_cache) | |
| print(f"✅ アバターキャッシュ初期化: {avatar_cache.get_cache_info()}") | |
| # 5. 推論キャッシュの初期化 | |
| inference_cache = InferenceCache( | |
| cache_dir="/tmp/inference_cache", | |
| memory_cache_size=50, | |
| file_cache_size_gb=5.0, | |
| ttl_hours=24 | |
| ) | |
| cached_inference = CachedInference(inference_cache) | |
| print(f"✅ 推論キャッシュ初期化: {inference_cache.get_cache_stats()}") | |
| # 6. 並列処理の初期化(SDK初期化後に移動) | |
| # モデルの初期化(最適化版) | |
| USE_PYTORCH = True | |
| model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) | |
| # Cold start最適化: 永続ストレージのセットアップ | |
| if not cold_start_optimizer.setup_persistent_model_cache("./checkpoints"): | |
| print("⚠️ 永続ストレージのセットアップに失敗") | |
| if not model_manager.setup_models(): | |
| raise RuntimeError("モデルのセットアップに失敗しました。") | |
| # SDKの初期化 | |
| if USE_PYTORCH: | |
| data_root = "./checkpoints/ditto_pytorch" | |
| cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
| else: | |
| data_root = "./checkpoints/ditto_trt_Ampere_Plus" | |
| cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" | |
| # SDK初期化 | |
| SDK = None | |
| try: | |
| from stream_pipeline_offline import StreamSDK | |
| from inference import run, seed_everything | |
| # SDKを最適化設定で初期化 | |
| SDK = StreamSDK(cfg_pkl, data_root) | |
| print("✅ SDK初期化成功(最適化版)") | |
| # GPU最適化を適用(torch.nn.Moduleの場合のみ) | |
| if hasattr(SDK, 'decode_f3d') and hasattr(SDK.decode_f3d, 'decoder'): | |
| try: | |
| import torch.nn as nn | |
| if isinstance(SDK.decode_f3d.decoder, nn.Module): | |
| SDK.decode_f3d.decoder = gpu_optimizer.optimize_model(SDK.decode_f3d.decoder) | |
| print("✅ デコーダーモデルに最適化を適用") | |
| else: | |
| print("ℹ️ デコーダーはnn.Moduleではないため、最適化をスキップ") | |
| except Exception as e: | |
| print(f"⚠️ GPU最適化の適用をスキップ: {e}") | |
| except Exception as e: | |
| print(f"❌ SDK初期化エラー: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise | |
| # 並列処理の初期化(SDK初期化成功後) | |
| parallel_processor = ParallelProcessor(num_threads=4, num_processes=2) | |
| parallel_inference = ParallelInference(SDK, parallel_processor) | |
| optimized_wrapper = OptimizedInferenceWrapper( | |
| SDK, | |
| use_parallel=True, | |
| use_cache=True, | |
| use_gpu_opt=True | |
| ) | |
| print(f"✅ 並列処理初期化: {parallel_inference.get_performance_stats()}") | |
| def prepare_avatar(image_file) -> Dict[str, Any]: | |
| """ | |
| 画像を事前処理してアバタートークンを生成 | |
| Args: | |
| image_file: アップロードされた画像ファイル | |
| Returns: | |
| アバタートークン情報 | |
| """ | |
| if image_file is None: | |
| return {"error": "画像ファイルをアップロードしてください。"} | |
| try: | |
| # 画像データを読み込む | |
| with open(image_file, 'rb') as f: | |
| image_data = f.read() | |
| # 外観エンコーダーで埋め込みを生成 | |
| def encode_appearance(img_data): | |
| # ここでは簡略化のため、SDKの外観抽出を使用 | |
| # 実際の実装では appearance_extractor を直接呼び出す | |
| import numpy as np | |
| from PIL import Image | |
| # 画像を読み込んで処理 | |
| img = Image.open(io.BytesIO(img_data)) | |
| img = img.convert('RGB') | |
| img = img.resize((FIXED_RESOLUTION, FIXED_RESOLUTION)) | |
| # 仮の埋め込みベクトル(実際はモデルで生成) | |
| # TODO: 実際の appearance_extractor を使用 | |
| embedding = np.random.randn(512).astype(np.float32) | |
| return embedding | |
| # トークンを生成 | |
| result = token_manager.prepare_avatar( | |
| image_data, | |
| encode_appearance | |
| ) | |
| return { | |
| "status": "✅ アバター準備完了", | |
| "avatar_token": result['avatar_token'], | |
| "expires": result['expires'], | |
| "cached": "キャッシュ済み" if result['cached'] else "新規生成" | |
| } | |
| except Exception as e: | |
| import traceback | |
| return { | |
| "error": f"❌ エラー: {str(e)}\n{traceback.format_exc()}" | |
| } | |
| def process_talking_head_optimized( | |
| audio_file, | |
| source_image, | |
| avatar_token: Optional[str] = None, | |
| use_resolution_optimization: bool = True, | |
| use_inference_cache: bool = True, | |
| use_parallel_processing: bool = True | |
| ): | |
| """ | |
| 最適化されたTalking Head生成処理(キャッシュ対応) | |
| Args: | |
| audio_file: 音声ファイル | |
| source_image: ソース画像(avatar_tokenがない場合に使用) | |
| avatar_token: 事前生成されたアバタートークン | |
| use_resolution_optimization: 解像度最適化を使用するか | |
| use_inference_cache: 推論キャッシュを使用するか | |
| """ | |
| if audio_file is None: | |
| return None, "音声ファイルをアップロードしてください。" | |
| if avatar_token is None and source_image is None: | |
| return None, "ソース画像またはアバタートークンが必要です。" | |
| try: | |
| start_time = time.time() | |
| # 出力ファイルの作成(出力ディレクトリ内) | |
| import uuid | |
| output_filename = f"{uuid.uuid4()}.mp4" | |
| output_path = str(OUTPUT_DIR / output_filename) | |
| # アバタートークンから埋め込みを取得 | |
| if avatar_token: | |
| embedding = avatar_cache.load_embedding(avatar_token) | |
| if embedding is None: | |
| return None, "❌ 無効または期限切れのアバタートークンです。" | |
| print(f"✅ キャッシュから埋め込みを取得: {avatar_token[:8]}...") | |
| # 解像度最適化設定を適用 | |
| if use_resolution_optimization: | |
| setup_kwargs = { | |
| "max_size": FIXED_RESOLUTION, # 320固定 | |
| "sampling_timesteps": resolution_optimizer.get_diffusion_steps() # 25 | |
| } | |
| print(f"✅ 解像度最適化適用: {FIXED_RESOLUTION}×{FIXED_RESOLUTION}, ステップ数: {setup_kwargs['sampling_timesteps']}") | |
| else: | |
| setup_kwargs = {} | |
| # 処理方法の選択 | |
| if use_parallel_processing and source_image: | |
| # 並列処理を使用 | |
| print("🔄 並列処理モードで実行...") | |
| if use_inference_cache: | |
| # キャッシュ + 並列処理 | |
| def inference_func(audio_path, image_path, out_path, **kwargs): | |
| # 並列処理ラッパーを使用 | |
| optimized_wrapper.process( | |
| audio_path, image_path, out_path, | |
| seed=1024, | |
| more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})} | |
| ) | |
| # キャッシュシステムを通じて処理 | |
| result_path, cache_hit, process_time = cached_inference.process_with_cache( | |
| inference_func, | |
| audio_file, | |
| source_image, | |
| output_path, | |
| resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default", | |
| steps=setup_kwargs.get('sampling_timesteps', 50), | |
| setup_kwargs=setup_kwargs | |
| ) | |
| cache_status = "キャッシュヒット(並列)" if cache_hit else "新規生成(並列)" | |
| else: | |
| # 並列処理のみ | |
| _, process_time, stats = optimized_wrapper.process( | |
| audio_file, source_image, output_path, | |
| seed=1024, | |
| more_kwargs={"setup_kwargs": setup_kwargs} | |
| ) | |
| cache_hit = False | |
| cache_status = "並列処理(キャッシュ未使用)" | |
| elif use_inference_cache and source_image: | |
| # キャッシュのみ(並列処理なし) | |
| def inference_func(audio_path, image_path, out_path, **kwargs): | |
| seed_everything(1024) | |
| run(SDK, audio_path, image_path, out_path, | |
| more_kwargs={"setup_kwargs": kwargs.get('setup_kwargs', {})}) | |
| # キャッシュシステムを通じて処理 | |
| result_path, cache_hit, process_time = cached_inference.process_with_cache( | |
| inference_func, | |
| audio_file, | |
| source_image, | |
| output_path, | |
| resolution=f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" if use_resolution_optimization else "default", | |
| steps=setup_kwargs.get('sampling_timesteps', 50), | |
| setup_kwargs=setup_kwargs | |
| ) | |
| cache_status = "キャッシュヒット" if cache_hit else "新規生成" | |
| else: | |
| # 通常処理(並列処理もキャッシュもなし) | |
| print(f"処理開始: audio={audio_file}, image={source_image}, token={avatar_token is not None}") | |
| seed_everything(1024) | |
| run(SDK, audio_file, source_image, output_path, more_kwargs={"setup_kwargs": setup_kwargs}) | |
| process_time = time.time() - start_time | |
| cache_hit = False | |
| cache_status = "通常処理" | |
| # 結果の確認 | |
| if os.path.exists(output_path) and os.path.getsize(output_path) > 0: | |
| # パフォーマンス統計 | |
| perf_info = f""" | |
| ✅ 処理完了! | |
| 処理時間: {process_time:.2f}秒 | |
| 解像度: {FIXED_RESOLUTION}×{FIXED_RESOLUTION} | |
| 最適化設定: | |
| - 解像度最適化: {'有効' if use_resolution_optimization else '無効'} | |
| - 並列処理: {'有効' if use_parallel_processing else '無効'} | |
| - アバターキャッシュ: {'使用' if avatar_token else '未使用'} | |
| - 推論キャッシュ: {cache_status} | |
| キャッシュ統計: {inference_cache.get_cache_stats()['memory_cache_entries']}件(メモリ), {inference_cache.get_cache_stats()['file_cache_entries']}件(ファイル) | |
| """ | |
| return output_path, perf_info | |
| else: | |
| return None, "❌ 処理に失敗しました。出力ファイルが生成されませんでした。" | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ エラーが発生しました: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return None, error_msg | |
| # Gradio UI(最適化版) | |
| with gr.Blocks(title="DittoTalkingHead - Phase 3 最適化版") as demo: | |
| gr.Markdown(""" | |
| # DittoTalkingHead - Phase 3 高速化実装 | |
| **🚀 最適化機能:** | |
| - 📐 解像度320×320固定による高速化 | |
| - 🎯 画像事前アップロード&キャッシュ機能 | |
| - ⚡ GPU最適化(Mixed Precision, torch.compile) | |
| - 💾 Cold Start最適化 | |
| - 🔄 推論キャッシュ(同じ入力で即座に結果を返す) | |
| - 🚀 並列処理(音声・画像の前処理を並列化) | |
| ## 使い方 | |
| ### 方法1: 通常の使用 | |
| 1. 音声ファイル(WAV)と画像をアップロード | |
| 2. 「生成」ボタンをクリック | |
| ### 方法2: 高速化(推奨) | |
| 1. 「アバター準備」タブで画像を事前アップロード | |
| 2. 生成されたトークンをコピー | |
| 3. 「動画生成」タブで音声とトークンを使用 | |
| """) | |
| with gr.Tabs(): | |
| # タブ1: 通常の動画生成 | |
| with gr.TabItem("🎬 動画生成"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="音声ファイル (WAV)", | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image( | |
| label="ソース画像(オプション)", | |
| type="filepath" | |
| ) | |
| token_input = gr.Textbox( | |
| label="アバタートークン(オプション)", | |
| placeholder="事前準備したトークンを入力", | |
| lines=1 | |
| ) | |
| use_optimization = gr.Checkbox( | |
| label="解像度最適化を使用(320×320)", | |
| value=True | |
| ) | |
| use_cache = gr.Checkbox( | |
| label="推論キャッシュを使用(同じ入力で高速化)", | |
| value=True | |
| ) | |
| use_parallel = gr.Checkbox( | |
| label="並列処理を使用(前処理を高速化)", | |
| value=True | |
| ) | |
| generate_btn = gr.Button("🎬 生成", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video( | |
| label="生成されたビデオ" | |
| ) | |
| status_output = gr.Textbox( | |
| label="ステータス", | |
| lines=6 | |
| ) | |
| # タブ2: アバター準備 | |
| with gr.TabItem("👤 アバター準備"): | |
| gr.Markdown(""" | |
| ### 画像を事前にアップロードして高速化 | |
| 画像の埋め込みベクトルを事前計算し、トークンとして保存します。 | |
| このトークンを使用することで、動画生成時の処理時間を短縮できます。 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| avatar_image_input = gr.Image( | |
| label="アバター画像", | |
| type="filepath" | |
| ) | |
| prepare_btn = gr.Button("📤 アバター準備", variant="primary") | |
| with gr.Column(): | |
| prepare_output = gr.JSON( | |
| label="準備結果" | |
| ) | |
| # タブ3: 最適化情報 | |
| with gr.TabItem("📊 最適化情報"): | |
| with gr.Row(): | |
| refresh_btn = gr.Button("🔄 情報を更新", scale=1) | |
| info_display = gr.Markdown(f""" | |
| ### 現在の最適化設定 | |
| {resolution_optimizer.get_optimization_summary()} | |
| {gpu_optimizer.get_optimization_summary()} | |
| ### アバターキャッシュ情報 | |
| {avatar_cache.get_cache_info()} | |
| ### 推論キャッシュ情報 | |
| {inference_cache.get_cache_stats()} | |
| """) | |
| # キャッシュ管理ボタン | |
| with gr.Row(): | |
| clear_inference_cache_btn = gr.Button("🗑️ 推論キャッシュをクリア", variant="secondary") | |
| clear_avatar_cache_btn = gr.Button("🗑️ アバターキャッシュをクリア", variant="secondary") | |
| cleanup_status_btn = gr.Button("📊 クリーンアップ状態", variant="secondary") | |
| cache_status = gr.Textbox(label="キャッシュ操作ステータス", lines=2) | |
| # サンプル | |
| example_audio = EXAMPLES_DIR / "audio.wav" | |
| example_image = EXAMPLES_DIR / "image.png" | |
| if example_audio.exists() and example_image.exists(): | |
| gr.Examples( | |
| examples=[ | |
| [str(example_audio), str(example_image), None, True, True, True] | |
| ], | |
| inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel], | |
| outputs=[video_output, status_output], | |
| fn=process_talking_head_optimized | |
| ) | |
| # イベントハンドラ | |
| generate_btn.click( | |
| fn=process_talking_head_optimized, | |
| inputs=[audio_input, image_input, token_input, use_optimization, use_cache, use_parallel], | |
| outputs=[video_output, status_output] | |
| ) | |
| prepare_btn.click( | |
| fn=prepare_avatar, | |
| inputs=[avatar_image_input], | |
| outputs=[prepare_output] | |
| ) | |
| # キャッシュ管理関数 | |
| def refresh_info(): | |
| return f""" | |
| ### 現在の最適化設定 | |
| {resolution_optimizer.get_optimization_summary()} | |
| {gpu_optimizer.get_optimization_summary()} | |
| ### アバターキャッシュ情報 | |
| {avatar_cache.get_cache_info()} | |
| ### 推論キャッシュ情報 | |
| {inference_cache.get_cache_stats()} | |
| ### 並列処理情報 | |
| {parallel_inference.get_performance_stats()} | |
| """ | |
| def clear_inference_cache(): | |
| inference_cache.clear_cache() | |
| return "✅ 推論キャッシュをクリアしました" | |
| def clear_avatar_cache(): | |
| avatar_cache.clear_cache() | |
| return "✅ アバターキャッシュをクリアしました" | |
| # キャッシュ管理イベント | |
| refresh_btn.click( | |
| fn=refresh_info, | |
| outputs=[info_display] | |
| ) | |
| clear_inference_cache_btn.click( | |
| fn=clear_inference_cache, | |
| outputs=[cache_status] | |
| ) | |
| clear_avatar_cache_btn.click( | |
| fn=clear_avatar_cache, | |
| outputs=[cache_status] | |
| ) | |
| cleanup_status_btn.click( | |
| fn=lambda: get_cleanup_status(), | |
| outputs=[cache_status] | |
| ) | |
| if __name__ == "__main__": | |
| # Cold Start最適化設定でGradioを起動 | |
| launch_settings = cold_start_optimizer.optimize_gradio_settings() | |
| # allowed_pathsを追加 | |
| launch_settings['allowed_paths'] = [str(EXAMPLES_DIR), str(OUTPUT_DIR)] | |
| demo.launch(**launch_settings) |