Spaces:
Runtime error
Runtime error
| """ | |
| DittoTalkingHead Streaming Client | |
| WebSocketを使用したストリーミングクライアントの実装例 | |
| """ | |
| import asyncio | |
| import websockets | |
| import numpy as np | |
| import soundfile as sf | |
| import base64 | |
| import json | |
| import cv2 | |
| from typing import Optional, Callable | |
| import pyaudio | |
| import threading | |
| import queue | |
| from pathlib import Path | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DittoStreamingClient: | |
| """DittoTalkingHeadストリーミングクライアント""" | |
| def __init__(self, server_url="ws://localhost:8000/ws/generate"): | |
| self.server_url = server_url | |
| self.sample_rate = 16000 | |
| self.chunk_duration = 0.2 # 200ms | |
| self.chunk_size = int(self.sample_rate * self.chunk_duration) | |
| self.websocket = None | |
| self.is_connected = False | |
| self.frame_callback: Optional[Callable] = None | |
| self.final_video_callback: Optional[Callable] = None | |
| async def connect(self, source_image_path: str): | |
| """サーバーに接続してセッションを開始""" | |
| try: | |
| # 画像をBase64エンコード | |
| with open(source_image_path, "rb") as f: | |
| image_b64 = base64.b64encode(f.read()).decode('utf-8') | |
| # WebSocket接続 | |
| self.websocket = await websockets.connect(self.server_url) | |
| self.is_connected = True | |
| # 初期設定を送信 | |
| await self.websocket.send(json.dumps({ | |
| "source_image": image_b64, | |
| "sample_rate": self.sample_rate, | |
| "chunk_duration": self.chunk_duration | |
| })) | |
| # 応答を待つ | |
| response = await self.websocket.recv() | |
| data = json.loads(response) | |
| if data["type"] == "ready": | |
| logger.info(f"Connected to server: {data['message']}") | |
| return True | |
| else: | |
| logger.error(f"Connection failed: {data}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Connection error: {e}") | |
| self.is_connected = False | |
| raise | |
| async def disconnect(self): | |
| """接続を切断""" | |
| if self.websocket: | |
| await self.websocket.close() | |
| self.is_connected = False | |
| logger.info("Disconnected from server") | |
| async def stream_audio_file(self, audio_path: str, source_image_path: str): | |
| """音声ファイルをストリーミング""" | |
| try: | |
| # 接続 | |
| await self.connect(source_image_path) | |
| # 音声を読み込み | |
| audio_data, sr = sf.read(audio_path) | |
| if sr != self.sample_rate: | |
| import librosa | |
| audio_data = librosa.resample( | |
| audio_data, | |
| orig_sr=sr, | |
| target_sr=self.sample_rate | |
| ) | |
| # フレーム受信タスク | |
| receive_task = asyncio.create_task(self._receive_frames()) | |
| # 音声をチャンク単位で送信 | |
| total_chunks = 0 | |
| for i in range(0, len(audio_data), self.chunk_size): | |
| chunk = audio_data[i:i+self.chunk_size] | |
| if len(chunk) < self.chunk_size: | |
| chunk = np.pad(chunk, (0, self.chunk_size - len(chunk))) | |
| # Float32として送信 | |
| await self.websocket.send(chunk.astype(np.float32).tobytes()) | |
| total_chunks += 1 | |
| # リアルタイムシミュレーション | |
| await asyncio.sleep(self.chunk_duration) | |
| # 進捗表示 | |
| progress = (i + self.chunk_size) / len(audio_data) * 100 | |
| logger.info(f"Streaming progress: {progress:.1f}%") | |
| # 停止コマンドを送信 | |
| await self.websocket.send(json.dumps({"action": "stop"})) | |
| logger.info(f"Sent {total_chunks} audio chunks") | |
| # フレーム受信を待つ | |
| await receive_task | |
| finally: | |
| await self.disconnect() | |
| async def stream_microphone(self, source_image_path: str, duration: Optional[float] = None): | |
| """マイクからリアルタイムストリーミング""" | |
| try: | |
| # 接続 | |
| await self.connect(source_image_path) | |
| # フレーム受信タスク | |
| receive_task = asyncio.create_task(self._receive_frames()) | |
| # マイク録音用のキュー | |
| audio_queue = queue.Queue() | |
| stop_event = threading.Event() | |
| # マイク録音スレッド | |
| def record_audio(): | |
| p = pyaudio.PyAudio() | |
| stream = p.open( | |
| format=pyaudio.paFloat32, | |
| channels=1, | |
| rate=self.sample_rate, | |
| input=True, | |
| frames_per_buffer=self.chunk_size | |
| ) | |
| logger.info("Recording started... Press Ctrl+C to stop") | |
| try: | |
| start_time = asyncio.get_event_loop().time() | |
| while not stop_event.is_set(): | |
| if duration and (asyncio.get_event_loop().time() - start_time) > duration: | |
| break | |
| audio_chunk = stream.read(self.chunk_size, exception_on_overflow=False) | |
| audio_queue.put(audio_chunk) | |
| except Exception as e: | |
| logger.error(f"Recording error: {e}") | |
| finally: | |
| stream.stop_stream() | |
| stream.close() | |
| p.terminate() | |
| logger.info("Recording stopped") | |
| # 録音スレッドを開始 | |
| record_thread = threading.Thread(target=record_audio) | |
| record_thread.start() | |
| try: | |
| # 音声データを送信 | |
| while record_thread.is_alive() or not audio_queue.empty(): | |
| try: | |
| audio_chunk = audio_queue.get(timeout=0.1) | |
| audio_array = np.frombuffer(audio_chunk, dtype=np.float32) | |
| await self.websocket.send(audio_array.tobytes()) | |
| except queue.Empty: | |
| continue | |
| except KeyboardInterrupt: | |
| logger.info("Stopping recording...") | |
| break | |
| finally: | |
| stop_event.set() | |
| record_thread.join() | |
| # 停止コマンドを送信 | |
| await self.websocket.send(json.dumps({"action": "stop"})) | |
| # フレーム受信を待つ | |
| await receive_task | |
| finally: | |
| await self.disconnect() | |
| async def _receive_frames(self): | |
| """フレームとメッセージを受信""" | |
| frame_count = 0 | |
| try: | |
| while True: | |
| message = await self.websocket.recv() | |
| data = json.loads(message) | |
| if data["type"] == "frame": | |
| frame_count += 1 | |
| logger.info(f"Received frame {data['frame_id']} (FPS: {data.get('fps', 0)})") | |
| if self.frame_callback: | |
| # フレームをデコード | |
| frame_data = base64.b64decode(data["data"]) | |
| nparr = np.frombuffer(frame_data, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| self.frame_callback(frame, data) | |
| elif data["type"] == "progress": | |
| logger.info(f"Progress: {data['duration_seconds']:.1f}s processed") | |
| elif data["type"] == "processing": | |
| logger.info(f"Server: {data['message']}") | |
| elif data["type"] == "final_video": | |
| logger.info(f"Received final video ({data['size_bytes']} bytes, {data['duration_seconds']:.1f}s)") | |
| if self.final_video_callback: | |
| video_data = base64.b64decode(data["data"]) | |
| self.final_video_callback(video_data, data) | |
| break | |
| elif data["type"] == "error": | |
| logger.error(f"Server error: {data['message']}") | |
| break | |
| except websockets.exceptions.ConnectionClosed: | |
| logger.info("Connection closed by server") | |
| except Exception as e: | |
| logger.error(f"Receive error: {e}") | |
| logger.info(f"Total frames received: {frame_count}") | |
| def set_frame_callback(self, callback: Callable): | |
| """フレーム受信時のコールバックを設定""" | |
| self.frame_callback = callback | |
| def set_final_video_callback(self, callback: Callable): | |
| """最終動画受信時のコールバックを設定""" | |
| self.final_video_callback = callback | |
| # 使用例とテスト | |
| async def main(): | |
| """使用例""" | |
| client = DittoStreamingClient() | |
| # フレーム表示用のコールバック | |
| def display_frame(frame, metadata): | |
| cv2.imshow("Live Frame", frame) | |
| cv2.waitKey(1) | |
| # 最終動画保存用のコールバック | |
| def save_video(video_data, metadata): | |
| output_path = "output_streaming.mp4" | |
| with open(output_path, "wb") as f: | |
| f.write(video_data) | |
| logger.info(f"Video saved to {output_path}") | |
| client.set_frame_callback(display_frame) | |
| client.set_final_video_callback(save_video) | |
| # テスト画像とサンプル音声のパス | |
| source_image = "example/reference.png" | |
| audio_file = "example/audio.wav" | |
| # ファイルが存在するか確認 | |
| if not Path(source_image).exists(): | |
| logger.error(f"Source image not found: {source_image}") | |
| return | |
| # 音声ファイルからストリーミング | |
| if Path(audio_file).exists(): | |
| logger.info("=== Testing audio file streaming ===") | |
| await client.stream_audio_file(audio_file, source_image) | |
| else: | |
| logger.warning(f"Audio file not found: {audio_file}") | |
| # マイクからストリーミング(5秒間) | |
| # logger.info("\n=== Testing microphone streaming (5 seconds) ===") | |
| # await client.stream_microphone(source_image, duration=5.0) | |
| cv2.destroyAllWindows() | |
| # バッチ処理クライアント | |
| class BatchStreamingClient: | |
| """複数のリクエストを並列処理するクライアント""" | |
| def __init__(self, server_url="ws://localhost:8000/ws/generate", max_parallel=3): | |
| self.server_url = server_url | |
| self.max_parallel = max_parallel | |
| async def process_batch(self, tasks: list): | |
| """バッチ処理""" | |
| semaphore = asyncio.Semaphore(self.max_parallel) | |
| async def process_with_limit(task): | |
| async with semaphore: | |
| client = DittoStreamingClient(self.server_url) | |
| await client.stream_audio_file( | |
| task["audio_path"], | |
| task["image_path"] | |
| ) | |
| return task["id"] | |
| results = await asyncio.gather( | |
| *[process_with_limit(task) for task in tasks], | |
| return_exceptions=True | |
| ) | |
| return results | |
| if __name__ == "__main__": | |
| # 単一クライアントのテスト | |
| asyncio.run(main()) | |
| # バッチ処理の例 | |
| # batch_client = BatchStreamingClient() | |
| # tasks = [ | |
| # {"id": 1, "audio_path": "audio1.wav", "image_path": "image1.png"}, | |
| # {"id": 2, "audio_path": "audio2.wav", "image_path": "image2.png"}, | |
| # ] | |
| # asyncio.run(batch_client.process_batch(tasks)) |