""" ST-GCN Model for Fall Detection Spatial-Temporal Graph Convolutional Networks for skeleton-based action recognition. Adapted for binary fall detection (Fall vs Non-Fall) and multi-class fall type classification. References: - ST-GCN Paper: https://arxiv.org/abs/1801.07455 - Official Implementation: https://github.com/yysijie/st-gcn - Fall Detection: Keskes & Noumeir (2021) Input Shape: (N, C, T, V, M) - N: Batch size - C: Number of channels (3: x, y, confidence) - T: Temporal dimension (number of frames) - V: Number of vertices (17 COCO keypoints) - M: Number of persons (1 for single-person scenarios) Output: Class logits for Fall/Non-Fall (binary) or BY/FY/SY/N (multi-class) """ import torch import torch.nn as nn import torch.nn.functional as F from .graph import Graph class STGCNLayer(nn.Module): """ Spatial-Temporal Graph Convolutional Layer. Combines spatial graph convolution and temporal convolution. """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, dropout=0.5, residual=True ): """ Initialize ST-GCN layer. Args: in_channels: Number of input channels out_channels: Number of output channels kernel_size: Tuple (temporal_kernel_size, spatial_kernel_size) stride: Temporal stride for downsampling dropout: Dropout probability residual: Whether to use residual connection """ super(STGCNLayer, self).__init__() assert len(kernel_size) == 2, "Kernel size must be (temporal, spatial)" assert kernel_size[0] % 2 == 1, "Temporal kernel size must be odd" padding = ((kernel_size[0] - 1) // 2, 0) # Temporal padding only # Spatial graph convolution self.gcn = SpatialGraphConv( in_channels, out_channels, kernel_size[1] ) # Temporal convolution self.tcn = nn.Sequential( nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d( out_channels, out_channels, (kernel_size[0], 1), (stride, 1), padding, ), nn.BatchNorm2d(out_channels), nn.Dropout(dropout, inplace=True), ) # Residual connection if not residual: self.residual = lambda x: 0 elif (in_channels == out_channels) and (stride == 1): self.residual = lambda x: x else: self.residual = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=(stride, 1) ), nn.BatchNorm2d(out_channels), ) self.relu = nn.ReLU(inplace=True) def forward(self, x, A): """ Forward pass. Args: x: Input tensor (N, C, T, V) A: Adjacency matrix (K, V, V) where K is number of partitions Returns: Output tensor (N, C', T', V) """ res = self.residual(x) x = self.gcn(x, A) x = self.tcn(x) + res return self.relu(x) class SpatialGraphConv(nn.Module): """ Spatial graph convolutional layer. Applies graph convolution on skeleton graph using adjacency matrix. """ def __init__(self, in_channels, out_channels, kernel_size, bias=True): """ Initialize spatial graph convolution. Args: in_channels: Number of input channels out_channels: Number of output channels kernel_size: Number of adjacency matrix partitions (1 or 3) bias: Whether to include bias term """ super(SpatialGraphConv, self).__init__() self.kernel_size = kernel_size # Convolutional weights for each partition self.conv = nn.Conv2d( in_channels, out_channels * kernel_size, kernel_size=1, bias=bias ) def forward(self, x, A): """ Forward pass. Args: x: Input tensor (N, C, T, V) A: Adjacency matrix (K, V, V) Returns: Output tensor (N, C', T, V) """ assert A.size(0) == self.kernel_size, \ f"Adjacency matrix size {A.size(0)} != kernel size {self.kernel_size}" # Apply convolution x = self.conv(x) # (N, C'*K, T, V) # Split channels for each partition n, kc, t, v = x.size() x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) # (N, K, C', T, V) # Apply graph convolution with each partition # A: (K, V, V) # x: (N, K, C', T, V) x = torch.einsum('nkctv,kvw->nctw', x, A) # (N, C', T, V) return x.contiguous() class STGCN(nn.Module): """ ST-GCN model for fall detection. Architecture: - Input: (N, 3, 60, 17, 1) - batch, channels, frames, joints, persons - ST-GCN layers: Extract spatial-temporal features - Global pooling: Aggregate features across time and space - FC layers: Classification (binary or multi-class) """ def __init__( self, num_classes=2, in_channels=3, edge_importance_weighting=True, graph_cfg=None, dropout=0.5, **kwargs ): """ Initialize ST-GCN model. Args: num_classes: Number of output classes (2 for binary, 4 for multi-class) in_channels: Number of input channels (3: x, y, confidence) edge_importance_weighting: Whether to learn edge importance weights graph_cfg: Graph configuration (default: spatial labeling) dropout: Dropout probability """ super(STGCN, self).__init__() # Load graph if graph_cfg is None: graph_cfg = {'labeling_mode': 'spatial'} self.graph = Graph(**graph_cfg) # Get adjacency matrix (K, V, V) where K=3 for spatial labeling A = torch.tensor( self.graph.get_adjacency_matrix(normalize=True), dtype=torch.float32, requires_grad=False ) self.register_buffer('A', A) # Number of adjacency matrix partitions spatial_kernel_size = A.size(0) # 3 for spatial labeling # Temporal kernel size (odd numbers for symmetric padding) temporal_kernel_size = 9 # Build ST-GCN layers kernel_size = (temporal_kernel_size, spatial_kernel_size) # Layer configurations: (in_channels, out_channels, stride) self.st_gcn_networks = nn.ModuleList(( STGCNLayer(in_channels, 64, kernel_size, 1, dropout, residual=False), STGCNLayer(64, 64, kernel_size, 1, dropout), STGCNLayer(64, 64, kernel_size, 1, dropout), STGCNLayer(64, 64, kernel_size, 1, dropout), STGCNLayer(64, 128, kernel_size, 2, dropout), STGCNLayer(128, 128, kernel_size, 1, dropout), STGCNLayer(128, 128, kernel_size, 1, dropout), STGCNLayer(128, 256, kernel_size, 2, dropout), STGCNLayer(256, 256, kernel_size, 1, dropout), STGCNLayer(256, 256, kernel_size, 1, dropout), )) # Edge importance weighting if edge_importance_weighting: self.edge_importance = nn.ParameterList([ nn.Parameter(torch.ones(self.A.size())) for _ in self.st_gcn_networks ]) else: self.edge_importance = [1] * len(self.st_gcn_networks) # Fully connected layer for classification self.fcn = nn.Conv2d(256, num_classes, kernel_size=1) def forward(self, x): """ Forward pass. Args: x: Input tensor (N, C, T, V, M) - N: Batch size - C: Number of channels (3) - T: Number of frames (60) - V: Number of joints (17) - M: Number of persons (1) Returns: Output logits (N, num_classes) """ # Reshape input: (N, C, T, V, M) -> (N*M, C, T, V) N, C, T, V, M = x.size() x = x.permute(0, 4, 1, 2, 3).contiguous() # (N, M, C, T, V) x = x.view(N * M, C, T, V) # (N*M, C, T, V) # Forward through ST-GCN layers for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): x = gcn(x, self.A * importance) # Global pooling: (N*M, C, T, V) -> (N*M, C) x = F.avg_pool2d(x, x.size()[2:]) # (N*M, C, 1, 1) x = x.view(N, M, -1, 1, 1).mean(dim=1) # Average across persons: (N, C, 1, 1) # Classification x = self.fcn(x) # (N, num_classes, 1, 1) x = x.view(x.size(0), -1) # (N, num_classes) return x def extract_features(self, x): """ Extract features before classification layer. Args: x: Input tensor (N, C, T, V, M) Returns: Feature tensor (N, 256) """ # Reshape input N, C, T, V, M = x.size() x = x.permute(0, 4, 1, 2, 3).contiguous() x = x.view(N * M, C, T, V) # Forward through ST-GCN layers for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): x = gcn(x, self.A * importance) # Global pooling x = F.avg_pool2d(x, x.size()[2:]) x = x.view(N, M, -1).mean(dim=1) # (N, 256) return x def stgcn_binary(pretrained=False, **kwargs): """ ST-GCN for binary fall detection (Fall vs Non-Fall). Args: pretrained: Whether to load pretrained weights (not implemented) **kwargs: Additional model arguments Returns: ST-GCN model """ model = STGCN(num_classes=2, **kwargs) if pretrained: raise NotImplementedError("Pretrained weights not available") return model def stgcn_multiclass(pretrained=False, **kwargs): """ ST-GCN for multi-class fall detection (BY/FY/SY/N). Args: pretrained: Whether to load pretrained weights (not implemented) **kwargs: Additional model arguments Returns: ST-GCN model """ model = STGCN(num_classes=4, **kwargs) if pretrained: raise NotImplementedError("Pretrained weights not available") return model if __name__ == '__main__': # Test model construction print("Testing ST-GCN Model...") # Binary classification model_binary = stgcn_binary() print(f"\nBinary ST-GCN:") print(f" Parameters: {sum(p.numel() for p in model_binary.parameters()):,}") print(f" Trainable: {sum(p.numel() for p in model_binary.parameters() if p.requires_grad):,}") # Multi-class classification model_multiclass = stgcn_multiclass() print(f"\nMulti-class ST-GCN:") print(f" Parameters: {sum(p.numel() for p in model_multiclass.parameters()):,}") # Test forward pass batch_size = 4 input_tensor = torch.randn(batch_size, 3, 60, 17, 1) print(f"\nInput shape: {input_tensor.shape}") # Binary output output_binary = model_binary(input_tensor) print(f"Binary output shape: {output_binary.shape}") print(f"Binary output: {output_binary}") # Multi-class output output_multiclass = model_multiclass(input_tensor) print(f"Multi-class output shape: {output_multiclass.shape}") # Feature extraction features = model_binary.extract_features(input_tensor) print(f"Feature shape: {features.shape}") print("\nST-GCN model construction successful!")