""" BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model A hybrid architecture combining: - Video Swin Transformer (Stage 1-3) for spatial features - Selective State Space Model (SSM) for temporal modeling Trained via Knowledge Distillation from Video Swin-T teacher. Author: C-Team License: Apache-2.0 """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models.video import swin3d_t, Swin3D_T_Weights from typing import Dict, Tuple class SelectiveSSM(nn.Module): """ Selective State Space Model (Mamba-style) Key: Dynamically generates B, C, delta based on input - Important information is remembered - Less important information is quickly forgotten """ def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dropout: float = 0.1): super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = d_model * expand # Input projection (expansion) self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) # 1D convolution (local context) self.conv1d = nn.Conv1d( self.d_inner, self.d_inner, kernel_size=d_conv, padding=d_conv - 1, groups=self.d_inner ) # SSM parameter generator (selective!) self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) # A parameter (learnable diagonal matrix) self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32))) self.D = nn.Parameter(torch.ones(self.d_inner)) # Output projection self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [B, T, D] Returns: y: [B, T, D] """ residual = x x = self.layer_norm(x) B, T, D = x.shape # Input projection -> [B, T, 2*d_inner] xz = self.in_proj(x) x, z = xz.chunk(2, dim=-1) # 1D Conv (capture local context) x = x.transpose(1, 2) x = self.conv1d(x)[:, :, :T] x = x.transpose(1, 2) x = F.silu(x) # Selective SSM parameter generation x_ssm = self.x_proj(x) B_t = x_ssm[:, :, :self.d_state] C_t = x_ssm[:, :, self.d_state:self.d_state*2] delta = F.softplus(x_ssm[:, :, -1:]) # A parameter (negative for stability) A = -torch.exp(self.A_log) # Discretization: A_bar = exp(delta * A) A_bar = torch.exp(delta * A.view(1, 1, -1)) # SSM scan h = torch.zeros(B, self.d_inner, self.d_state, device=x.device, dtype=x.dtype) outputs = [] for t in range(T): x_t = x[:, t, :] B_t_t = B_t[:, t, :] C_t_t = C_t[:, t, :] A_bar_t = A_bar[:, t, :] # h = A_bar * h + B_t * x h = h * A_bar_t.unsqueeze(1) + B_t_t.unsqueeze(1) * x_t.unsqueeze(-1) # y = C_t * h + D * x y_t = (C_t_t.unsqueeze(1) * h).sum(dim=-1) + self.D * x_t outputs.append(y_t) y = torch.stack(outputs, dim=1) # Gating y = y * F.silu(z) # Output projection y = self.out_proj(y) y = self.dropout(y) return y + residual class TemporalSSMBlock(nn.Module): """ Temporal SSM Block for video Takes [B, T, C] sequence and applies SSM layers """ def __init__(self, d_model: int, d_state: int = 16, n_layers: int = 2, dropout: float = 0.1): super().__init__() self.ssm_layers = nn.ModuleList([ SelectiveSSM(d_model, d_state=d_state, dropout=dropout) for _ in range(n_layers) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [B, T, D] sequence Returns: y: [B, D] final representation """ for ssm in self.ssm_layers: x = ssm(x) return x.mean(dim=1) class BaramNuri(nn.Module): """ BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model Architecture: 1. Video Swin-T Stages 1-3 (spatial features, 384 dim) 2. Selective SSM Block (temporal modeling) 3. Classification Head Parameters: 14.20M (49% reduction from teacher's 27.86M) Performance: 96.17% accuracy, 0.9504 Macro F1 """ CLASS_NAMES = ["정상", "졸음운전", "물건찾기", "휴대폰 사용", "운전자 폭행"] CLASS_NAMES_EN = ["normal", "drowsy_driving", "searching_object", "phone_usage", "driver_assault"] def __init__( self, num_classes: int = 5, pretrained: bool = True, d_state: int = 16, ssm_layers: int = 2, dropout: float = 0.2, ): super().__init__() self.num_classes = num_classes # Load Video Swin-T backbone (only Stage 1-3) if pretrained: print("Loading Swin backbone (Kinetics-400 pretrained)...") full_swin = swin3d_t(weights=Swin3D_T_Weights.KINETICS400_V1) else: full_swin = swin3d_t(weights=None) # Patch embedding self.patch_embed = full_swin.patch_embed # Use only Stage 1-3 (features[0:5]) for 384 dim output self.features = nn.Sequential(*[full_swin.features[i] for i in range(5)]) # Stage 3 output: 384 dim self.feature_dim = 384 # Global average pooling self.avgpool = nn.AdaptiveAvgPool3d(output_size=1) # SSM temporal modeling block self.temporal_ssm = TemporalSSMBlock( d_model=self.feature_dim, d_state=d_state, n_layers=ssm_layers, dropout=dropout, ) # Classification head self.head = nn.Sequential( nn.LayerNorm(self.feature_dim), nn.Dropout(p=dropout), nn.Linear(self.feature_dim, num_classes), ) # Initialize head self._init_head() # Delete Stage 4 parameters (memory saving) del full_swin def _init_head(self): """Initialize head weights""" for m in self.head.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def extract_features(self, x: torch.Tensor) -> torch.Tensor: """ Extract features (for knowledge distillation) Args: x: [B, C, T, H, W] Returns: features: [B, feature_dim] """ # Patch embedding x = self.patch_embed(x) # Swin Stages x = self.features(x) B, T, H, W, C = x.shape # Spatial average -> [B, T, C] sequence x = x.mean(dim=[2, 3]) # SSM temporal modeling x = self.temporal_ssm(x) return x def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass Args: x: [B, C, T, H, W] video tensor Returns: logits: [B, num_classes] """ features = self.extract_features(x) logits = self.head(features) return logits def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Return both features and logits (for knowledge distillation) """ features = self.extract_features(x) logits = self.head(features) return logits, features def predict(self, x: torch.Tensor, return_english: bool = False) -> Dict: """ Inference prediction Args: x: [1, C, T, H, W] single video return_english: Return English class names Returns: dict with class, confidence, class_name """ self.eval() with torch.no_grad(): logits = self.forward(x) probs = F.softmax(logits, dim=-1)[0] class_idx = probs.argmax().item() class_names = self.CLASS_NAMES_EN if return_english else self.CLASS_NAMES return { "class": class_idx, "confidence": probs[class_idx].item(), "class_name": class_names[class_idx], "all_probs": { name: probs[i].item() for i, name in enumerate(class_names) } } @classmethod def from_pretrained(cls, checkpoint_path: str, device: str = 'cpu'): """ Load pretrained model from checkpoint Args: checkpoint_path: Path to .pth file device: 'cpu' or 'cuda' Returns: Loaded model in eval mode """ model = cls(num_classes=5, pretrained=True) checkpoint = torch.load(checkpoint_path, map_location=device) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model = model.to(device) model.eval() return model def count_parameters(model: nn.Module) -> int: """Count total model parameters""" return sum(p.numel() for p in model.parameters()) if __name__ == "__main__": print("=" * 60) print("BaramNuri Model Test") print("=" * 60) # Create model model = BaramNuri(num_classes=5, pretrained=True) # Parameter count total_params = count_parameters(model) print(f"\nTotal parameters: {total_params:,} ({total_params/1e6:.2f}M)") # Test with dummy input dummy_input = torch.randn(2, 3, 30, 224, 224) print(f"\nInput shape: {dummy_input.shape}") # Forward pass model.eval() with torch.no_grad(): output = model(dummy_input) print(f"Output shape: {output.shape}") # Single sample prediction test single_input = torch.randn(1, 3, 30, 224, 224) prediction = model.predict(single_input) print(f"\nPrediction (Korean): {prediction['class_name']} ({prediction['confidence']:.2%})") prediction_en = model.predict(single_input, return_english=True) print(f"Prediction (English): {prediction_en['class_name']} ({prediction_en['confidence']:.2%})") print("\nModel test passed!")