flpelerin commited on
Commit
62466de
·
1 Parent(s): 6b14517
Files changed (1) hide show
  1. train.py +72 -27
train.py CHANGED
@@ -12,43 +12,43 @@ from util import generate_text
12
  # Configuration Parameters
13
  # ============================
14
  dataset_path = 'flpelerin/tinystories-10k'
 
15
  num_epochs = 1
16
  batch_size = 4
17
  seq_length = 256
18
  learning_rate = 1e-4
19
- infer_step = 50
20
  input_len = 50
21
  num_predict = 250
 
 
22
  reset_state_every = 16
23
- validate_every = 100 # Perform validation every 100 training steps
 
 
 
 
 
 
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  print(f"Total context size is {batch_size * seq_length} tokens")
26
 
 
 
 
 
 
 
 
27
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
28
  tokenizer.pad_token = tokenizer.eos_token
29
  vocab_size = tokenizer.vocab_size
30
  print(f"Tokenzier has {vocab_size} unique tokens")
31
 
32
 
33
- wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
34
 
35
- # Initialize W&B Logging
36
- wandb.init(
37
- project="minGRU-Training",
38
- config={
39
- "dataset_path": dataset_path,
40
- "num_epochs": num_epochs,
41
- "batch_size": batch_size,
42
- "seq_length": seq_length,
43
- "learning_rate": learning_rate,
44
- "infer_step": infer_step,
45
- "input_len": input_len,
46
- "num_predict": num_predict,
47
- "reset_state_every": reset_state_every,
48
- "validate_every": validate_every,
49
- "device": str(device)
50
- }
51
- )
52
 
53
  # ============================
54
  # Load and Preprocess Dataset
@@ -65,7 +65,9 @@ def process_function(examples):
65
 
66
  tokenized_datasets = dataset.map(process_function, batched=True)
67
  print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
68
- wandb.log({"dataset_rows": tokenized_datasets['train'].num_rows, "dataset_token_count": batch_size * seq_length}) # Log dataset stats
 
 
69
 
70
  # ============================
71
  # Split Dataset into Train and Validation
@@ -76,7 +78,8 @@ valid_dataset = split_dataset['test']
76
 
77
  print(f"Training set size: {len(train_dataset)}")
78
  print(f"Validation set size: {len(valid_dataset)}")
79
- wandb.log({"train_set_size": len(train_dataset), "valid_set_size": len(valid_dataset)}) # Log set sizes
 
80
 
81
  # ============================
82
  # Initialize the Model
@@ -89,17 +92,59 @@ model = minGRULM(
89
  )
90
 
91
  model.to(device)
92
- print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
93
- wandb.log({"model_parameters": sum(p.numel() for p in model.parameters())}) # Log model parameter count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
96
 
97
- h_states = None
98
 
99
  # ============================
100
  # Training Loop with Validation
101
  # ============================
 
 
 
102
  step = 0
 
103
  for epoch in range(num_epochs):
104
  print(f"Starting Epoch {epoch + 1}/{num_epochs}")
105
  for i in range(0, len(train_dataset), batch_size):
@@ -150,7 +195,7 @@ for epoch in range(num_epochs):
150
  model.train() # Switch back to training mode
151
 
152
  # Perform inference at specified steps
153
- if step % infer_step == 0:
154
  with torch.no_grad():
155
  # Select a single input from the current batch for inference
156
  sample_ids = input_ids[0][:input_len]
 
12
  # Configuration Parameters
13
  # ============================
14
  dataset_path = 'flpelerin/tinystories-10k'
15
+
16
  num_epochs = 1
17
  batch_size = 4
18
  seq_length = 256
19
  learning_rate = 1e-4
20
+
21
  input_len = 50
22
  num_predict = 250
23
+
24
+ infer_every = 50
25
  reset_state_every = 16
26
+ validate_every = 100 # Perform validation every 100 training steps
27
+
28
+
29
+
30
+ # ============================
31
+ # Initialize the Device
32
+ # ============================
33
+
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  print(f"Total context size is {batch_size * seq_length} tokens")
36
 
37
+
38
+
39
+ # ============================
40
+ # Initialize the Tokenizer
41
+ # ============================
42
+
43
+
44
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
45
  tokenizer.pad_token = tokenizer.eos_token
46
  vocab_size = tokenizer.vocab_size
47
  print(f"Tokenzier has {vocab_size} unique tokens")
48
 
49
 
 
50
 
51
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # ============================
54
  # Load and Preprocess Dataset
 
65
 
66
  tokenized_datasets = dataset.map(process_function, batched=True)
67
  print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
68
+
69
+
70
+
71
 
72
  # ============================
73
  # Split Dataset into Train and Validation
 
78
 
79
  print(f"Training set size: {len(train_dataset)}")
80
  print(f"Validation set size: {len(valid_dataset)}")
81
+
82
+
83
 
84
  # ============================
85
  # Initialize the Model
 
92
  )
93
 
94
  model.to(device)
95
+ parameters_count = sum(p.numel() for p in model.parameters())
96
+ print(f"Model has {parameters_count:,} parameters")
97
+
98
+
99
+
100
+ # ============================
101
+ # Initialize the Weights and Biases Run
102
+ # ============================
103
+ wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
104
+ wandb.init(
105
+ project="minGRU-Training",
106
+ config={
107
+ "dataset_path": dataset_path,
108
+
109
+ "num_epochs": num_epochs,
110
+ "batch_size": batch_size,
111
+ "seq_length": seq_length,
112
+ "learning_rate": learning_rate,
113
+
114
+ "input_len": input_len,
115
+ "num_predict": num_predict,
116
+
117
+ "infer_every": infer_every,
118
+ "reset_state_every": reset_state_every,
119
+ "validate_every": validate_every,
120
+
121
+ "dataset_rows": tokenized_datasets['train'].num_rows,
122
+ "dataset_token_count": batch_size * seq_length,
123
+
124
+ "train_set_size": len(train_dataset),
125
+ "valid_set_size": len(valid_dataset),
126
+
127
+ "model_parameters": parameters_count,
128
+
129
+ "vocab_size": vocab_size,
130
+ "d_model": 384,
131
+ "d_inner": 768,
132
+ "n_layers": 6,
133
+
134
+ "device": str(device)
135
+ }
136
+ )
137
 
 
138
 
 
139
 
140
  # ============================
141
  # Training Loop with Validation
142
  # ============================
143
+
144
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
145
+ h_states = None
146
  step = 0
147
+
148
  for epoch in range(num_epochs):
149
  print(f"Starting Epoch {epoch + 1}/{num_epochs}")
150
  for i in range(0, len(train_dataset), batch_size):
 
195
  model.train() # Switch back to training mode
196
 
197
  # Perform inference at specified steps
198
+ if step % infer_every == 0:
199
  with torch.no_grad():
200
  # Select a single input from the current batch for inference
201
  sample_ids = input_ids[0][:input_len]