metadata
license: mit
tags:
- table-tennis
- ball-detection
- pytorch
language:
- ja
TTNet - Table Tennis Ball Detection (Fine-tuned on OpenTTGames)
TTNet(CVPR2020)の学習済みモデル。OpenTTGames データセットで 3フェーズ学習を行ったチェックポイント一式。 元のリポジトリ: maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch
チェックポイント一覧
| ファイル名 | 説明 |
|---|---|
ttnet_1st_phase_best.pth |
stride=1(120fps想定)Phase 1 best (epoch 17) |
ttnet_2nd_phase_best.pth |
stride=1 Phase 2 best (epoch 3) |
ttnet_3rd_phase_best.pth |
stride=1 Phase 3 best (epoch 3) ← stride=1推論用 |
ttnet_30fps_1st_phase_best.pth |
stride=4(30fps対応)Phase 1 best (epoch 12) |
ttnet_30fps_2nd_phase_best.pth |
stride=4 Phase 2 best (epoch 7) |
ttnet_30fps_3rd_phase_best.pth |
stride=4 Phase 3 best (epoch 1) ← stride=4推論用 |
学習環境
- GPU: NVIDIA GeForce RTX 4070 Ti Super (16GB)
- PyTorch 2.4.0 (CUDA 12.1)
- データセット: OpenTTGames
- 学習: game_1〜4(約46,000フレーム)
- 検証: game_5(約7,300フレーム)
学習方法
フレーム画像の抽出
cd TTNet/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch/prepare_dataset
# stride=1(120fps想定)
python extract_smooth_labellings.py
# → dataset/training/images/ に出力(約39GB)
# stride=4(30fps対応)
python extract_smooth_labellings.py --frame_stride 4
# → dataset/training/images_stride4/ に出力(約39GB)
Phase 1〜3 の学習(stride=1)
bash train_ttnet.sh
# 内部で以下を順番に実行:
# Phase 1: ボール全体検出 + セグメンテーション
# Phase 2: ローカル検出 + イベント検出
# Phase 3: 全タスク fine-tune
# 重みの保存先: TTNet/checkpoints/ttnet_{1,2,3}rd_phase/
Phase 1〜3 の学習(stride=4、30fps対応)
bash train_ttnet_30fps.sh
# stride=4 専用スクリプト。モデル構造は変更なし(9フレーム27ch入力)
# 重みの保存先: TTNet/checkpoints/ttnet_30fps_{1,2,3}rd_phase/
学習結果
stride=1(game_1〜4 学習 / game_5 検証)
| フェーズ | 内容 | Best Val Loss | Best Epoch |
|---|---|---|---|
| Phase 1 | Global検出 + セグメンテーション | 0.1804 | 17 / 30 |
| Phase 2 | Local検出 + イベント検出 | 0.1252 | 3 / 30 |
| Phase 3 | 全タスク Fine-tune | 0.2865 | 3 / 30 |
stride=4(game_1〜4 学習 / game_5 検証)
| フェーズ | 内容 | Best Val Loss | Best Epoch |
|---|---|---|---|
| Phase 1 | Global検出 + セグメンテーション | 0.2664 | 12 / 30 |
| Phase 2 | Local検出 + イベント検出 | 0.0961 | 7 / 30 |
| Phase 3 | 全タスク Fine-tune | 0.3456 | 1 / 30 |
推論方法
source venv_ttnet/bin/activate
# stride=1(120fps想定映像向け)
python run_ball_trajectory.py \
--player haruya \
--video 1 \
--pretrained_path TTNet/checkpoints/ttnet_3rd_phase/ttnet_3rd_phase_best.pth \
--gpu_idx 0
# stride=4(30fps映像向け)
python run_ball_trajectory.py \
--player haruya \
--video 1 \
--pretrained_path TTNet/checkpoints/ttnet_30fps_3rd_phase/ttnet_30fps_3rd_phase_best.pth \
--frame_stride 4 \
--run-id stride4
TTNet のモデル構造(概要)
- Global Stage: 9フレームスタック(27ch、320×128)を入力。Conv×6 + FC×3 で 448次元の1Dガウス分布を出力し、
argmaxでボール位置を粗く推定(X方向320px + Y方向128px) - Local Stage: Global推定位置を中心に元画像をクロップして再度同じネットワークに通し、精密な座標を算出
- 最終座標:
bx = x_global × (1920/320) + x_local - 160、by = y_global × (1080/128) + y_local - 64 - Segmentation / Event Spotting: Global中間特徴からデコードして卓球台マスクとバウンス・ネット接触を並列に検出
ドメインシフトについて(重要な注意事項)
- 学習データ(OpenTTGames): 暗い競技会場・固定俯瞰カメラ・120fps
- 本研究で適用した映像: 明るい体育館・角度のある手持ちカメラ・30fps
- stride=4 による 30fps 対応後も全フレームで
x=-160, y=543の固定値となり検出不能 - 公式映像(OpenTTGames game_5)では正しく検出できることを確認済み(ドメインシフトが原因)
- このため、事前学習済み重みで様々な映像環境に対応できる TrackNetV3 に移行した