Spaces:
Sleeping
Sleeping
| """ | |
| ST-GCN Model for Fall Detection | |
| Spatial-Temporal Graph Convolutional Networks for skeleton-based action recognition. | |
| Adapted for binary fall detection (Fall vs Non-Fall) and multi-class fall type classification. | |
| References: | |
| - ST-GCN Paper: https://arxiv.org/abs/1801.07455 | |
| - Official Implementation: https://github.com/yysijie/st-gcn | |
| - Fall Detection: Keskes & Noumeir (2021) | |
| Input Shape: (N, C, T, V, M) | |
| - N: Batch size | |
| - C: Number of channels (3: x, y, confidence) | |
| - T: Temporal dimension (number of frames) | |
| - V: Number of vertices (17 COCO keypoints) | |
| - M: Number of persons (1 for single-person scenarios) | |
| Output: Class logits for Fall/Non-Fall (binary) or BY/FY/SY/N (multi-class) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .graph import Graph | |
| class STGCNLayer(nn.Module): | |
| """ | |
| Spatial-Temporal Graph Convolutional Layer. | |
| Combines spatial graph convolution and temporal convolution. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| dropout=0.5, | |
| residual=True | |
| ): | |
| """ | |
| Initialize ST-GCN layer. | |
| Args: | |
| in_channels: Number of input channels | |
| out_channels: Number of output channels | |
| kernel_size: Tuple (temporal_kernel_size, spatial_kernel_size) | |
| stride: Temporal stride for downsampling | |
| dropout: Dropout probability | |
| residual: Whether to use residual connection | |
| """ | |
| super(STGCNLayer, self).__init__() | |
| assert len(kernel_size) == 2, "Kernel size must be (temporal, spatial)" | |
| assert kernel_size[0] % 2 == 1, "Temporal kernel size must be odd" | |
| padding = ((kernel_size[0] - 1) // 2, 0) # Temporal padding only | |
| # Spatial graph convolution | |
| self.gcn = SpatialGraphConv( | |
| in_channels, | |
| out_channels, | |
| kernel_size[1] | |
| ) | |
| # Temporal convolution | |
| self.tcn = nn.Sequential( | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d( | |
| out_channels, | |
| out_channels, | |
| (kernel_size[0], 1), | |
| (stride, 1), | |
| padding, | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| nn.Dropout(dropout, inplace=True), | |
| ) | |
| # Residual connection | |
| if not residual: | |
| self.residual = lambda x: 0 | |
| elif (in_channels == out_channels) and (stride == 1): | |
| self.residual = lambda x: x | |
| else: | |
| self.residual = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=(stride, 1) | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x, A): | |
| """ | |
| Forward pass. | |
| Args: | |
| x: Input tensor (N, C, T, V) | |
| A: Adjacency matrix (K, V, V) where K is number of partitions | |
| Returns: | |
| Output tensor (N, C', T', V) | |
| """ | |
| res = self.residual(x) | |
| x = self.gcn(x, A) | |
| x = self.tcn(x) + res | |
| return self.relu(x) | |
| class SpatialGraphConv(nn.Module): | |
| """ | |
| Spatial graph convolutional layer. | |
| Applies graph convolution on skeleton graph using adjacency matrix. | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size, bias=True): | |
| """ | |
| Initialize spatial graph convolution. | |
| Args: | |
| in_channels: Number of input channels | |
| out_channels: Number of output channels | |
| kernel_size: Number of adjacency matrix partitions (1 or 3) | |
| bias: Whether to include bias term | |
| """ | |
| super(SpatialGraphConv, self).__init__() | |
| self.kernel_size = kernel_size | |
| # Convolutional weights for each partition | |
| self.conv = nn.Conv2d( | |
| in_channels, | |
| out_channels * kernel_size, | |
| kernel_size=1, | |
| bias=bias | |
| ) | |
| def forward(self, x, A): | |
| """ | |
| Forward pass. | |
| Args: | |
| x: Input tensor (N, C, T, V) | |
| A: Adjacency matrix (K, V, V) | |
| Returns: | |
| Output tensor (N, C', T, V) | |
| """ | |
| assert A.size(0) == self.kernel_size, \ | |
| f"Adjacency matrix size {A.size(0)} != kernel size {self.kernel_size}" | |
| # Apply convolution | |
| x = self.conv(x) # (N, C'*K, T, V) | |
| # Split channels for each partition | |
| n, kc, t, v = x.size() | |
| x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) # (N, K, C', T, V) | |
| # Apply graph convolution with each partition | |
| # A: (K, V, V) | |
| # x: (N, K, C', T, V) | |
| x = torch.einsum('nkctv,kvw->nctw', x, A) # (N, C', T, V) | |
| return x.contiguous() | |
| class STGCN(nn.Module): | |
| """ | |
| ST-GCN model for fall detection. | |
| Architecture: | |
| - Input: (N, 3, 60, 17, 1) - batch, channels, frames, joints, persons | |
| - ST-GCN layers: Extract spatial-temporal features | |
| - Global pooling: Aggregate features across time and space | |
| - FC layers: Classification (binary or multi-class) | |
| """ | |
| def __init__( | |
| self, | |
| num_classes=2, | |
| in_channels=3, | |
| edge_importance_weighting=True, | |
| graph_cfg=None, | |
| dropout=0.5, | |
| **kwargs | |
| ): | |
| """ | |
| Initialize ST-GCN model. | |
| Args: | |
| num_classes: Number of output classes (2 for binary, 4 for multi-class) | |
| in_channels: Number of input channels (3: x, y, confidence) | |
| edge_importance_weighting: Whether to learn edge importance weights | |
| graph_cfg: Graph configuration (default: spatial labeling) | |
| dropout: Dropout probability | |
| """ | |
| super(STGCN, self).__init__() | |
| # Load graph | |
| if graph_cfg is None: | |
| graph_cfg = {'labeling_mode': 'spatial'} | |
| self.graph = Graph(**graph_cfg) | |
| # Get adjacency matrix (K, V, V) where K=3 for spatial labeling | |
| A = torch.tensor( | |
| self.graph.get_adjacency_matrix(normalize=True), | |
| dtype=torch.float32, | |
| requires_grad=False | |
| ) | |
| self.register_buffer('A', A) | |
| # Number of adjacency matrix partitions | |
| spatial_kernel_size = A.size(0) # 3 for spatial labeling | |
| # Temporal kernel size (odd numbers for symmetric padding) | |
| temporal_kernel_size = 9 | |
| # Build ST-GCN layers | |
| kernel_size = (temporal_kernel_size, spatial_kernel_size) | |
| # Layer configurations: (in_channels, out_channels, stride) | |
| self.st_gcn_networks = nn.ModuleList(( | |
| STGCNLayer(in_channels, 64, kernel_size, 1, dropout, residual=False), | |
| STGCNLayer(64, 64, kernel_size, 1, dropout), | |
| STGCNLayer(64, 64, kernel_size, 1, dropout), | |
| STGCNLayer(64, 64, kernel_size, 1, dropout), | |
| STGCNLayer(64, 128, kernel_size, 2, dropout), | |
| STGCNLayer(128, 128, kernel_size, 1, dropout), | |
| STGCNLayer(128, 128, kernel_size, 1, dropout), | |
| STGCNLayer(128, 256, kernel_size, 2, dropout), | |
| STGCNLayer(256, 256, kernel_size, 1, dropout), | |
| STGCNLayer(256, 256, kernel_size, 1, dropout), | |
| )) | |
| # Edge importance weighting | |
| if edge_importance_weighting: | |
| self.edge_importance = nn.ParameterList([ | |
| nn.Parameter(torch.ones(self.A.size())) | |
| for _ in self.st_gcn_networks | |
| ]) | |
| else: | |
| self.edge_importance = [1] * len(self.st_gcn_networks) | |
| # Fully connected layer for classification | |
| self.fcn = nn.Conv2d(256, num_classes, kernel_size=1) | |
| def forward(self, x): | |
| """ | |
| Forward pass. | |
| Args: | |
| x: Input tensor (N, C, T, V, M) | |
| - N: Batch size | |
| - C: Number of channels (3) | |
| - T: Number of frames (60) | |
| - V: Number of joints (17) | |
| - M: Number of persons (1) | |
| Returns: | |
| Output logits (N, num_classes) | |
| """ | |
| # Reshape input: (N, C, T, V, M) -> (N*M, C, T, V) | |
| N, C, T, V, M = x.size() | |
| x = x.permute(0, 4, 1, 2, 3).contiguous() # (N, M, C, T, V) | |
| x = x.view(N * M, C, T, V) # (N*M, C, T, V) | |
| # Forward through ST-GCN layers | |
| for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): | |
| x = gcn(x, self.A * importance) | |
| # Global pooling: (N*M, C, T, V) -> (N*M, C) | |
| x = F.avg_pool2d(x, x.size()[2:]) # (N*M, C, 1, 1) | |
| x = x.view(N, M, -1, 1, 1).mean(dim=1) # Average across persons: (N, C, 1, 1) | |
| # Classification | |
| x = self.fcn(x) # (N, num_classes, 1, 1) | |
| x = x.view(x.size(0), -1) # (N, num_classes) | |
| return x | |
| def extract_features(self, x): | |
| """ | |
| Extract features before classification layer. | |
| Args: | |
| x: Input tensor (N, C, T, V, M) | |
| Returns: | |
| Feature tensor (N, 256) | |
| """ | |
| # Reshape input | |
| N, C, T, V, M = x.size() | |
| x = x.permute(0, 4, 1, 2, 3).contiguous() | |
| x = x.view(N * M, C, T, V) | |
| # Forward through ST-GCN layers | |
| for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): | |
| x = gcn(x, self.A * importance) | |
| # Global pooling | |
| x = F.avg_pool2d(x, x.size()[2:]) | |
| x = x.view(N, M, -1).mean(dim=1) # (N, 256) | |
| return x | |
| def stgcn_binary(pretrained=False, **kwargs): | |
| """ | |
| ST-GCN for binary fall detection (Fall vs Non-Fall). | |
| Args: | |
| pretrained: Whether to load pretrained weights (not implemented) | |
| **kwargs: Additional model arguments | |
| Returns: | |
| ST-GCN model | |
| """ | |
| model = STGCN(num_classes=2, **kwargs) | |
| if pretrained: | |
| raise NotImplementedError("Pretrained weights not available") | |
| return model | |
| def stgcn_multiclass(pretrained=False, **kwargs): | |
| """ | |
| ST-GCN for multi-class fall detection (BY/FY/SY/N). | |
| Args: | |
| pretrained: Whether to load pretrained weights (not implemented) | |
| **kwargs: Additional model arguments | |
| Returns: | |
| ST-GCN model | |
| """ | |
| model = STGCN(num_classes=4, **kwargs) | |
| if pretrained: | |
| raise NotImplementedError("Pretrained weights not available") | |
| return model | |
| if __name__ == '__main__': | |
| # Test model construction | |
| print("Testing ST-GCN Model...") | |
| # Binary classification | |
| model_binary = stgcn_binary() | |
| print(f"\nBinary ST-GCN:") | |
| print(f" Parameters: {sum(p.numel() for p in model_binary.parameters()):,}") | |
| print(f" Trainable: {sum(p.numel() for p in model_binary.parameters() if p.requires_grad):,}") | |
| # Multi-class classification | |
| model_multiclass = stgcn_multiclass() | |
| print(f"\nMulti-class ST-GCN:") | |
| print(f" Parameters: {sum(p.numel() for p in model_multiclass.parameters()):,}") | |
| # Test forward pass | |
| batch_size = 4 | |
| input_tensor = torch.randn(batch_size, 3, 60, 17, 1) | |
| print(f"\nInput shape: {input_tensor.shape}") | |
| # Binary output | |
| output_binary = model_binary(input_tensor) | |
| print(f"Binary output shape: {output_binary.shape}") | |
| print(f"Binary output: {output_binary}") | |
| # Multi-class output | |
| output_multiclass = model_multiclass(input_tensor) | |
| print(f"Multi-class output shape: {output_multiclass.shape}") | |
| # Feature extraction | |
| features = model_binary.extract_features(input_tensor) | |
| print(f"Feature shape: {features.shape}") | |
| print("\nST-GCN model construction successful!") | |