Spaces:
Runtime error
Runtime error
| """ | |
| DittoTalkingHead Streaming API Server | |
| WebSocket/SSEによるリアルタイムストリーミング実装 | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, UploadFile, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import asyncio | |
| import tempfile | |
| import numpy as np | |
| import base64 | |
| import json | |
| from typing import AsyncGenerator, Optional | |
| import cv2 | |
| import time | |
| import logging | |
| from pathlib import Path | |
| import traceback | |
| from stream_pipeline_offline import StreamSDK | |
| # ログ設定 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="DittoTalkingHead Streaming API") | |
| # CORS設定 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # SDK設定 | |
| CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
| DATA_ROOT = "checkpoints/ditto_pytorch" | |
| # グローバル設定 | |
| class AppState: | |
| def __init__(self): | |
| self.sdk: Optional[StreamSDK] = None | |
| self.active_connections: int = 0 | |
| self.max_connections: int = 5 | |
| state = AppState() | |
| def init_sdk(): | |
| """SDKの初期化""" | |
| if state.sdk is None: | |
| logger.info("Initializing StreamSDK...") | |
| state.sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
| logger.info("StreamSDK initialized successfully") | |
| return state.sdk | |
| async def startup_event(): | |
| """起動時にSDKを初期化""" | |
| init_sdk() | |
| async def root(): | |
| """ヘルスチェック""" | |
| return { | |
| "status": "ok", | |
| "service": "DittoTalkingHead Streaming API", | |
| "active_connections": state.active_connections, | |
| "max_connections": state.max_connections | |
| } | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocketエンドポイント - リアルタイムストリーミング""" | |
| # 接続数チェック | |
| if state.active_connections >= state.max_connections: | |
| await websocket.close(code=1008, reason="Server busy") | |
| return | |
| await websocket.accept() | |
| state.active_connections += 1 | |
| logger.info(f"New WebSocket connection. Active: {state.active_connections}") | |
| sdk_instance = None | |
| output_path = None | |
| try: | |
| # 初期設定を受信 | |
| config = await websocket.receive_json() | |
| source_image_b64 = config.get("source_image") | |
| sample_rate = config.get("sample_rate", 16000) | |
| chunk_duration = config.get("chunk_duration", 0.2) | |
| if not source_image_b64: | |
| await websocket.send_json({"type": "error", "message": "source_image is required"}) | |
| return | |
| # 画像をデコードして一時ファイルに保存 | |
| image_data = base64.b64decode(source_image_b64) | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img: | |
| tmp_img.write(image_data) | |
| source_path = tmp_img.name | |
| # 出力ファイルの準備 | |
| output_path = tempfile.mktemp(suffix=".mp4") | |
| # SDK設定 | |
| sdk_instance = init_sdk() | |
| sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024) | |
| await websocket.send_json({ | |
| "type": "ready", | |
| "message": "Ready to receive audio chunks", | |
| "chunk_size": int(sample_rate * chunk_duration) | |
| }) | |
| # フレーム送信タスク | |
| async def send_frames(): | |
| frame_count = 0 | |
| last_frame_time = time.time() | |
| while True: | |
| try: | |
| current_time = time.time() | |
| if sdk_instance.writer_queue.qsize() > 0: | |
| frame = sdk_instance.writer_queue.get_nowait() | |
| if frame is not None: | |
| # フレームをJPEGエンコード(品質調整可能) | |
| encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] | |
| _, jpeg = cv2.imencode('.jpg', | |
| cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), | |
| encode_param) | |
| frame_b64 = base64.b64encode(jpeg).decode('utf-8') | |
| # FPS計算 | |
| fps = 1.0 / (current_time - last_frame_time) if current_time > last_frame_time else 0 | |
| last_frame_time = current_time | |
| await websocket.send_json({ | |
| "type": "frame", | |
| "frame_id": frame_count, | |
| "timestamp": current_time, | |
| "fps": round(fps, 2), | |
| "data": frame_b64 | |
| }) | |
| frame_count += 1 | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"Error sending frame: {e}") | |
| await asyncio.sleep(0.01) # 10ms間隔でチェック | |
| # フレーム送信タスクを開始 | |
| frame_task = asyncio.create_task(send_frames()) | |
| # 音声チャンクを受信して処理 | |
| total_samples = 0 | |
| chunk_size = int(sample_rate * chunk_duration) | |
| processing_start = time.time() | |
| while True: | |
| message = await websocket.receive() | |
| if "bytes" in message: | |
| # 音声データを受信 | |
| audio_bytes = message["bytes"] | |
| audio_chunk = np.frombuffer(audio_bytes, dtype=np.float32) | |
| # パディング | |
| if len(audio_chunk) < chunk_size: | |
| audio_chunk = np.pad(audio_chunk, (0, chunk_size - len(audio_chunk))) | |
| # SDKに送信 | |
| sdk_instance.run_chunk(audio_chunk[:chunk_size]) | |
| total_samples += len(audio_chunk) | |
| # 進捗情報を送信 | |
| elapsed = time.time() - processing_start | |
| await websocket.send_json({ | |
| "type": "progress", | |
| "samples_processed": total_samples, | |
| "duration_seconds": total_samples / sample_rate, | |
| "elapsed_seconds": elapsed | |
| }) | |
| elif "text" in message: | |
| # コマンドを受信 | |
| command = json.loads(message["text"]) | |
| if command.get("action") == "stop": | |
| logger.info("Received stop command") | |
| break | |
| # 処理終了 | |
| frame_task.cancel() | |
| try: | |
| await frame_task | |
| except asyncio.CancelledError: | |
| pass | |
| # フレーム数を推定してsetup_Nd | |
| estimated_frames = int(total_samples / sample_rate * 20) | |
| sdk_instance.setup_Nd(estimated_frames) | |
| # 残りのフレームを処理 | |
| await websocket.send_json({"type": "processing", "message": "Finalizing video..."}) | |
| # SDKを閉じて最終MP4を生成 | |
| sdk_instance.close() | |
| # 最終的なMP4を送信 | |
| if Path(output_path).exists(): | |
| with open(output_path, "rb") as f: | |
| mp4_data = f.read() | |
| mp4_b64 = base64.b64encode(mp4_data).decode('utf-8') | |
| await websocket.send_json({ | |
| "type": "final_video", | |
| "size_bytes": len(mp4_data), | |
| "duration_seconds": total_samples / sample_rate, | |
| "data": mp4_b64 | |
| }) | |
| else: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Failed to generate final video" | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {e}") | |
| logger.error(traceback.format_exc()) | |
| try: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": str(e) | |
| }) | |
| except: | |
| pass | |
| finally: | |
| state.active_connections -= 1 | |
| logger.info(f"Connection closed. Active: {state.active_connections}") | |
| # クリーンアップ | |
| if output_path and Path(output_path).exists(): | |
| try: | |
| Path(output_path).unlink() | |
| except: | |
| pass | |
| async def sse_generate( | |
| source_image: UploadFile = File(...), | |
| sample_rate: int = 16000, | |
| max_duration: float = 10.0 | |
| ): | |
| """SSEエンドポイント - Server-Sent Eventsによるストリーミング""" | |
| if state.active_connections >= state.max_connections: | |
| raise HTTPException(status_code=503, detail="Server busy") | |
| state.active_connections += 1 | |
| async def generate() -> AsyncGenerator[str, None]: | |
| sdk_instance = None | |
| output_path = None | |
| try: | |
| # 画像を保存 | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img: | |
| content = await source_image.read() | |
| tmp_img.write(content) | |
| source_path = tmp_img.name | |
| output_path = tempfile.mktemp(suffix=".mp4") | |
| # SDK設定 | |
| sdk_instance = init_sdk() | |
| sdk_instance.setup(source_path, output_path, online_mode=True, max_size=1024) | |
| # イベント送信 | |
| yield f"data: {json.dumps({'type': 'start', 'message': 'Processing started'})}\n\n" | |
| # デモ用:ダミー音声を生成して処理 | |
| chunk_duration = 0.2 | |
| chunk_size = int(sample_rate * chunk_duration) | |
| num_chunks = int(max_duration / chunk_duration) | |
| for i in range(num_chunks): | |
| # ダミー音声チャンク(実際の実装では音声ストリームから取得) | |
| audio_chunk = np.random.randn(chunk_size).astype(np.float32) * 0.1 | |
| sdk_instance.run_chunk(audio_chunk) | |
| # フレームチェック | |
| if sdk_instance.writer_queue.qsize() > 0: | |
| try: | |
| frame = sdk_instance.writer_queue.get_nowait() | |
| if frame is not None: | |
| # サムネイル生成(低解像度) | |
| thumbnail = cv2.resize(frame, (160, 160)) | |
| _, jpeg = cv2.imencode('.jpg', cv2.cvtColor(thumbnail, cv2.COLOR_RGB2BGR)) | |
| frame_b64 = base64.b64encode(jpeg).decode('utf-8') | |
| yield f"data: {json.dumps({'type': 'thumbnail', 'frame_id': i, 'data': frame_b64})}\n\n" | |
| except: | |
| pass | |
| await asyncio.sleep(chunk_duration) | |
| # 完了 | |
| estimated_frames = num_chunks * 5 # 概算 | |
| sdk_instance.setup_Nd(estimated_frames) | |
| sdk_instance.close() | |
| yield f"data: {json.dumps({'type': 'complete', 'frames': estimated_frames})}\n\n" | |
| except Exception as e: | |
| logger.error(f"SSE error: {e}") | |
| yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" | |
| finally: | |
| state.active_connections -= 1 | |
| if output_path and Path(output_path).exists(): | |
| try: | |
| Path(output_path).unlink() | |
| except: | |
| pass | |
| return StreamingResponse( | |
| generate(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| } | |
| ) | |
| async def test_page(): | |
| """テスト用HTMLページ""" | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>DittoTalkingHead Streaming Test</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 20px; } | |
| .container { max-width: 800px; margin: 0 auto; } | |
| #live-frame { max-width: 100%; border: 1px solid #ccc; } | |
| #status { margin: 10px 0; padding: 10px; background: #f0f0f0; } | |
| .controls { margin: 20px 0; } | |
| button { padding: 10px 20px; margin: 5px; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>DittoTalkingHead Streaming Test</h1> | |
| <div class="controls"> | |
| <input type="file" id="source-image" accept="image/*"> | |
| <button id="start-btn">Start Streaming</button> | |
| <button id="stop-btn" disabled>Stop</button> | |
| </div> | |
| <div id="status">Ready</div> | |
| <img id="live-frame" style="display: none;"> | |
| <video id="final-video" controls style="display: none; width: 100%;"></video> | |
| </div> | |
| <script> | |
| // WebSocket実装はstreaming_api_guide.mdを参照 | |
| console.log('WebSocket endpoint: ws://localhost:8000/ws/generate'); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| from fastapi.responses import HTMLResponse | |
| return HTMLResponse(content=html_content) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import torch | |
| # GPU設定 | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.backends.cudnn.benchmark = True | |
| logger.info("Starting DittoTalkingHead Streaming API Server...") | |
| logger.info(f"GPU available: {torch.cuda.is_available()}") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| log_level="info", | |
| access_log=True | |
| ) |