| --- |
| license: mit |
| tags: |
| - table-tennis |
| - ball-detection |
| - pytorch |
| language: |
| - ja |
| --- |
| |
| # TTNet - Table Tennis Ball Detection (Fine-tuned on OpenTTGames) |
|
|
| TTNet(CVPR2020)の学習済みモデル。OpenTTGames データセットで 3フェーズ学習を行ったチェックポイント一式。 |
| 元のリポジトリ: [maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch](https://github.com/maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch) |
|
|
| --- |
|
|
| ## チェックポイント一覧 |
|
|
| | ファイル名 | 説明 | |
| |---|---| |
| | `ttnet_1st_phase_best.pth` | stride=1(120fps想定)Phase 1 best (epoch 17) | |
| | `ttnet_2nd_phase_best.pth` | stride=1 Phase 2 best (epoch 3) | |
| | `ttnet_3rd_phase_best.pth` | stride=1 Phase 3 best (epoch 3) ← stride=1推論用 | |
| | `ttnet_30fps_1st_phase_best.pth` | stride=4(30fps対応)Phase 1 best (epoch 12) | |
| | `ttnet_30fps_2nd_phase_best.pth` | stride=4 Phase 2 best (epoch 7) | |
| | `ttnet_30fps_3rd_phase_best.pth` | stride=4 Phase 3 best (epoch 1) ← stride=4推論用 | |
|
|
| --- |
|
|
| ## 学習環境 |
|
|
| - GPU: NVIDIA GeForce RTX 4070 Ti Super (16GB) |
| - PyTorch 2.4.0 (CUDA 12.1) |
| - データセット: [OpenTTGames](https://lab.osai.ai/datasets/openttgames/) |
| - 学習: game_1〜4(約46,000フレーム) |
| - 検証: game_5(約7,300フレーム) |
|
|
| --- |
|
|
| ## 学習方法 |
|
|
| ### フレーム画像の抽出 |
|
|
| ```bash |
| cd TTNet/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch/prepare_dataset |
| |
| # stride=1(120fps想定) |
| python extract_smooth_labellings.py |
| # → dataset/training/images/ に出力(約39GB) |
| |
| # stride=4(30fps対応) |
| python extract_smooth_labellings.py --frame_stride 4 |
| # → dataset/training/images_stride4/ に出力(約39GB) |
| ``` |
|
|
| ### Phase 1〜3 の学習(stride=1) |
|
|
| ```bash |
| bash train_ttnet.sh |
| # 内部で以下を順番に実行: |
| # Phase 1: ボール全体検出 + セグメンテーション |
| # Phase 2: ローカル検出 + イベント検出 |
| # Phase 3: 全タスク fine-tune |
| # 重みの保存先: TTNet/checkpoints/ttnet_{1,2,3}rd_phase/ |
| ``` |
|
|
| ### Phase 1〜3 の学習(stride=4、30fps対応) |
|
|
| ```bash |
| bash train_ttnet_30fps.sh |
| # stride=4 専用スクリプト。モデル構造は変更なし(9フレーム27ch入力) |
| # 重みの保存先: TTNet/checkpoints/ttnet_30fps_{1,2,3}rd_phase/ |
| ``` |
|
|
| --- |
|
|
| ## 学習結果 |
|
|
| ### stride=1(game_1〜4 学習 / game_5 検証) |
|
|
| | フェーズ | 内容 | Best Val Loss | Best Epoch | |
| |---------|------|--------------|-----------| |
| | Phase 1 | Global検出 + セグメンテーション | 0.1804 | 17 / 30 | |
| | Phase 2 | Local検出 + イベント検出 | 0.1252 | 3 / 30 | |
| | Phase 3 | 全タスク Fine-tune | 0.2865 | 3 / 30 | |
|
|
| ### stride=4(game_1〜4 学習 / game_5 検証) |
|
|
| | フェーズ | 内容 | Best Val Loss | Best Epoch | |
| |---------|------|--------------|-----------| |
| | Phase 1 | Global検出 + セグメンテーション | 0.2664 | 12 / 30 | |
| | Phase 2 | Local検出 + イベント検出 | 0.0961 | 7 / 30 | |
| | Phase 3 | 全タスク Fine-tune | 0.3456 | 1 / 30 | |
|
|
| --- |
|
|
| ## 推論方法 |
|
|
| ```bash |
| source venv_ttnet/bin/activate |
| |
| # stride=1(120fps想定映像向け) |
| python run_ball_trajectory.py \ |
| --player haruya \ |
| --video 1 \ |
| --pretrained_path TTNet/checkpoints/ttnet_3rd_phase/ttnet_3rd_phase_best.pth \ |
| --gpu_idx 0 |
| |
| # stride=4(30fps映像向け) |
| python run_ball_trajectory.py \ |
| --player haruya \ |
| --video 1 \ |
| --pretrained_path TTNet/checkpoints/ttnet_30fps_3rd_phase/ttnet_30fps_3rd_phase_best.pth \ |
| --frame_stride 4 \ |
| --run-id stride4 |
| ``` |
|
|
| --- |
|
|
| ## TTNet のモデル構造(概要) |
|
|
| 1. **Global Stage**: 9フレームスタック(27ch、320×128)を入力。Conv×6 + FC×3 で 448次元の1Dガウス分布を出力し、`argmax` でボール位置を粗く推定(X方向320px + Y方向128px) |
| 2. **Local Stage**: Global推定位置を中心に元画像をクロップして再度同じネットワークに通し、精密な座標を算出 |
| 3. **最終座標**: `bx = x_global × (1920/320) + x_local - 160`、`by = y_global × (1080/128) + y_local - 64` |
| 4. **Segmentation / Event Spotting**: Global中間特徴からデコードして卓球台マスクとバウンス・ネット接触を並列に検出 |
|
|
| --- |
|
|
| ## ドメインシフトについて(重要な注意事項) |
|
|
| - 学習データ(OpenTTGames): 暗い競技会場・固定俯瞰カメラ・120fps |
| - 本研究で適用した映像: 明るい体育館・角度のある手持ちカメラ・30fps |
| - stride=4 による 30fps 対応後も全フレームで `x=-160, y=543` の固定値となり検出不能 |
| - **公式映像(OpenTTGames game_5)では正しく検出できることを確認済み**(ドメインシフトが原因) |
| - このため、事前学習済み重みで様々な映像環境に対応できる **TrackNetV3** に移行した |
| |