| """ |
| Vehicle Warning Lights Classifier - Model Architecture |
| ResNet50 with custom classification head for 68 classes |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torchvision.models as models |
|
|
|
|
| class WarningLightsResNet50(nn.Module): |
| """ |
| ResNet50-based classifier for 68 vehicle warning light types |
| |
| Architecture: |
| Backbone: ResNet50 (pretrained on ImageNet) |
| Classifier Head: |
| - Dropout(0.3) |
| - Linear(2048 β 512) |
| - ReLU |
| - BatchNorm1d(512) |
| - Dropout(0.15) |
| - Linear(512 β 68) |
| |
| Input Shape: (batch, 3, 224, 224) |
| Output Shape: (batch, 68) |
| |
| Parameters: 24,593,028 total |
| """ |
| |
| def __init__(self, num_classes=68, dropout_rate=0.3, pretrained=False): |
| """ |
| Initialize ResNet50 classifier |
| |
| Args: |
| num_classes (int): Number of output classes (default: 68) |
| dropout_rate (float): Dropout probability (default: 0.3) |
| pretrained (bool): Use ImageNet pretrained weights (default: False) |
| Note: Set to True only during initial training |
| """ |
| super(WarningLightsResNet50, self).__init__() |
| |
| |
| if pretrained: |
| self.resnet = models.resnet50(weights='IMAGENET1K_V1') |
| else: |
| self.resnet = models.resnet50(weights=None) |
| |
| |
| num_features = self.resnet.fc.in_features |
| |
| |
| self.resnet.fc = nn.Sequential( |
| |
| nn.Dropout(p=dropout_rate), |
| |
| |
| nn.Linear(num_features, 512), |
| nn.ReLU(inplace=True), |
| |
| |
| nn.BatchNorm1d(512), |
| |
| |
| nn.Dropout(p=dropout_rate / 2), |
| |
| |
| nn.Linear(512, num_classes) |
| ) |
| |
| def forward(self, x): |
| """ |
| Forward pass |
| |
| Args: |
| x (torch.Tensor): Input tensor of shape (batch, 3, 224, 224) |
| |
| Returns: |
| torch.Tensor: Output logits of shape (batch, num_classes) |
| """ |
| return self.resnet(x) |
| |
| def get_num_parameters(self): |
| """Count total and trainable parameters""" |
| total_params = sum(p.numel() for p in self.parameters()) |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return { |
| 'total': total_params, |
| 'trainable': trainable_params, |
| 'frozen': total_params - trainable_params |
| } |
|
|
|
|
| |
| ResNet50Classifier = WarningLightsResNet50 |
|
|
|
|
| def create_model(num_classes=68, pretrained=False, checkpoint_path=None): |
| """ |
| Factory function to create and optionally load a model |
| |
| Args: |
| num_classes (int): Number of classes |
| pretrained (bool): Use ImageNet pretrained weights for backbone |
| checkpoint_path (str): Path to checkpoint file (.pth) |
| |
| Returns: |
| WarningLightsResNet50: Model instance |
| |
| Example: |
| # Create new model |
| model = create_model(num_classes=68, pretrained=True) |
| |
| # Load trained model |
| model = create_model(checkpoint_path="car_warning_lights_resnet50.pth") |
| """ |
| model = WarningLightsResNet50( |
| num_classes=num_classes, |
| dropout_rate=0.3, |
| pretrained=pretrained |
| ) |
| |
| if checkpoint_path: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| ckpt = torch.load(checkpoint_path, map_location=device) |
| |
| |
| if isinstance(ckpt, dict) and "model_state_dict" in ckpt: |
| state_dict = ckpt["model_state_dict"] |
| else: |
| state_dict = ckpt |
| |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| |
| |
| if isinstance(ckpt, dict): |
| epoch = ckpt.get("epoch") |
| val_acc = ckpt.get("val_acc") |
| |
| if epoch is not None: |
| print(f"β
Loaded checkpoint (epoch: {epoch})") |
| else: |
| print("β
Loaded weights") |
| |
| if isinstance(val_acc, (int, float)): |
| print(f" val_acc: {val_acc:.2f}%") |
| else: |
| print("β
Loaded weights") |
| |
| print(f"π Model running on: {device}") |
|
|
| return model |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("="*70) |
| print("Vehicle Warning Lights ResNet50 - Model Architecture") |
| print("="*70) |
| |
| model = WarningLightsResNet50(num_classes=68, dropout_rate=0.3) |
| model.eval() |
| |
| |
| params = model.get_num_parameters() |
| print(f"\nTotal parameters: {params['total']:,}") |
| print(f"Trainable parameters: {params['trainable']:,}") |
| |
| |
| print("\nModel Architecture:") |
| print(model) |
| |
| |
| print("\nTesting forward pass...") |
| dummy_input = torch.randn(1, 3, 224, 224) |
| with torch.no_grad(): |
| output = model(dummy_input) |
| print(f"Input shape: {dummy_input.shape}") |
| print(f"Output shape: {output.shape}") |
| |
| |
| assert output.shape == (1, 68), f"Expected (1, 68), got {output.shape}" |
| print("β
Forward pass successful!") |
|
|