YoungjaeDev
fix: HF Spaces import 에러 해결 - self-contained 구조로 변경
8133f1d
"""
ST-GCN Model for Fall Detection
Spatial-Temporal Graph Convolutional Networks for skeleton-based action recognition.
Adapted for binary fall detection (Fall vs Non-Fall) and multi-class fall type classification.
References:
- ST-GCN Paper: https://arxiv.org/abs/1801.07455
- Official Implementation: https://github.com/yysijie/st-gcn
- Fall Detection: Keskes & Noumeir (2021)
Input Shape: (N, C, T, V, M)
- N: Batch size
- C: Number of channels (3: x, y, confidence)
- T: Temporal dimension (number of frames)
- V: Number of vertices (17 COCO keypoints)
- M: Number of persons (1 for single-person scenarios)
Output: Class logits for Fall/Non-Fall (binary) or BY/FY/SY/N (multi-class)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .graph import Graph
class STGCNLayer(nn.Module):
"""
Spatial-Temporal Graph Convolutional Layer.
Combines spatial graph convolution and temporal convolution.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dropout=0.5,
residual=True
):
"""
Initialize ST-GCN layer.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Tuple (temporal_kernel_size, spatial_kernel_size)
stride: Temporal stride for downsampling
dropout: Dropout probability
residual: Whether to use residual connection
"""
super(STGCNLayer, self).__init__()
assert len(kernel_size) == 2, "Kernel size must be (temporal, spatial)"
assert kernel_size[0] % 2 == 1, "Temporal kernel size must be odd"
padding = ((kernel_size[0] - 1) // 2, 0) # Temporal padding only
# Spatial graph convolution
self.gcn = SpatialGraphConv(
in_channels,
out_channels,
kernel_size[1]
)
# Temporal convolution
self.tcn = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
out_channels,
out_channels,
(kernel_size[0], 1),
(stride, 1),
padding,
),
nn.BatchNorm2d(out_channels),
nn.Dropout(dropout, inplace=True),
)
# Residual connection
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1)
),
nn.BatchNorm2d(out_channels),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, A):
"""
Forward pass.
Args:
x: Input tensor (N, C, T, V)
A: Adjacency matrix (K, V, V) where K is number of partitions
Returns:
Output tensor (N, C', T', V)
"""
res = self.residual(x)
x = self.gcn(x, A)
x = self.tcn(x) + res
return self.relu(x)
class SpatialGraphConv(nn.Module):
"""
Spatial graph convolutional layer.
Applies graph convolution on skeleton graph using adjacency matrix.
"""
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
"""
Initialize spatial graph convolution.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Number of adjacency matrix partitions (1 or 3)
bias: Whether to include bias term
"""
super(SpatialGraphConv, self).__init__()
self.kernel_size = kernel_size
# Convolutional weights for each partition
self.conv = nn.Conv2d(
in_channels,
out_channels * kernel_size,
kernel_size=1,
bias=bias
)
def forward(self, x, A):
"""
Forward pass.
Args:
x: Input tensor (N, C, T, V)
A: Adjacency matrix (K, V, V)
Returns:
Output tensor (N, C', T, V)
"""
assert A.size(0) == self.kernel_size, \
f"Adjacency matrix size {A.size(0)} != kernel size {self.kernel_size}"
# Apply convolution
x = self.conv(x) # (N, C'*K, T, V)
# Split channels for each partition
n, kc, t, v = x.size()
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) # (N, K, C', T, V)
# Apply graph convolution with each partition
# A: (K, V, V)
# x: (N, K, C', T, V)
x = torch.einsum('nkctv,kvw->nctw', x, A) # (N, C', T, V)
return x.contiguous()
class STGCN(nn.Module):
"""
ST-GCN model for fall detection.
Architecture:
- Input: (N, 3, 60, 17, 1) - batch, channels, frames, joints, persons
- ST-GCN layers: Extract spatial-temporal features
- Global pooling: Aggregate features across time and space
- FC layers: Classification (binary or multi-class)
"""
def __init__(
self,
num_classes=2,
in_channels=3,
edge_importance_weighting=True,
graph_cfg=None,
dropout=0.5,
**kwargs
):
"""
Initialize ST-GCN model.
Args:
num_classes: Number of output classes (2 for binary, 4 for multi-class)
in_channels: Number of input channels (3: x, y, confidence)
edge_importance_weighting: Whether to learn edge importance weights
graph_cfg: Graph configuration (default: spatial labeling)
dropout: Dropout probability
"""
super(STGCN, self).__init__()
# Load graph
if graph_cfg is None:
graph_cfg = {'labeling_mode': 'spatial'}
self.graph = Graph(**graph_cfg)
# Get adjacency matrix (K, V, V) where K=3 for spatial labeling
A = torch.tensor(
self.graph.get_adjacency_matrix(normalize=True),
dtype=torch.float32,
requires_grad=False
)
self.register_buffer('A', A)
# Number of adjacency matrix partitions
spatial_kernel_size = A.size(0) # 3 for spatial labeling
# Temporal kernel size (odd numbers for symmetric padding)
temporal_kernel_size = 9
# Build ST-GCN layers
kernel_size = (temporal_kernel_size, spatial_kernel_size)
# Layer configurations: (in_channels, out_channels, stride)
self.st_gcn_networks = nn.ModuleList((
STGCNLayer(in_channels, 64, kernel_size, 1, dropout, residual=False),
STGCNLayer(64, 64, kernel_size, 1, dropout),
STGCNLayer(64, 64, kernel_size, 1, dropout),
STGCNLayer(64, 64, kernel_size, 1, dropout),
STGCNLayer(64, 128, kernel_size, 2, dropout),
STGCNLayer(128, 128, kernel_size, 1, dropout),
STGCNLayer(128, 128, kernel_size, 1, dropout),
STGCNLayer(128, 256, kernel_size, 2, dropout),
STGCNLayer(256, 256, kernel_size, 1, dropout),
STGCNLayer(256, 256, kernel_size, 1, dropout),
))
# Edge importance weighting
if edge_importance_weighting:
self.edge_importance = nn.ParameterList([
nn.Parameter(torch.ones(self.A.size()))
for _ in self.st_gcn_networks
])
else:
self.edge_importance = [1] * len(self.st_gcn_networks)
# Fully connected layer for classification
self.fcn = nn.Conv2d(256, num_classes, kernel_size=1)
def forward(self, x):
"""
Forward pass.
Args:
x: Input tensor (N, C, T, V, M)
- N: Batch size
- C: Number of channels (3)
- T: Number of frames (60)
- V: Number of joints (17)
- M: Number of persons (1)
Returns:
Output logits (N, num_classes)
"""
# Reshape input: (N, C, T, V, M) -> (N*M, C, T, V)
N, C, T, V, M = x.size()
x = x.permute(0, 4, 1, 2, 3).contiguous() # (N, M, C, T, V)
x = x.view(N * M, C, T, V) # (N*M, C, T, V)
# Forward through ST-GCN layers
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
x = gcn(x, self.A * importance)
# Global pooling: (N*M, C, T, V) -> (N*M, C)
x = F.avg_pool2d(x, x.size()[2:]) # (N*M, C, 1, 1)
x = x.view(N, M, -1, 1, 1).mean(dim=1) # Average across persons: (N, C, 1, 1)
# Classification
x = self.fcn(x) # (N, num_classes, 1, 1)
x = x.view(x.size(0), -1) # (N, num_classes)
return x
def extract_features(self, x):
"""
Extract features before classification layer.
Args:
x: Input tensor (N, C, T, V, M)
Returns:
Feature tensor (N, 256)
"""
# Reshape input
N, C, T, V, M = x.size()
x = x.permute(0, 4, 1, 2, 3).contiguous()
x = x.view(N * M, C, T, V)
# Forward through ST-GCN layers
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
x = gcn(x, self.A * importance)
# Global pooling
x = F.avg_pool2d(x, x.size()[2:])
x = x.view(N, M, -1).mean(dim=1) # (N, 256)
return x
def stgcn_binary(pretrained=False, **kwargs):
"""
ST-GCN for binary fall detection (Fall vs Non-Fall).
Args:
pretrained: Whether to load pretrained weights (not implemented)
**kwargs: Additional model arguments
Returns:
ST-GCN model
"""
model = STGCN(num_classes=2, **kwargs)
if pretrained:
raise NotImplementedError("Pretrained weights not available")
return model
def stgcn_multiclass(pretrained=False, **kwargs):
"""
ST-GCN for multi-class fall detection (BY/FY/SY/N).
Args:
pretrained: Whether to load pretrained weights (not implemented)
**kwargs: Additional model arguments
Returns:
ST-GCN model
"""
model = STGCN(num_classes=4, **kwargs)
if pretrained:
raise NotImplementedError("Pretrained weights not available")
return model
if __name__ == '__main__':
# Test model construction
print("Testing ST-GCN Model...")
# Binary classification
model_binary = stgcn_binary()
print(f"\nBinary ST-GCN:")
print(f" Parameters: {sum(p.numel() for p in model_binary.parameters()):,}")
print(f" Trainable: {sum(p.numel() for p in model_binary.parameters() if p.requires_grad):,}")
# Multi-class classification
model_multiclass = stgcn_multiclass()
print(f"\nMulti-class ST-GCN:")
print(f" Parameters: {sum(p.numel() for p in model_multiclass.parameters()):,}")
# Test forward pass
batch_size = 4
input_tensor = torch.randn(batch_size, 3, 60, 17, 1)
print(f"\nInput shape: {input_tensor.shape}")
# Binary output
output_binary = model_binary(input_tensor)
print(f"Binary output shape: {output_binary.shape}")
print(f"Binary output: {output_binary}")
# Multi-class output
output_multiclass = model_multiclass(input_tensor)
print(f"Multi-class output shape: {output_multiclass.shape}")
# Feature extraction
features = model_binary.extract_features(input_tensor)
print(f"Feature shape: {features.shape}")
print("\nST-GCN model construction successful!")