File size: 3,523 Bytes
8f24287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Neural network architectures for the ML Training Optimizer environment.

Provides three model architectures of increasing complexity:
- SimpleMLP: For MNIST (Easy task)
- SmallCNN: For FashionMNIST (Medium task)
- DeeperCNN: For CIFAR-10 (Hard task)

All models accept a configurable dropout rate set by the agent.
"""

import torch.nn as nn


class SimpleMLP(nn.Module):
    """
    2-layer MLP for MNIST digit classification.
    ~100k parameters. Trains in ~0.5-1s per epoch on CPU with 4k samples.

    Architecture: 784 → 128 → 64 → 10
    """

    def __init__(self, dropout: float = 0.0):
        super().__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        return self.layers(x)


class SmallCNN(nn.Module):
    """
    Small CNN for FashionMNIST classification.
    ~200k parameters. Trains in ~1-2s per epoch on CPU with 6.5k samples.

    Architecture: 2×(Conv→ReLU→MaxPool) → FC(128) → FC(10)
    """

    def __init__(self, dropout: float = 0.0):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


class DeeperCNN(nn.Module):
    """
    Deeper CNN for CIFAR-10 classification.
    ~500k parameters. Trains in ~2-3s per epoch on CPU with 8k samples.

    Architecture: 3×(Conv→BN→ReLU→MaxPool) → FC(256) → FC(10)
    """

    def __init__(self, dropout: float = 0.0):
        super().__init__()
        self.features = nn.Sequential(
            # Block 1: 3→32, 32×32 → 16×16
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # Block 2: 32→64, 16×16 → 8×8
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            # Block 3: 64→128, 8×8 → 4×4
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


def create_model(model_type: str, dropout: float = 0.0) -> nn.Module:
    """Factory function to create a model by name."""
    models = {
        "simple_mlp": SimpleMLP,
        "small_cnn": SmallCNN,
        "deeper_cnn": DeeperCNN,
    }
    if model_type not in models:
        raise ValueError(f"Unknown model type: {model_type}. Choose from {list(models.keys())}")
    return models[model_type](dropout=dropout)