Baramnuri / model.py
koreashin's picture
Upload 6 files
5a64d6e verified
"""
BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model
A hybrid architecture combining:
- Video Swin Transformer (Stage 1-3) for spatial features
- Selective State Space Model (SSM) for temporal modeling
Trained via Knowledge Distillation from Video Swin-T teacher.
Author: C-Team
License: Apache-2.0
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.video import swin3d_t, Swin3D_T_Weights
from typing import Dict, Tuple
class SelectiveSSM(nn.Module):
"""
Selective State Space Model (Mamba-style)
Key: Dynamically generates B, C, delta based on input
- Important information is remembered
- Less important information is quickly forgotten
"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = d_model * expand
# Input projection (expansion)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 1D convolution (local context)
self.conv1d = nn.Conv1d(
self.d_inner, self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
groups=self.d_inner
)
# SSM parameter generator (selective!)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
# A parameter (learnable diagonal matrix)
self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
self.D = nn.Parameter(torch.ones(self.d_inner))
# Output projection
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, D]
Returns:
y: [B, T, D]
"""
residual = x
x = self.layer_norm(x)
B, T, D = x.shape
# Input projection -> [B, T, 2*d_inner]
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
# 1D Conv (capture local context)
x = x.transpose(1, 2)
x = self.conv1d(x)[:, :, :T]
x = x.transpose(1, 2)
x = F.silu(x)
# Selective SSM parameter generation
x_ssm = self.x_proj(x)
B_t = x_ssm[:, :, :self.d_state]
C_t = x_ssm[:, :, self.d_state:self.d_state*2]
delta = F.softplus(x_ssm[:, :, -1:])
# A parameter (negative for stability)
A = -torch.exp(self.A_log)
# Discretization: A_bar = exp(delta * A)
A_bar = torch.exp(delta * A.view(1, 1, -1))
# SSM scan
h = torch.zeros(B, self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
outputs = []
for t in range(T):
x_t = x[:, t, :]
B_t_t = B_t[:, t, :]
C_t_t = C_t[:, t, :]
A_bar_t = A_bar[:, t, :]
# h = A_bar * h + B_t * x
h = h * A_bar_t.unsqueeze(1) + B_t_t.unsqueeze(1) * x_t.unsqueeze(-1)
# y = C_t * h + D * x
y_t = (C_t_t.unsqueeze(1) * h).sum(dim=-1) + self.D * x_t
outputs.append(y_t)
y = torch.stack(outputs, dim=1)
# Gating
y = y * F.silu(z)
# Output projection
y = self.out_proj(y)
y = self.dropout(y)
return y + residual
class TemporalSSMBlock(nn.Module):
"""
Temporal SSM Block for video
Takes [B, T, C] sequence and applies SSM layers
"""
def __init__(self, d_model: int, d_state: int = 16, n_layers: int = 2, dropout: float = 0.1):
super().__init__()
self.ssm_layers = nn.ModuleList([
SelectiveSSM(d_model, d_state=d_state, dropout=dropout)
for _ in range(n_layers)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [B, T, D] sequence
Returns:
y: [B, D] final representation
"""
for ssm in self.ssm_layers:
x = ssm(x)
return x.mean(dim=1)
class BaramNuri(nn.Module):
"""
BaramNuri (바람누리) - Lightweight Driver Behavior Detection Model
Architecture:
1. Video Swin-T Stages 1-3 (spatial features, 384 dim)
2. Selective SSM Block (temporal modeling)
3. Classification Head
Parameters: 14.20M (49% reduction from teacher's 27.86M)
Performance: 96.17% accuracy, 0.9504 Macro F1
"""
CLASS_NAMES = ["정상", "졸음운전", "물건찾기", "휴대폰 사용", "운전자 폭행"]
CLASS_NAMES_EN = ["normal", "drowsy_driving", "searching_object", "phone_usage", "driver_assault"]
def __init__(
self,
num_classes: int = 5,
pretrained: bool = True,
d_state: int = 16,
ssm_layers: int = 2,
dropout: float = 0.2,
):
super().__init__()
self.num_classes = num_classes
# Load Video Swin-T backbone (only Stage 1-3)
if pretrained:
print("Loading Swin backbone (Kinetics-400 pretrained)...")
full_swin = swin3d_t(weights=Swin3D_T_Weights.KINETICS400_V1)
else:
full_swin = swin3d_t(weights=None)
# Patch embedding
self.patch_embed = full_swin.patch_embed
# Use only Stage 1-3 (features[0:5]) for 384 dim output
self.features = nn.Sequential(*[full_swin.features[i] for i in range(5)])
# Stage 3 output: 384 dim
self.feature_dim = 384
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool3d(output_size=1)
# SSM temporal modeling block
self.temporal_ssm = TemporalSSMBlock(
d_model=self.feature_dim,
d_state=d_state,
n_layers=ssm_layers,
dropout=dropout,
)
# Classification head
self.head = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Dropout(p=dropout),
nn.Linear(self.feature_dim, num_classes),
)
# Initialize head
self._init_head()
# Delete Stage 4 parameters (memory saving)
del full_swin
def _init_head(self):
"""Initialize head weights"""
for m in self.head.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def extract_features(self, x: torch.Tensor) -> torch.Tensor:
"""
Extract features (for knowledge distillation)
Args:
x: [B, C, T, H, W]
Returns:
features: [B, feature_dim]
"""
# Patch embedding
x = self.patch_embed(x)
# Swin Stages
x = self.features(x)
B, T, H, W, C = x.shape
# Spatial average -> [B, T, C] sequence
x = x.mean(dim=[2, 3])
# SSM temporal modeling
x = self.temporal_ssm(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass
Args:
x: [B, C, T, H, W] video tensor
Returns:
logits: [B, num_classes]
"""
features = self.extract_features(x)
logits = self.head(features)
return logits
def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Return both features and logits (for knowledge distillation)
"""
features = self.extract_features(x)
logits = self.head(features)
return logits, features
def predict(self, x: torch.Tensor, return_english: bool = False) -> Dict:
"""
Inference prediction
Args:
x: [1, C, T, H, W] single video
return_english: Return English class names
Returns:
dict with class, confidence, class_name
"""
self.eval()
with torch.no_grad():
logits = self.forward(x)
probs = F.softmax(logits, dim=-1)[0]
class_idx = probs.argmax().item()
class_names = self.CLASS_NAMES_EN if return_english else self.CLASS_NAMES
return {
"class": class_idx,
"confidence": probs[class_idx].item(),
"class_name": class_names[class_idx],
"all_probs": {
name: probs[i].item()
for i, name in enumerate(class_names)
}
}
@classmethod
def from_pretrained(cls, checkpoint_path: str, device: str = 'cpu'):
"""
Load pretrained model from checkpoint
Args:
checkpoint_path: Path to .pth file
device: 'cpu' or 'cuda'
Returns:
Loaded model in eval mode
"""
model = cls(num_classes=5, pretrained=True)
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
return model
def count_parameters(model: nn.Module) -> int:
"""Count total model parameters"""
return sum(p.numel() for p in model.parameters())
if __name__ == "__main__":
print("=" * 60)
print("BaramNuri Model Test")
print("=" * 60)
# Create model
model = BaramNuri(num_classes=5, pretrained=True)
# Parameter count
total_params = count_parameters(model)
print(f"\nTotal parameters: {total_params:,} ({total_params/1e6:.2f}M)")
# Test with dummy input
dummy_input = torch.randn(2, 3, 30, 224, 224)
print(f"\nInput shape: {dummy_input.shape}")
# Forward pass
model.eval()
with torch.no_grad():
output = model(dummy_input)
print(f"Output shape: {output.shape}")
# Single sample prediction test
single_input = torch.randn(1, 3, 30, 224, 224)
prediction = model.predict(single_input)
print(f"\nPrediction (Korean): {prediction['class_name']} ({prediction['confidence']:.2%})")
prediction_en = model.predict(single_input, return_english=True)
print(f"Prediction (English): {prediction_en['class_name']} ({prediction_en['confidence']:.2%})")
print("\nModel test passed!")