File size: 1,687 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging
import torch
import torch.nn as nn
from torchvision import models

logger = logging.getLogger(__name__)


class CarClassifierResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        logger.info("Initializing ResNet18 model...")

        self.model = models.resnet18(weights="DEFAULT")

        # Freeze everything
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze last layers
        for param in self.model.layer3.parameters():
            param.requires_grad = True

        for param in self.model.layer4.parameters():
            param.requires_grad = True

        # Custom classifier head
        self.model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.model.fc.in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        logger.info("ResNet18 model initialized successfully.")

    def forward(self, x):
        return self.model(x)


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    model = CarClassifierResNet(num_classes=6)

    dummy_input = torch.randn(2, 3, 128, 128)

    output = model(dummy_input)

    print("Output shape:", output.shape)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(
        p.numel() for p in model.parameters()
        if p.requires_grad
    )

    print("Total params:", total_params)
    print("Trainable params:", trainable_params)