Spaces:
Sleeping
Sleeping
| from typing import Optional, Sequence, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| class ResPool3d(nn.Module): | |
| """ | |
| Residual pooling block from your current project. | |
| It keeps important max-pooled responses, | |
| then adds them back to the original tensor. | |
| """ | |
| __constants__ = ["kernel_size", "stride", "padding", "dilation", "ceil_mode"] | |
| ceil_mode: bool | |
| def __init__( | |
| self, | |
| kernel_size: Union[int, Tuple[int, ...]], | |
| stride: Optional[Union[int, Tuple[int, ...]]] = None, | |
| padding: Union[int, Tuple[int, ...]] = 0, | |
| dilation: Union[int, Tuple[int, ...]] = 1, | |
| ceil_mode: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.kernel_size = kernel_size | |
| self.stride = stride if stride is not None else kernel_size | |
| self.padding = padding | |
| self.dilation = dilation | |
| self.ceil_mode = ceil_mode | |
| def forward(self, input_tensor: Tensor) -> Tensor: | |
| max_pooled, indices = F.max_pool3d( | |
| input_tensor, | |
| kernel_size=self.kernel_size, | |
| stride=self.stride, | |
| padding=self.padding, | |
| dilation=self.dilation, | |
| ceil_mode=self.ceil_mode, | |
| return_indices=True, | |
| ) | |
| output_shape = input_tensor.shape | |
| output_tensor = torch.zeros( | |
| output_shape, | |
| dtype=input_tensor.dtype, | |
| device=input_tensor.device, | |
| ) | |
| output_tensor = ( | |
| output_tensor.view(-1) | |
| .scatter_( | |
| 0, | |
| indices.view(-1), | |
| max_pooled.view(-1), | |
| ) | |
| .view(output_shape) | |
| ) | |
| return input_tensor + output_tensor | |
| class FocalLoss(nn.Module): | |
| """ | |
| Optional focal loss helper for posture training. | |
| """ | |
| def __init__( | |
| self, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "mean" | |
| ) -> None: | |
| super().__init__() | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| if self.reduction not in ["mean", "sum"]: | |
| raise ValueError("reduction must be either 'mean' or 'sum'") | |
| def forward(self, inputs: Tensor, labels: Tensor) -> Tensor: | |
| probs = torch.softmax(inputs, dim=-1) | |
| alpha_tensor = torch.tensor( | |
| [ | |
| [self.alpha] if label.item() == 1 else [1 - self.alpha] | |
| for label in labels | |
| ], | |
| dtype=inputs.dtype, | |
| device=inputs.device, | |
| ) | |
| pt = probs.gather(dim=-1, index=labels.unsqueeze(1)) | |
| loss = (alpha_tensor * ((1 - pt) ** self.gamma)) * (-torch.log(pt + 1e-12)) | |
| if self.reduction == "mean": | |
| return loss.mean() | |
| if self.reduction == "sum": | |
| return loss.sum() | |
| return loss | |
| class L2Regularization(nn.Module): | |
| """ | |
| Optional L2 regularization helper for posture training. | |
| """ | |
| def __init__(self, l2_lambda: float = 0.0001) -> None: | |
| super().__init__() | |
| self.l2_lambda = l2_lambda | |
| def forward(self, model: nn.Module) -> Tensor: | |
| device = next(model.parameters()).device | |
| l2_reg = torch.tensor(0.0, device=device) | |
| for parameter in model.parameters(): | |
| l2_reg += torch.sum(parameter**2) | |
| return self.l2_lambda * l2_reg | |
| class MCLoss(nn.Module): | |
| """ | |
| Combined loss from your current project design: | |
| - Cross entropy | |
| - Focal loss | |
| - L2 regularization | |
| """ | |
| def __init__( | |
| self, | |
| w1: float = 0.6, | |
| w2: float = 0.3, | |
| w3: float = 0.1, | |
| focal_alpha: float = 0.25, | |
| focal_gamma: float = 2.0, | |
| l2_lambda: float = 0.0001, | |
| ) -> None: | |
| super().__init__() | |
| self.w1 = w1 | |
| self.w2 = w2 | |
| self.w3 = w3 | |
| self.focal = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) | |
| self.l2 = L2Regularization(l2_lambda=l2_lambda) | |
| def forward( | |
| self, outputs: Tensor, labels: Tensor, mlp3d_instance: nn.Module | |
| ) -> Tensor: | |
| ce_loss = F.cross_entropy(outputs, labels) | |
| focal_loss = self.focal(outputs, labels) | |
| l2_reg = self.l2(mlp3d_instance) | |
| total_loss = self.w1 * ce_loss + self.w2 * focal_loss + self.w3 * l2_reg | |
| return total_loss | |
| class MLP3d(nn.Module): | |
| """ | |
| Main posture classifier copied cleanly from your current project. | |
| Input shape: | |
| (N, C=2, D=7, H=12, W=11) | |
| """ | |
| def __init__( | |
| self, | |
| input_channel_num: int, | |
| output_class_num: int, | |
| input_shape: Tuple[int, int, int] = (7, 12, 11), | |
| conv_kernel_size: Union[int, Tuple[int, int, int]] = (3, 5, 5), | |
| pool_kernel_size: Union[int, Tuple[int, int, int]] = 2, | |
| activation_name: str = "SiLU", | |
| fc_dims: Sequence[int] = (7392, 1848, 256), | |
| ) -> None: | |
| super().__init__() | |
| self.input_shape = input_shape | |
| self.kernel_size = conv_kernel_size | |
| self.pool_kernel_size = pool_kernel_size | |
| self.activation_name = activation_name | |
| self.fc_dims = list(fc_dims) | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv3d( | |
| in_channels=input_channel_num, | |
| out_channels=8, | |
| kernel_size=self.kernel_size, | |
| padding="same", | |
| ), | |
| nn.BatchNorm3d(num_features=8), | |
| self._build_activation(), | |
| nn.Conv3d( | |
| in_channels=8, | |
| out_channels=16, | |
| kernel_size=self.kernel_size, | |
| padding="same", | |
| ), | |
| nn.BatchNorm3d(num_features=16), | |
| self._build_activation(), | |
| nn.Conv3d( | |
| in_channels=16, | |
| out_channels=32, | |
| kernel_size=self.kernel_size, | |
| padding="same", | |
| ), | |
| nn.BatchNorm3d(num_features=32), | |
| self._build_activation(), | |
| ResPool3d( | |
| kernel_size=self.pool_kernel_size, | |
| stride=self.pool_kernel_size, | |
| padding=0, | |
| ), | |
| ) | |
| flattened_features = self._infer_flattened_features(input_channel_num) | |
| fc_layer_sizes = [flattened_features, *self.fc_dims, output_class_num] | |
| fc_modules: list[nn.Module] = [] | |
| for idx in range(len(fc_layer_sizes) - 1): | |
| fc_modules.append( | |
| nn.Linear( | |
| in_features=fc_layer_sizes[idx], | |
| out_features=fc_layer_sizes[idx + 1], | |
| ) | |
| ) | |
| if idx < len(fc_layer_sizes) - 2: | |
| fc_modules.append(self._build_activation()) | |
| self.fc_layers = nn.Sequential(*fc_modules) | |
| def _build_activation(self) -> nn.Module: | |
| activation_map = { | |
| "relu": nn.ReLU, | |
| "silu": nn.SiLU, | |
| "gelu": nn.GELU, | |
| "leakyrelu": nn.LeakyReLU, | |
| } | |
| normalized_name = self.activation_name.lower().strip() | |
| activation_cls = activation_map.get(normalized_name) | |
| if activation_cls is None: | |
| supported = ", ".join(sorted(activation_map)) | |
| raise ValueError( | |
| f"Unsupported activation '{self.activation_name}'. Supported: {supported}" | |
| ) | |
| return activation_cls() | |
| def _infer_flattened_features(self, input_channel_num: int) -> int: | |
| with torch.no_grad(): | |
| dummy_input = torch.zeros( | |
| 1, | |
| input_channel_num, | |
| self.input_shape[0], | |
| self.input_shape[1], | |
| self.input_shape[2], | |
| ) | |
| conv_output = self.conv_layers(dummy_input) | |
| return int(torch.flatten(conv_output, start_dim=1).shape[1]) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.conv_layers(x) | |
| x = torch.flatten(x, start_dim=1) | |
| x = self.fc_layers(x) | |
| return x | |