ajitsi commited on
Commit
9a437fa
·
1 Parent(s): 53635f6

adding config files

Browse files
Files changed (2) hide show
  1. config.json +19 -0
  2. train.py +17 -6
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "TinyVGG",
3
+ "num_epochs": 10,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.001,
6
+ "input_shape": 3,
7
+ "hidden_units": 64,
8
+ "output_shape": 10,
9
+ "kernel_size": 3,
10
+ "stride": 1,
11
+ "padding": 0,
12
+ "pooling_kernel_size": 2,
13
+ "pooling_stride": 2,
14
+ "framework": "PyTorch",
15
+ "image_size": 64,
16
+ "author": "Ajit",
17
+ "description": "TinyVGG model for image classification with 64x64 input images."
18
+ }
19
+
train.py CHANGED
@@ -9,13 +9,24 @@ from torchvision import transforms
9
  from timeit import default_timer as timer
10
  from pathlib import Path
11
  from get_data import fetch_data
 
 
 
 
12
 
13
  def train_torch():
 
 
 
 
 
 
 
 
14
  # Setup hyperparameters
15
- NUM_EPOCHS = 10
16
- BATCH_SIZE = 32
17
- HIDDEN_UNITS = 10
18
- LEARNING_RATE = 0.001
19
 
20
  # Define the URL and paths
21
  # Define the URL and paths
@@ -46,8 +57,8 @@ def train_torch():
46
  batch_size=BATCH_SIZE)
47
  # Create model with help from model_builder.py
48
  model = model_builder.TinyVGG(
49
- input_shape = 3,
50
- hidden_units=HIDDEN_UNITS,
51
  output_shape=len(class_names)
52
  ).to(device)
53
 
 
9
  from timeit import default_timer as timer
10
  from pathlib import Path
11
  from get_data import fetch_data
12
+ import json
13
+
14
+
15
+
16
 
17
  def train_torch():
18
+ # Load config.json
19
+ with open("config.json", "r") as f:
20
+ config = json.load(f)
21
+
22
+ # Access configuration parameters
23
+ input_shape = config["input_shape"]
24
+ hidden_units = config["hidden_units"]
25
+
26
  # Setup hyperparameters
27
+ NUM_EPOCHS = config["num_epochs"]
28
+ BATCH_SIZE = config["batch_size"]
29
+ LEARNING_RATE = config["learning_rate"]
 
30
 
31
  # Define the URL and paths
32
  # Define the URL and paths
 
57
  batch_size=BATCH_SIZE)
58
  # Create model with help from model_builder.py
59
  model = model_builder.TinyVGG(
60
+ input_shape = input_shape,
61
+ hidden_units=hidden_units,
62
  output_shape=len(class_names)
63
  ).to(device)
64