TTNet-TableTennis / README.md
Hruno's picture
Upload README.md with huggingface_hub
ce59861 verified
|
Raw
History Blame Contribute Delete
4.84 kB
---
license: mit
tags:
- table-tennis
- ball-detection
- pytorch
language:
- ja
---
# TTNet - Table Tennis Ball Detection (Fine-tuned on OpenTTGames)
TTNet(CVPR2020)の学習済みモデル。OpenTTGames データセットで 3フェーズ学習を行ったチェックポイント一式。
元のリポジトリ: [maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch](https://github.com/maudzung/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch)
---
## チェックポイント一覧
| ファイル名 | 説明 |
|---|---|
| `ttnet_1st_phase_best.pth` | stride=1(120fps想定)Phase 1 best (epoch 17) |
| `ttnet_2nd_phase_best.pth` | stride=1 Phase 2 best (epoch 3) |
| `ttnet_3rd_phase_best.pth` | stride=1 Phase 3 best (epoch 3) ← stride=1推論用 |
| `ttnet_30fps_1st_phase_best.pth` | stride=4(30fps対応)Phase 1 best (epoch 12) |
| `ttnet_30fps_2nd_phase_best.pth` | stride=4 Phase 2 best (epoch 7) |
| `ttnet_30fps_3rd_phase_best.pth` | stride=4 Phase 3 best (epoch 1) ← stride=4推論用 |
---
## 学習環境
- GPU: NVIDIA GeForce RTX 4070 Ti Super (16GB)
- PyTorch 2.4.0 (CUDA 12.1)
- データセット: [OpenTTGames](https://lab.osai.ai/datasets/openttgames/)
- 学習: game_1〜4(約46,000フレーム)
- 検証: game_5(約7,300フレーム)
---
## 学習方法
### フレーム画像の抽出
```bash
cd TTNet/TTNet-Real-time-Analysis-System-for-Table-Tennis-Pytorch/prepare_dataset
# stride=1(120fps想定)
python extract_smooth_labellings.py
# → dataset/training/images/ に出力(約39GB)
# stride=4(30fps対応)
python extract_smooth_labellings.py --frame_stride 4
# → dataset/training/images_stride4/ に出力(約39GB)
```
### Phase 1〜3 の学習(stride=1)
```bash
bash train_ttnet.sh
# 内部で以下を順番に実行:
# Phase 1: ボール全体検出 + セグメンテーション
# Phase 2: ローカル検出 + イベント検出
# Phase 3: 全タスク fine-tune
# 重みの保存先: TTNet/checkpoints/ttnet_{1,2,3}rd_phase/
```
### Phase 1〜3 の学習(stride=4、30fps対応)
```bash
bash train_ttnet_30fps.sh
# stride=4 専用スクリプト。モデル構造は変更なし(9フレーム27ch入力)
# 重みの保存先: TTNet/checkpoints/ttnet_30fps_{1,2,3}rd_phase/
```
---
## 学習結果
### stride=1(game_1〜4 学習 / game_5 検証)
| フェーズ | 内容 | Best Val Loss | Best Epoch |
|---------|------|--------------|-----------|
| Phase 1 | Global検出 + セグメンテーション | 0.1804 | 17 / 30 |
| Phase 2 | Local検出 + イベント検出 | 0.1252 | 3 / 30 |
| Phase 3 | 全タスク Fine-tune | 0.2865 | 3 / 30 |
### stride=4(game_1〜4 学習 / game_5 検証)
| フェーズ | 内容 | Best Val Loss | Best Epoch |
|---------|------|--------------|-----------|
| Phase 1 | Global検出 + セグメンテーション | 0.2664 | 12 / 30 |
| Phase 2 | Local検出 + イベント検出 | 0.0961 | 7 / 30 |
| Phase 3 | 全タスク Fine-tune | 0.3456 | 1 / 30 |
---
## 推論方法
```bash
source venv_ttnet/bin/activate
# stride=1(120fps想定映像向け)
python run_ball_trajectory.py \
--player haruya \
--video 1 \
--pretrained_path TTNet/checkpoints/ttnet_3rd_phase/ttnet_3rd_phase_best.pth \
--gpu_idx 0
# stride=4(30fps映像向け)
python run_ball_trajectory.py \
--player haruya \
--video 1 \
--pretrained_path TTNet/checkpoints/ttnet_30fps_3rd_phase/ttnet_30fps_3rd_phase_best.pth \
--frame_stride 4 \
--run-id stride4
```
---
## TTNet のモデル構造(概要)
1. **Global Stage**: 9フレームスタック(27ch、320×128)を入力。Conv×6 + FC×3 で 448次元の1Dガウス分布を出力し、`argmax` でボール位置を粗く推定(X方向320px + Y方向128px)
2. **Local Stage**: Global推定位置を中心に元画像をクロップして再度同じネットワークに通し、精密な座標を算出
3. **最終座標**: `bx = x_global × (1920/320) + x_local - 160``by = y_global × (1080/128) + y_local - 64`
4. **Segmentation / Event Spotting**: Global中間特徴からデコードして卓球台マスクとバウンス・ネット接触を並列に検出
---
## ドメインシフトについて(重要な注意事項)
- 学習データ(OpenTTGames): 暗い競技会場・固定俯瞰カメラ・120fps
- 本研究で適用した映像: 明るい体育館・角度のある手持ちカメラ・30fps
- stride=4 による 30fps 対応後も全フレームで `x=-160, y=543` の固定値となり検出不能
- **公式映像(OpenTTGames game_5)では正しく検出できることを確認済み**(ドメインシフトが原因)
- このため、事前学習済み重みで様々な映像環境に対応できる **TrackNetV3** に移行した