smart-sorting / model_architecture.py
karthikeya09's picture
Add Smart Waste Classifier app
cd4dafb
"""
Model Architecture Definition
Defines the WasteClassifier using transfer learning with MobileNetV2.
"""
import torch
import torch.nn as nn
import torchvision.models as models
class WasteClassifier(nn.Module):
"""
Waste classification model based on MobileNetV2.
Uses transfer learning with pretrained ImageNet weights.
"""
def __init__(self, num_classes=6, pretrained=True, dropout=0.2):
"""
Initialize the model.
Args:
num_classes (int): Number of output classes (default: 6)
pretrained (bool): Use ImageNet pretrained weights (default: True)
dropout (float): Dropout rate (default: 0.2)
"""
super(WasteClassifier, self).__init__()
# Load pretrained MobileNetV2
self.backbone = models.mobilenet_v2(pretrained=pretrained)
# Replace the classifier head
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(in_features, num_classes)
)
def forward(self, x):
"""
Forward pass.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, 3, 224, 224)
Returns:
torch.Tensor: Output logits of shape (batch_size, num_classes)
"""
return self.backbone(x)
class WasteClassifierResNet(nn.Module):
"""
Alternative: Waste classification model based on ResNet18.
"""
def __init__(self, num_classes=6, pretrained=True, dropout=0.2):
super(WasteClassifierResNet, self).__init__()
# Load pretrained ResNet18
self.backbone = models.resnet18(pretrained=pretrained)
# Replace the final fully connected layer
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(in_features, num_classes)
)
def forward(self, x):
return self.backbone(x)
def create_model(architecture='mobilenet_v2', num_classes=6, pretrained=True, dropout=0.2):
"""
Factory function to create a model.
Args:
architecture (str): Model architecture ('mobilenet_v2' or 'resnet18')
num_classes (int): Number of output classes
pretrained (bool): Use pretrained weights
dropout (float): Dropout rate
Returns:
nn.Module: Initialized model
"""
if architecture == 'mobilenet_v2':
return WasteClassifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout)
elif architecture == 'resnet18':
return WasteClassifierResNet(num_classes=num_classes, pretrained=pretrained, dropout=dropout)
else:
raise ValueError(f"Unknown architecture: {architecture}")
def count_parameters(model):
"""
Count the number of trainable parameters in the model.
Args:
model (nn.Module): PyTorch model
Returns:
int: Number of trainable parameters
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def print_model_summary(model):
"""
Print a summary of the model architecture.
Args:
model (nn.Module): PyTorch model
"""
print("=" * 60)
print("MODEL SUMMARY")
print("=" * 60)
print(f"Model architecture: {model.__class__.__name__}")
print(f"Total parameters: {count_parameters(model):,}")
print()
print("Model structure:")
print(model)
print("=" * 60)
if __name__ == "__main__":
# Example usage
print("Creating MobileNetV2 model...")
model = create_model('mobilenet_v2', num_classes=6, pretrained=True)
print_model_summary(model)
# Test forward pass
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input)
print(f"\nOutput shape: {output.shape}")
print(f"Expected shape: (1, 6)")