Spaces:
Runtime error
Runtime error
| """ | |
| FastAPI server for DittoTalkingHead with Phase 3 optimizations | |
| Implements /prepare_avatar and /generate_video endpoints | |
| """ | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import os | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| import torch | |
| import time | |
| from typing import Optional, Dict, Any | |
| import io | |
| import asyncio | |
| from datetime import datetime | |
| import uvicorn | |
| from model_manager import ModelManager | |
| from core.optimization import ( | |
| FixedResolutionProcessor, | |
| GPUOptimizer, | |
| AvatarCache, | |
| AvatarTokenManager, | |
| ColdStartOptimizer | |
| ) | |
| # FastAPIアプリケーションの初期化 | |
| app = FastAPI( | |
| title="DittoTalkingHead API", | |
| description="High-performance talking head generation API with Phase 3 optimizations", | |
| version="3.0.0" | |
| ) | |
| # CORS設定 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # グローバル初期化 | |
| print("=== API Server Phase 3 - 初期化開始 ===") | |
| # 1. 解像度最適化 | |
| resolution_optimizer = FixedResolutionProcessor() | |
| FIXED_RESOLUTION = resolution_optimizer.get_max_dim() | |
| # 2. GPU最適化 | |
| gpu_optimizer = GPUOptimizer() | |
| # 3. Cold Start最適化 | |
| cold_start_optimizer = ColdStartOptimizer(persistent_dir="/tmp/persistent_model_cache") | |
| # 4. アバターキャッシュ | |
| avatar_cache = AvatarCache(cache_dir="/tmp/avatar_cache", ttl_days=14) | |
| token_manager = AvatarTokenManager(avatar_cache) | |
| # モデルとSDKの初期化 | |
| USE_PYTORCH = True | |
| model_manager = ModelManager(cache_dir="/tmp/ditto_models", use_pytorch=USE_PYTORCH) | |
| SDK = None | |
| # 初期化処理 | |
| async def startup_event(): | |
| """アプリケーション起動時の初期化""" | |
| global SDK | |
| print("Starting model initialization...") | |
| # Cold start最適化 | |
| cold_start_optimizer.setup_persistent_model_cache("./checkpoints") | |
| # モデルセットアップ | |
| if not model_manager.setup_models(): | |
| raise RuntimeError("Failed to setup models") | |
| # SDK初期化 | |
| if USE_PYTORCH: | |
| data_root = "./checkpoints/ditto_pytorch" | |
| cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
| else: | |
| data_root = "./checkpoints/ditto_trt_Ampere_Plus" | |
| cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" | |
| try: | |
| from stream_pipeline_offline import StreamSDK | |
| SDK = StreamSDK(cfg_pkl, data_root) | |
| # GPU最適化を適用(torch.nn.Moduleの場合のみ) | |
| if hasattr(SDK, 'decode_f3d') and hasattr(SDK.decode_f3d, 'decoder'): | |
| try: | |
| import torch.nn as nn | |
| if isinstance(SDK.decode_f3d.decoder, nn.Module): | |
| SDK.decode_f3d.decoder = gpu_optimizer.optimize_model(SDK.decode_f3d.decoder) | |
| print("✅ Decoder model optimized") | |
| else: | |
| print("ℹ️ Decoder is not nn.Module, skipping optimization") | |
| except Exception as e: | |
| print(f"⚠️ Skipping GPU optimization: {e}") | |
| print("✅ SDK initialized with optimizations") | |
| except Exception as e: | |
| print(f"❌ SDK initialization error: {e}") | |
| raise | |
| # ヘルスチェックエンドポイント | |
| async def health_check(): | |
| """サーバーの状態を確認""" | |
| return { | |
| "status": "healthy", | |
| "gpu_available": torch.cuda.is_available(), | |
| "cache_info": avatar_cache.get_cache_info(), | |
| "optimization_enabled": True | |
| } | |
| # アバター準備エンドポイント | |
| async def prepare_avatar(file: UploadFile = File(...)): | |
| """ | |
| 画像を事前にアップロードして埋め込みを生成 | |
| Args: | |
| file: アップロードされた画像ファイル | |
| Returns: | |
| avatar_token と有効期限 | |
| """ | |
| # ファイル検証 | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| # 画像データを読み込む | |
| image_data = await file.read() | |
| # 画像を処理して埋め込みを生成 | |
| from PIL import Image | |
| import numpy as np | |
| # 画像を読み込んで前処理 | |
| img = Image.open(io.BytesIO(image_data)) | |
| img = img.convert('RGB') | |
| img = img.resize((FIXED_RESOLUTION, FIXED_RESOLUTION)) | |
| # 外観エンコーダーで埋め込みを生成(簡略化版) | |
| # TODO: 実際のappearance_extractorを使用 | |
| def encode_appearance(img_data): | |
| # ここでSDKの外観抽出機能を使用 | |
| import numpy as np | |
| # 仮の埋め込みベクトル生成 | |
| # 実際の実装では、SDKのappearance_extractorを使用 | |
| embedding = np.random.randn(512).astype(np.float32) | |
| return embedding | |
| # トークンを生成 | |
| result = token_manager.prepare_avatar( | |
| image_data, | |
| encode_appearance | |
| ) | |
| return JSONResponse(content={ | |
| "avatar_token": result['avatar_token'], | |
| "expires": result['expires'], | |
| "cached": result['cached'], | |
| "resolution": f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" | |
| }) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # 動画生成エンドポイント | |
| async def generate_video( | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| avatar_token: Optional[str] = None, | |
| avatar_image: Optional[UploadFile] = None | |
| ): | |
| """ | |
| 音声とavatar_tokenから動画を生成 | |
| Args: | |
| file: 音声ファイル(WAV) | |
| avatar_token: 事前生成されたアバタートークン(オプション) | |
| avatar_image: アバター画像(avatar_tokenがない場合) | |
| Returns: | |
| 生成された動画(MP4) | |
| """ | |
| # 音声ファイル検証 | |
| if not file.content_type.startswith("audio/"): | |
| raise HTTPException(status_code=400, detail="File must be an audio file") | |
| # アバター入力の検証 | |
| if avatar_token is None and avatar_image is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Either avatar_token or avatar_image must be provided" | |
| ) | |
| try: | |
| start_time = time.time() | |
| # 一時ファイルを作成 | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: | |
| audio_content = await file.read() | |
| tmp_audio.write(audio_content) | |
| audio_path = tmp_audio.name | |
| # アバター処理 | |
| if avatar_token: | |
| # キャッシュから埋め込みを取得 | |
| embedding = avatar_cache.load_embedding(avatar_token) | |
| if embedding is None: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid or expired avatar_token" | |
| ) | |
| print(f"✅ Using cached embedding: {avatar_token[:8]}...") | |
| # 仮の画像パス(SDKの要求に応じて) | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: | |
| # ダミー画像を作成(実際はキャッシュされた埋め込みを使用) | |
| from PIL import Image | |
| dummy_img = Image.new('RGB', (FIXED_RESOLUTION, FIXED_RESOLUTION), 'white') | |
| dummy_img.save(tmp_img.name) | |
| image_path = tmp_img.name | |
| else: | |
| # 画像を一時保存 | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: | |
| img_content = await avatar_image.read() | |
| tmp_img.write(img_content) | |
| image_path = tmp_img.name | |
| # 出力ファイル | |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output: | |
| output_path = tmp_output.name | |
| # 解像度最適化設定 | |
| setup_kwargs = { | |
| "max_size": FIXED_RESOLUTION, | |
| "sampling_timesteps": resolution_optimizer.get_diffusion_steps() | |
| } | |
| # 動画生成を実行 | |
| from inference import run, seed_everything | |
| seed_everything(1024) | |
| # 非同期実行のためのラッパー | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor( | |
| None, | |
| run, | |
| SDK, | |
| audio_path, | |
| image_path, | |
| output_path, | |
| {"setup_kwargs": setup_kwargs} | |
| ) | |
| # 処理時間 | |
| process_time = time.time() - start_time | |
| print(f"✅ Video generated in {process_time:.2f}s") | |
| # クリーンアップをバックグラウンドで実行 | |
| def cleanup_files(): | |
| try: | |
| os.unlink(audio_path) | |
| os.unlink(image_path) | |
| # output_pathは返却後に削除 | |
| except: | |
| pass | |
| background_tasks.add_task(cleanup_files) | |
| # 動画をストリーミング返却 | |
| def iterfile(): | |
| with open(output_path, 'rb') as f: | |
| yield from f | |
| # ファイルを削除 | |
| try: | |
| os.unlink(output_path) | |
| except: | |
| pass | |
| return StreamingResponse( | |
| iterfile(), | |
| media_type="video/mp4", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=talking_head_{int(time.time())}.mp4", | |
| "X-Process-Time": str(process_time), | |
| "X-Resolution": f"{FIXED_RESOLUTION}x{FIXED_RESOLUTION}" | |
| } | |
| ) | |
| except Exception as e: | |
| # エラー時のクリーンアップ | |
| for path in [audio_path, image_path, output_path]: | |
| try: | |
| if 'path' in locals() and os.path.exists(path): | |
| os.unlink(path) | |
| except: | |
| pass | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # キャッシュ情報エンドポイント | |
| async def get_cache_info(): | |
| """キャッシュの統計情報を取得""" | |
| return { | |
| "avatar_cache": avatar_cache.get_cache_info(), | |
| "gpu_memory": gpu_optimizer.get_memory_stats(), | |
| "cold_start_stats": cold_start_optimizer.get_optimization_stats() | |
| } | |
| # トークン検証エンドポイント | |
| async def validate_token(token: str): | |
| """アバタートークンの有効性を確認""" | |
| info = token_manager.get_token_info(token) | |
| if info is None: | |
| raise HTTPException(status_code=404, detail="Token not found") | |
| return info | |
| # パフォーマンステストエンドポイント | |
| async def run_benchmark(duration_seconds: int = 16): | |
| """ | |
| パフォーマンステストを実行 | |
| Args: | |
| duration_seconds: テスト音声の長さ(秒) | |
| """ | |
| try: | |
| # ダミーの音声と画像を生成 | |
| import numpy as np | |
| from scipy.io import wavfile | |
| from PIL import Image | |
| # テスト音声生成(無音) | |
| sample_rate = 16000 | |
| audio_data = np.zeros(duration_seconds * sample_rate, dtype=np.float32) | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: | |
| wavfile.write(tmp_audio.name, sample_rate, audio_data) | |
| audio_path = tmp_audio.name | |
| # テスト画像生成 | |
| test_img = Image.new('RGB', (FIXED_RESOLUTION, FIXED_RESOLUTION), 'white') | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: | |
| test_img.save(tmp_img.name) | |
| image_path = tmp_img.name | |
| # 出力パス | |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_output: | |
| output_path = tmp_output.name | |
| # ベンチマーク実行 | |
| start_time = time.time() | |
| from inference import run, seed_everything | |
| seed_everything(1024) | |
| setup_kwargs = { | |
| "max_size": FIXED_RESOLUTION, | |
| "sampling_timesteps": resolution_optimizer.get_diffusion_steps() | |
| } | |
| run(SDK, audio_path, image_path, output_path, {"setup_kwargs": setup_kwargs}) | |
| process_time = time.time() - start_time | |
| # クリーンアップ | |
| for path in [audio_path, image_path, output_path]: | |
| try: | |
| os.unlink(path) | |
| except: | |
| pass | |
| # パフォーマンス検証 | |
| perf_result = resolution_optimizer.validate_performance_improvement( | |
| original_time=duration_seconds * 1.9, # 元の処理時間(推定) | |
| optimized_time=process_time | |
| ) | |
| return { | |
| "audio_duration_seconds": duration_seconds, | |
| "process_time_seconds": process_time, | |
| "realtime_factor": process_time / duration_seconds, | |
| "performance": perf_result, | |
| "optimization_config": resolution_optimizer.get_performance_config() | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| # サーバー起動 | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| workers=1, # GPUを使用するため単一ワーカー | |
| log_level="info" | |
| ) |