CourtKeyNet / config.yaml
Cracked-ANJ's picture
Upload folder using huggingface_hub
8254e28 verified
# Model Configuration (Manuscript Section 4)
model:
name: "CourtKeyNet"
# Octave Feature Extractor
ofe:
channels_per_band: 64 # C for high/mid/low paths
stem_channels: 64
# Polar Transform Attention
pta:
enabled: true
radial_bins: 64 # Manuscript says fine-grained polar grid
angular_bins: 128
# Keypoint Localization
heatmap_sigma: 8.0 # Increased from 2.0 to help convergence
num_keypoints: 4
feature_dim: 128
# Transformer Refinement
transformer:
num_layers: 2
num_heads: 4
dim_feedforward: 512
dropout: 0.1
# Quadrilateral Constraint Module
qcm:
hidden_dims: [64, 128]
output_dim: 128
# Training Configuration
train:
epochs: 300
batch_size: 47
num_workers: 4
imgsz: 640
early_stopping_patience: 20 # Stop if no improvement for N epochs
# Optimizer (Manuscript: AdamW)
optimizer: "adamw"
lr0: 0.00005 # Reduced further to prevent divergence with high loss weights
lrf: 0.01 # Final LR multiplier
weight_decay: 0.0005
momentum: 0.937
# LR Scheduler
warmup_epochs: 5
scheduler: "cosine"
# Loss Weights (Manuscript Eq. 47)
# Geometric losses disabled initially (set to 0) to let model learn keypoints first
loss_weights:
keypoint: 10.0 # Primary loss for coordinate accuracy
heatmap: 20.0 # Heatmap supervision
edge: 0.0 # Disabled initially - enable after model converges
diagonal: 0.0 # Disabled initially
angle: 0.0 # Disabled initially
# Geometric loss warmup (linearly scale geo losses over N epochs)
# Only relevant when edge/diagonal/angle > 0
# geo_warmup_epochs: 10 #enabled in finetune only
# Training Tricks
mixed_precision: true
grad_clip: 1.0
ema_decay: 0.9999
# Checkpointing
save_interval: 5
save_best: true
# Paths
project: "runs/courtkeynet"
name: "exp"
# Validation
val:
batch_size: 32
interval: 1 # Validate every N epochs
# Weights & Biases Configuration
wandb:
enabled: true
project: "CourtKeyNet"
entity: null # Set to your wandb username or team name, or leave null for default
name: CourtKeyNet # Run name (auto-generated if null)
tags: ["badminton", "court-detection", "keypoint"]
log_freq: 100 # Log every N batches
# Device
device: "cuda" # or "cpu"