Spaces:
Runtime error
Runtime error
| """ | |
| ストリーミング実装のテストスクリプト | |
| """ | |
| import numpy as np | |
| import soundfile as sf | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from stream_pipeline_offline import StreamSDK | |
| # テスト設定 | |
| CFG_PKL = "checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" | |
| DATA_ROOT = "checkpoints/ditto_pytorch" | |
| EXAMPLES_DIR = Path("example") | |
| def test_streaming(): | |
| """ストリーミング機能の基本テスト""" | |
| print("=== ストリーミング機能テスト開始 ===") | |
| # テスト用の音声を生成(3秒のサイン波) | |
| duration = 3.0 # seconds | |
| sample_rate = 16000 | |
| t = np.linspace(0, duration, int(sample_rate * duration)) | |
| audio_data = np.sin(2 * np.pi * 440 * t) * 0.5 # 440Hz | |
| # SDKの初期化 | |
| print("1. SDK初期化...") | |
| sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
| print("✅ SDK初期化完了") | |
| # セットアップ | |
| print("\n2. ストリーミングモードでセットアップ...") | |
| src_img = str(EXAMPLES_DIR / "reference.png") | |
| tmp_out = tempfile.mktemp(suffix=".mp4") | |
| sdk.setup(src_img, tmp_out, online_mode=True, max_size=1024) | |
| N_total = int(np.ceil(duration * 20)) # 20fps | |
| sdk.setup_Nd(N_total) | |
| print("✅ セットアップ完了") | |
| # チャンク単位で音声を送信 | |
| print("\n3. チャンク単位で音声送信...") | |
| chunk_sec = 0.2 # 200ms | |
| chunk_samples = int(sample_rate * chunk_sec) | |
| chunks_sent = 0 | |
| frames_received = 0 | |
| start_time = time.time() | |
| for i in range(0, len(audio_data), chunk_samples): | |
| chunk = audio_data[i:i + chunk_samples] | |
| if len(chunk) < chunk_samples: | |
| chunk = np.pad(chunk, (0, chunk_samples - len(chunk))) | |
| sdk.run_chunk(chunk) | |
| chunks_sent += 1 | |
| # キューからフレームを確認 | |
| while sdk.writer_queue.qsize() > 0: | |
| try: | |
| frame = sdk.writer_queue.get_nowait() | |
| if frame is not None: | |
| frames_received += 1 | |
| print(f" フレーム {frames_received} 受信 (チャンク {chunks_sent})") | |
| except: | |
| break | |
| time.sleep(0.05) # 少し待機 | |
| # 残りのフレームを待つ | |
| print("\n4. 残りのフレームを処理...") | |
| timeout = 5.0 # 5秒タイムアウト | |
| timeout_start = time.time() | |
| while time.time() - timeout_start < timeout: | |
| if sdk.writer_queue.qsize() > 0: | |
| try: | |
| frame = sdk.writer_queue.get_nowait() | |
| if frame is not None: | |
| frames_received += 1 | |
| print(f" フレーム {frames_received} 受信") | |
| except: | |
| pass | |
| else: | |
| time.sleep(0.1) | |
| # クローズ | |
| print("\n5. SDKクローズ...") | |
| sdk.close() | |
| elapsed = time.time() - start_time | |
| # 結果 | |
| print("\n=== テスト結果 ===") | |
| print(f"✅ 送信チャンク数: {chunks_sent}") | |
| print(f"✅ 受信フレーム数: {frames_received}") | |
| print(f"✅ 処理時間: {elapsed:.2f}秒") | |
| print(f"✅ 出力ファイル: {tmp_out}") | |
| # 期待される結果の確認 | |
| expected_frames = int(duration * 20) # 20fps | |
| if frames_received >= expected_frames * 0.8: # 80%以上 | |
| print("✅ テスト成功!") | |
| else: | |
| print(f"⚠️ 期待フレーム数 ({expected_frames}) に対して受信数が少ない") | |
| return True | |
| def test_writer_queue(): | |
| """writer_queueの動作確認""" | |
| print("\n=== writer_queue 動作確認 ===") | |
| sdk = StreamSDK(CFG_PKL, DATA_ROOT) | |
| # キューの存在確認 | |
| if hasattr(sdk, 'writer_queue'): | |
| print("✅ writer_queue が存在します") | |
| print(f" キューサイズ: {sdk.writer_queue.qsize()}") | |
| print(f" 最大サイズ: {sdk.writer_queue.maxsize}") | |
| else: | |
| print("❌ writer_queue が見つかりません") | |
| return False | |
| return True | |
| if __name__ == "__main__": | |
| # writer_queueの確認 | |
| if not test_writer_queue(): | |
| print("基本的な要件が満たされていません") | |
| exit(1) | |
| # ストリーミングテスト | |
| try: | |
| test_streaming() | |
| except Exception as e: | |
| print(f"❌ エラー: {e}") | |
| import traceback | |
| traceback.print_exc() |