Spaces:
Sleeping
Sleeping
| """ | |
| Model Architecture Definition | |
| Defines the WasteClassifier using transfer learning with MobileNetV2. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class WasteClassifier(nn.Module): | |
| """ | |
| Waste classification model based on MobileNetV2. | |
| Uses transfer learning with pretrained ImageNet weights. | |
| """ | |
| def __init__(self, num_classes=6, pretrained=True, dropout=0.2): | |
| """ | |
| Initialize the model. | |
| Args: | |
| num_classes (int): Number of output classes (default: 6) | |
| pretrained (bool): Use ImageNet pretrained weights (default: True) | |
| dropout (float): Dropout rate (default: 0.2) | |
| """ | |
| super(WasteClassifier, self).__init__() | |
| # Load pretrained MobileNetV2 | |
| self.backbone = models.mobilenet_v2(pretrained=pretrained) | |
| # Replace the classifier head | |
| in_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(p=dropout), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| def forward(self, x): | |
| """ | |
| Forward pass. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (batch_size, 3, 224, 224) | |
| Returns: | |
| torch.Tensor: Output logits of shape (batch_size, num_classes) | |
| """ | |
| return self.backbone(x) | |
| class WasteClassifierResNet(nn.Module): | |
| """ | |
| Alternative: Waste classification model based on ResNet18. | |
| """ | |
| def __init__(self, num_classes=6, pretrained=True, dropout=0.2): | |
| super(WasteClassifierResNet, self).__init__() | |
| # Load pretrained ResNet18 | |
| self.backbone = models.resnet18(pretrained=pretrained) | |
| # Replace the final fully connected layer | |
| in_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Sequential( | |
| nn.Dropout(p=dropout), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) | |
| def create_model(architecture='mobilenet_v2', num_classes=6, pretrained=True, dropout=0.2): | |
| """ | |
| Factory function to create a model. | |
| Args: | |
| architecture (str): Model architecture ('mobilenet_v2' or 'resnet18') | |
| num_classes (int): Number of output classes | |
| pretrained (bool): Use pretrained weights | |
| dropout (float): Dropout rate | |
| Returns: | |
| nn.Module: Initialized model | |
| """ | |
| if architecture == 'mobilenet_v2': | |
| return WasteClassifier(num_classes=num_classes, pretrained=pretrained, dropout=dropout) | |
| elif architecture == 'resnet18': | |
| return WasteClassifierResNet(num_classes=num_classes, pretrained=pretrained, dropout=dropout) | |
| else: | |
| raise ValueError(f"Unknown architecture: {architecture}") | |
| def count_parameters(model): | |
| """ | |
| Count the number of trainable parameters in the model. | |
| Args: | |
| model (nn.Module): PyTorch model | |
| Returns: | |
| int: Number of trainable parameters | |
| """ | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def print_model_summary(model): | |
| """ | |
| Print a summary of the model architecture. | |
| Args: | |
| model (nn.Module): PyTorch model | |
| """ | |
| print("=" * 60) | |
| print("MODEL SUMMARY") | |
| print("=" * 60) | |
| print(f"Model architecture: {model.__class__.__name__}") | |
| print(f"Total parameters: {count_parameters(model):,}") | |
| print() | |
| print("Model structure:") | |
| print(model) | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| # Example usage | |
| print("Creating MobileNetV2 model...") | |
| model = create_model('mobilenet_v2', num_classes=6, pretrained=True) | |
| print_model_summary(model) | |
| # Test forward pass | |
| dummy_input = torch.randn(1, 3, 224, 224) | |
| output = model(dummy_input) | |
| print(f"\nOutput shape: {output.shape}") | |
| print(f"Expected shape: (1, 6)") | |