smashfix-v1 / src /models.py
uncertainrods's picture
v1-try-deploy
0d0412d
"""
Neural Network Architecture Definitions
========================================
Defines the deep learning model architectures for badminton shot classification.
Provides two model builders for different feature input types.
Architectures:
1. build_lstm_pose(input_shape, num_classes)
- Conv1D feature extractor: 128 filters, kernel_size=4
- Stacked LSTM layers: 128 units β†’ 64 units
- Batch normalization for training stability
- Dropout regularization (0.3-0.4) to prevent overfitting
- Softmax classifier for multi-class output
2. build_tcn_hybrid(pose_shape, cnn_shape, num_classes)
- CNN Branch: Dilated causal Conv1D (TCN-style) + GRU
- Pose Branch: GRU with batch normalization
- Late fusion: Concatenation of branch outputs
- L2 regularization (1e-4) on all kernels
- Multi-input model for simultaneous pose+visual features
Design Rationale:
- Conv1D captures local temporal patterns in pose sequences
- LSTM/GRU models long-range temporal dependencies
- Causal convolutions ensure no future information leakage
- Dilated convolutions expand receptive field efficiently
- Late fusion allows each modality to learn independently
Input/Output Specifications:
Pose Model:
Input: (batch, sequence_length, 99) - normalized pose features
Output: (batch, num_classes) - shot type probabilities
Hybrid Model:
Inputs: [(batch, T, cnn_dim), (batch, T, 99)]
Output: (batch, num_classes) - shot type probabilities
Dependencies:
External: tensorflow, keras
Author: IPD Research Team
Version: 1.0.0
"""
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (
Input, Conv1D, MaxPooling1D, LSTM, Dense, Dropout, BatchNormalization,
GRU, SpatialDropout1D, Concatenate, ReLU
)
from tensorflow.keras.regularizers import l2
def build_lstm_pose(input_shape, num_classes):
"""Build Conv1D + LSTM model for pose-based shot classification."""
model = Sequential([
Conv1D(filters=128, kernel_size=4, activation='relu', input_shape=input_shape),
MaxPooling1D(pool_size=3),
Dropout(0.3),
LSTM(128, return_sequences=True, activation='relu'),
Dropout(0.4),
BatchNormalization(),
LSTM(64, activation='relu'),
Dropout(0.3),
Dense(64, activation='relu'),
Dropout(0.2),
Dense(num_classes, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
def build_tcn_hybrid(pose_shape, cnn_shape, num_classes):
"""Build hybrid TCN model with temporal self-attention.
Architecture:
- CNN Branch: Dilated causal TCN β†’ Self-Attention β†’ GRU
- Pose Branch: Conv1D β†’ Self-Attention β†’ GRU
- Fusion: Concatenate β†’ Dense (softmax)
Key features:
- 4 TCN layers with dilations [1,2,4,8] for 31-frame receptive field
- Residual connections for gradient flow
- Temporal Multi-Head Self-Attention on both branches
- Tuned Hyperparameters (Optuna): 92% Acc
"""
from tensorflow.keras.layers import Add, MultiHeadAttention, LayerNormalization
# Tuned Hyperparameters (Acc: 94.6% in tuning, 92% on test)
reg = l2(7.6e-5)
dropout_rate = 0.22
tcn_filters = 80
gru_units = 80
attn_heads = 8
attn_key_dim = 32
fusion_units = 80
# --- CNN/Visual Branch with Deep TCN + Attention ---
cnn_in = Input(shape=cnn_shape, name="cnn_input")
# Initial projection
x = Conv1D(tcn_filters, 1, kernel_regularizer=reg)(cnn_in)
x = BatchNormalization()(x)
x = ReLU()(x)
# Deep TCN with residual connections: dilations 1, 2, 4, 8
for dilation in [1, 2, 4, 8]:
residual = x
x = Conv1D(tcn_filters, 3, padding="causal", dilation_rate=dilation, kernel_regularizer=reg)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
x = SpatialDropout1D(dropout_rate)(x)
# Residual connection
x = Add()([x, residual])
# Temporal Self-Attention: learn which frames are most important
attn_out = MultiHeadAttention(num_heads=attn_heads, key_dim=attn_key_dim, dropout=0.2)(x, x)
x = Add()([x, attn_out]) # Residual around attention
x = LayerNormalization()(x)
x = GRU(gru_units, dropout=dropout_rate)(x)
x = Dense(max(gru_units // 2, 32), activation="relu", kernel_regularizer=reg)(x)
x = Dropout(dropout_rate)(x)
# --- Pose Branch with Conv1D + Attention + GRU ---
pose_in = Input(shape=pose_shape, name="pose_input")
# Local pattern extraction with Conv1D
y = Conv1D(tcn_filters, 3, padding="causal", activation="relu", kernel_regularizer=reg)(pose_in)
y = BatchNormalization()(y)
y = SpatialDropout1D(dropout_rate)(y)
# Temporal Self-Attention on pose features
pose_attn = MultiHeadAttention(num_heads=attn_heads, key_dim=attn_key_dim, dropout=0.2)(y, y)
y = Add()([y, pose_attn]) # Residual around attention
y = LayerNormalization()(y)
y = GRU(gru_units, dropout=dropout_rate)(y)
y = BatchNormalization()(y)
y = Dense(max(gru_units // 2, 32), activation="relu", kernel_regularizer=reg)(y)
y = Dropout(dropout_rate)(y)
# --- Fusion Layer ---
# Tuned fusion architecture
fused = Concatenate()([x, y])
fused = Dense(fusion_units, activation="relu", kernel_regularizer=reg)(fused)
fused = Dropout(min(dropout_rate + 0.1, 0.5))(fused)
fused = Dense(fusion_units // 2, activation="relu", kernel_regularizer=reg)(fused)
fused = Dropout(dropout_rate)(fused)
out = Dense(num_classes, activation="softmax")(fused)
model = Model([cnn_in, pose_in], out)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=5.6e-4),
loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
metrics=["accuracy"]
)
return model