--- license: mit tags: - table-tennis - ball-detection - pytorch language: - ja --- # TTNet - Table Tennis Ball Detection (Fine-tuned on OpenTTGames) TTNet(CVPR2020)の学習済みモデル。OpenTTGames データセットで 3フェーズ学習を行ったチェックポイント一式。 元のリポジトリ: [maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch](https://github.com/maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch) --- ## チェックポイント一覧 | ファイル名 | 説明 | |---|---| | `ttnet_1st_phase_best.pth` | stride=1(120fps想定)Phase 1 best (epoch 17) | | `ttnet_2nd_phase_best.pth` | stride=1 Phase 2 best (epoch 3) | | `ttnet_3rd_phase_best.pth` | stride=1 Phase 3 best (epoch 3) ← stride=1推論用 | | `ttnet_30fps_1st_phase_best.pth` | stride=4(30fps対応)Phase 1 best (epoch 12) | | `ttnet_30fps_2nd_phase_best.pth` | stride=4 Phase 2 best (epoch 7) | | `ttnet_30fps_3rd_phase_best.pth` | stride=4 Phase 3 best (epoch 1) ← stride=4推論用 | --- ## 学習環境 - GPU: NVIDIA GeForce RTX 4070 Ti Super (16GB) - PyTorch 2.4.0 (CUDA 12.1) - データセット: [OpenTTGames](https://lab.osai.ai/datasets/openttgames/) - 学習: game_1〜4(約46,000フレーム) - 検証: game_5(約7,300フレーム) --- ## 学習方法 ### フレーム画像の抽出 ```bash cd TTNet/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch/prepare_dataset # stride=1(120fps想定) python extract_smooth_labellings.py # → dataset/training/images/ に出力(約39GB) # stride=4(30fps対応) python extract_smooth_labellings.py --frame_stride 4 # → dataset/training/images_stride4/ に出力(約39GB) ``` ### Phase 1〜3 の学習(stride=1) ```bash bash train_ttnet.sh # 内部で以下を順番に実行: # Phase 1: ボール全体検出 + セグメンテーション # Phase 2: ローカル検出 + イベント検出 # Phase 3: 全タスク fine-tune # 重みの保存先: TTNet/checkpoints/ttnet_{1,2,3}rd_phase/ ``` ### Phase 1〜3 の学習(stride=4、30fps対応) ```bash bash train_ttnet_30fps.sh # stride=4 専用スクリプト。モデル構造は変更なし(9フレーム27ch入力) # 重みの保存先: TTNet/checkpoints/ttnet_30fps_{1,2,3}rd_phase/ ``` --- ## 学習結果 ### stride=1(game_1〜4 学習 / game_5 検証) | フェーズ | 内容 | Best Val Loss | Best Epoch | |---------|------|--------------|-----------| | Phase 1 | Global検出 + セグメンテーション | 0.1804 | 17 / 30 | | Phase 2 | Local検出 + イベント検出 | 0.1252 | 3 / 30 | | Phase 3 | 全タスク Fine-tune | 0.2865 | 3 / 30 | ### stride=4(game_1〜4 学習 / game_5 検証) | フェーズ | 内容 | Best Val Loss | Best Epoch | |---------|------|--------------|-----------| | Phase 1 | Global検出 + セグメンテーション | 0.2664 | 12 / 30 | | Phase 2 | Local検出 + イベント検出 | 0.0961 | 7 / 30 | | Phase 3 | 全タスク Fine-tune | 0.3456 | 1 / 30 | --- ## 推論方法 ```bash source venv_ttnet/bin/activate # stride=1(120fps想定映像向け) python run_ball_trajectory.py \ --player haruya \ --video 1 \ --pretrained_path TTNet/checkpoints/ttnet_3rd_phase/ttnet_3rd_phase_best.pth \ --gpu_idx 0 # stride=4(30fps映像向け) python run_ball_trajectory.py \ --player haruya \ --video 1 \ --pretrained_path TTNet/checkpoints/ttnet_30fps_3rd_phase/ttnet_30fps_3rd_phase_best.pth \ --frame_stride 4 \ --run-id stride4 ``` --- ## TTNet のモデル構造(概要) 1. **Global Stage**: 9フレームスタック(27ch、320×128)を入力。Conv×6 + FC×3 で 448次元の1Dガウス分布を出力し、`argmax` でボール位置を粗く推定(X方向320px + Y方向128px) 2. **Local Stage**: Global推定位置を中心に元画像をクロップして再度同じネットワークに通し、精密な座標を算出 3. **最終座標**: `bx = x_global × (1920/320) + x_local - 160`、`by = y_global × (1080/128) + y_local - 64` 4. **Segmentation / Event Spotting**: Global中間特徴からデコードして卓球台マスクとバウンス・ネット接触を並列に検出 --- ## ドメインシフトについて(重要な注意事項) - 学習データ(OpenTTGames): 暗い競技会場・固定俯瞰カメラ・120fps - 本研究で適用した映像: 明るい体育館・角度のある手持ちカメラ・30fps - stride=4 による 30fps 対応後も全フレームで `x=-160, y=543` の固定値となり検出不能 - **公式映像(OpenTTGames game_5)では正しく検出できることを確認済み**(ドメインシフトが原因) - このため、事前学習済み重みで様々な映像環境に対応できる **TrackNetV3** に移行した