TTNet - Table Tennis Ball Detection (Fine-tuned on OpenTTGames)

TTNet(CVPR2020)の学習済みモデル。OpenTTGames データセットで 3フェーズ学習を行ったチェックポイント一式。 元のリポジトリ: maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch


チェックポイント一覧

ファイル名 説明
ttnet_1st_phase_best.pth stride=1(120fps想定)Phase 1 best (epoch 17)
ttnet_2nd_phase_best.pth stride=1 Phase 2 best (epoch 3)
ttnet_3rd_phase_best.pth stride=1 Phase 3 best (epoch 3) ← stride=1推論用
ttnet_30fps_1st_phase_best.pth stride=4(30fps対応)Phase 1 best (epoch 12)
ttnet_30fps_2nd_phase_best.pth stride=4 Phase 2 best (epoch 7)
ttnet_30fps_3rd_phase_best.pth stride=4 Phase 3 best (epoch 1) ← stride=4推論用

学習環境

  • GPU: NVIDIA GeForce RTX 4070 Ti Super (16GB)
  • PyTorch 2.4.0 (CUDA 12.1)
  • データセット: OpenTTGames
    • 学習: game_1〜4(約46,000フレーム)
    • 検証: game_5(約7,300フレーム)

学習方法

フレーム画像の抽出

cd TTNet/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch/prepare_dataset

# stride=1(120fps想定)
python extract_smooth_labellings.py
# → dataset/training/images/ に出力(約39GB)

# stride=4(30fps対応)
python extract_smooth_labellings.py --frame_stride 4
# → dataset/training/images_stride4/ に出力(約39GB)

Phase 1〜3 の学習(stride=1)

bash train_ttnet.sh
# 内部で以下を順番に実行:
#   Phase 1: ボール全体検出 + セグメンテーション
#   Phase 2: ローカル検出 + イベント検出
#   Phase 3: 全タスク fine-tune
# 重みの保存先: TTNet/checkpoints/ttnet_{1,2,3}rd_phase/

Phase 1〜3 の学習(stride=4、30fps対応)

bash train_ttnet_30fps.sh
# stride=4 専用スクリプト。モデル構造は変更なし(9フレーム27ch入力)
# 重みの保存先: TTNet/checkpoints/ttnet_30fps_{1,2,3}rd_phase/

学習結果

stride=1(game_1〜4 学習 / game_5 検証)

フェーズ 内容 Best Val Loss Best Epoch
Phase 1 Global検出 + セグメンテーション 0.1804 17 / 30
Phase 2 Local検出 + イベント検出 0.1252 3 / 30
Phase 3 全タスク Fine-tune 0.2865 3 / 30

stride=4(game_1〜4 学習 / game_5 検証)

フェーズ 内容 Best Val Loss Best Epoch
Phase 1 Global検出 + セグメンテーション 0.2664 12 / 30
Phase 2 Local検出 + イベント検出 0.0961 7 / 30
Phase 3 全タスク Fine-tune 0.3456 1 / 30

推論方法

source venv_ttnet/bin/activate

# stride=1(120fps想定映像向け)
python run_ball_trajectory.py \
  --player haruya \
  --video 1 \
  --pretrained_path TTNet/checkpoints/ttnet_3rd_phase/ttnet_3rd_phase_best.pth \
  --gpu_idx 0

# stride=4(30fps映像向け)
python run_ball_trajectory.py \
  --player haruya \
  --video 1 \
  --pretrained_path TTNet/checkpoints/ttnet_30fps_3rd_phase/ttnet_30fps_3rd_phase_best.pth \
  --frame_stride 4 \
  --run-id stride4

TTNet のモデル構造(概要)

  1. Global Stage: 9フレームスタック(27ch、320×128)を入力。Conv×6 + FC×3 で 448次元の1Dガウス分布を出力し、argmax でボール位置を粗く推定(X方向320px + Y方向128px)
  2. Local Stage: Global推定位置を中心に元画像をクロップして再度同じネットワークに通し、精密な座標を算出
  3. 最終座標: bx = x_global × (1920/320) + x_local - 160by = y_global × (1080/128) + y_local - 64
  4. Segmentation / Event Spotting: Global中間特徴からデコードして卓球台マスクとバウンス・ネット接触を並列に検出

ドメインシフトについて(重要な注意事項)

  • 学習データ(OpenTTGames): 暗い競技会場・固定俯瞰カメラ・120fps
  • 本研究で適用した映像: 明るい体育館・角度のある手持ちカメラ・30fps
  • stride=4 による 30fps 対応後も全フレームで x=-160, y=543 の固定値となり検出不能
  • 公式映像(OpenTTGames game_5)では正しく検出できることを確認済み(ドメインシフトが原因)
  • このため、事前学習済み重みで様々な映像環境に対応できる TrackNetV3 に移行した
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support