Akimotorakiyu commited on
Commit
8140041
·
verified ·
1 Parent(s): d45252f

Upload configuration_mnist_cnn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_mnist_cnn.py +117 -0
configuration_mnist_cnn.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedConfig
2
+ from typing import List, Optional
3
+
4
+
5
+ class MNISTCNNConfig(PreTrainedConfig):
6
+ """
7
+ Configuration class for custom MNIST CNN model.
8
+
9
+ This configuration contains all the hyperparameters needed to build the model.
10
+ """
11
+
12
+ model_type = "mnist_cnn"
13
+
14
+ def __init__(
15
+ self,
16
+ # Convolutional layers
17
+ conv_channels: List[int] = [32, 64],
18
+ conv_kernel_size: int = 3,
19
+ conv_padding: int = 1,
20
+ pool_kernel_size: int = 2,
21
+ pool_stride: int = 2,
22
+ # Dropout rates
23
+ conv_dropout: float = 0.25,
24
+ fc_dropout: float = 0.5,
25
+ # Fully connected layers
26
+ hidden_size: int = 512,
27
+ # Input/output
28
+ input_channels: int = 1, # MNIST is grayscale
29
+ num_classes: int = 10, # Digits 0-9
30
+ # Image dimensions
31
+ image_size: int = 28,
32
+ # Normalization parameters
33
+ normalize_mean: float = 0.1307,
34
+ normalize_std: float = 0.3081,
35
+ **kwargs,
36
+ ):
37
+ # Validate parameters
38
+ if not isinstance(conv_channels, list) or len(conv_channels) != 2:
39
+ raise ValueError(
40
+ f"`conv_channels` must be a list of 2 integers, got {conv_channels}"
41
+ )
42
+
43
+ if conv_kernel_size <= 0:
44
+ raise ValueError(
45
+ f"`conv_kernel_size` must be positive, got {conv_kernel_size}"
46
+ )
47
+
48
+ if not (0 <= conv_dropout <= 1):
49
+ raise ValueError(
50
+ f"`conv_dropout` must be between 0 and 1, got {conv_dropout}"
51
+ )
52
+
53
+ if not (0 <= fc_dropout <= 1):
54
+ raise ValueError(f"`fc_dropout` must be between 0 and 1, got {fc_dropout}")
55
+
56
+ if num_classes <= 0:
57
+ raise ValueError(f"`num_classes` must be positive, got {num_classes}")
58
+
59
+ # Set configuration attributes
60
+ self.conv_channels = conv_channels
61
+ self.conv_kernel_size = conv_kernel_size
62
+ self.conv_padding = conv_padding
63
+ self.pool_kernel_size = pool_kernel_size
64
+ self.pool_stride = pool_stride
65
+ self.conv_dropout = conv_dropout
66
+ self.fc_dropout = fc_dropout
67
+ self.hidden_size = hidden_size
68
+ self.input_channels = input_channels
69
+ self.num_classes = num_classes
70
+ self.image_size = image_size
71
+ self.normalize_mean = normalize_mean
72
+ self.normalize_std = normalize_std
73
+
74
+ # Calculate the size of flattened features after conv layers
75
+ # After two 2x2 pooling operations: 28 -> 14 -> 7
76
+ self.flattened_size = conv_channels[-1] * (image_size // 4) * (image_size // 4)
77
+
78
+ super().__init__(**kwargs)
79
+
80
+
81
+ # Example configurations for different model variants
82
+ def create_small_config():
83
+ """Create a smaller CNN configuration"""
84
+ return MNISTCNNConfig(
85
+ conv_channels=[16, 32],
86
+ hidden_size=256,
87
+ conv_dropout=0.2,
88
+ fc_dropout=0.4,
89
+ )
90
+
91
+
92
+ def create_large_config():
93
+ """Create a larger CNN configuration"""
94
+ return MNISTCNNConfig(
95
+ conv_channels=[64, 128],
96
+ hidden_size=1024,
97
+ conv_dropout=0.3,
98
+ fc_dropout=0.6,
99
+ )
100
+
101
+
102
+ if __name__ == "__main__":
103
+ # Create and test configuration
104
+ config = MNISTCNNConfig()
105
+ print("Default configuration:")
106
+ print(config)
107
+
108
+ # Save configuration
109
+ config.save_pretrained("mnist-cnn-config")
110
+ print(f"\nConfiguration saved to 'mnist-cnn-config'")
111
+
112
+ # Test different configurations
113
+ small_config = create_small_config()
114
+ large_config = create_large_config()
115
+
116
+ print(f"\nSmall config flattened size: {small_config.flattened_size}")
117
+ print(f"Large config flattened size: {large_config.flattened_size}")