flpelerin commited on
Commit
aed0c2d
·
1 Parent(s): 87a4d20

adding wandb

Browse files
Files changed (1) hide show
  1. train.py +33 -18
train.py CHANGED
@@ -3,6 +3,7 @@ import math
3
  from transformers import GPT2Tokenizer
4
  from datasets import load_dataset
5
  import numpy as np
 
6
 
7
  from model import minGRULM
8
  from util import generate_text
@@ -10,21 +11,16 @@ from util import generate_text
10
  # ============================
11
  # Configuration Parameters
12
  # ============================
13
-
14
  dataset_path = 'flpelerin/tinystories-10k'
15
-
16
  num_epochs = 1
17
  batch_size = 4
18
  seq_length = 256
19
  learning_rate = 1e-4
20
  infer_step = 50
21
-
22
  input_len = 50
23
  num_predict = 250
24
-
25
  reset_state_every = 16
26
  validate_every = 100 # Perform validation every 100 training steps
27
-
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  print(f"Total context size is {batch_size * seq_length} tokens")
30
 
@@ -33,10 +29,25 @@ tokenizer.pad_token = tokenizer.eos_token
33
  vocab_size = tokenizer.vocab_size
34
  print(f"Tokenzier has {vocab_size} unique tokens")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # ============================
37
  # Load and Preprocess Dataset
38
  # ============================
39
-
40
  dataset = load_dataset(dataset_path)
41
 
42
  def process_function(examples):
@@ -49,23 +60,22 @@ def process_function(examples):
49
 
50
  tokenized_datasets = dataset.map(process_function, batched=True)
51
  print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
 
52
 
53
  # ============================
54
  # Split Dataset into Train and Validation
55
  # ============================
56
-
57
- # Split the training set into 90% train and 10% validation
58
  split_dataset = tokenized_datasets['train'].train_test_split(test_size=1 / validate_every)
59
  train_dataset = split_dataset['train']
60
  valid_dataset = split_dataset['test']
61
 
62
  print(f"Training set size: {len(train_dataset)}")
63
  print(f"Validation set size: {len(valid_dataset)}")
 
64
 
65
  # ============================
66
  # Initialize the Model
67
  # ============================
68
-
69
  model = minGRULM(
70
  vocab_size = vocab_size,
71
  d_model = 384,
@@ -75,6 +85,7 @@ model = minGRULM(
75
 
76
  model.to(device)
77
  print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
 
78
 
79
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
80
 
@@ -83,7 +94,6 @@ h_states = None
83
  # ============================
84
  # Training Loop with Validation
85
  # ============================
86
-
87
  step = 0
88
  for epoch in range(num_epochs):
89
  print(f"Starting Epoch {epoch + 1}/{num_epochs}")
@@ -92,10 +102,9 @@ for epoch in range(num_epochs):
92
  input_ids = torch.tensor(batch['input_ids']).to(device)
93
 
94
  # Reset hidden states if needed
95
- h_states = h_states if (step % reset_state_every != 0) else None
96
  str_states = (
97
- ''.join(['{:.3f}, '.format(h_states[0][0][0][j].item()) for j in range(10)])
98
- if h_states is not None else 'None'
99
  )
100
 
101
  optimizer.zero_grad()
@@ -104,6 +113,13 @@ for epoch in range(num_epochs):
104
  optimizer.step()
105
 
106
  step += 1
 
 
 
 
 
 
 
107
  print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden State: {str_states}")
108
 
109
  # Perform validation at specified intervals
@@ -122,12 +138,9 @@ for epoch in range(num_epochs):
122
  validation_loss += val_loss.item()
123
  valid_steps += 1
124
 
125
- # Optionally, limit the number of batches for faster validation
126
- # Uncomment the following lines to validate on only the first 100 batches
127
- # if valid_steps >= 100:
128
- # break
129
-
130
  avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf')
 
 
131
  print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----")
132
  model.train() # Switch back to training mode
133
 
@@ -142,3 +155,5 @@ for epoch in range(num_epochs):
142
  prompt = sample_ids.unsqueeze(0) # Shape: [1, input_len]
143
  generated_text = generate_text(model, tokenizer, prompt, num_predict)
144
  print(f"Generated Text:\n{generated_text}\n")
 
 
 
3
  from transformers import GPT2Tokenizer
4
  from datasets import load_dataset
5
  import numpy as np
6
+ import wandb # Import W&B library
7
 
8
  from model import minGRULM
9
  from util import generate_text
 
11
  # ============================
12
  # Configuration Parameters
13
  # ============================
 
14
  dataset_path = 'flpelerin/tinystories-10k'
 
15
  num_epochs = 1
16
  batch_size = 4
17
  seq_length = 256
18
  learning_rate = 1e-4
19
  infer_step = 50
 
20
  input_len = 50
21
  num_predict = 250
 
22
  reset_state_every = 16
23
  validate_every = 100 # Perform validation every 100 training steps
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  print(f"Total context size is {batch_size * seq_length} tokens")
26
 
 
29
  vocab_size = tokenizer.vocab_size
30
  print(f"Tokenzier has {vocab_size} unique tokens")
31
 
32
+ # Initialize W&B Logging
33
+ wandb.init(project="minGRU-Training", config={
34
+ "dataset_path": dataset_path,
35
+ "num_epochs": num_epochs,
36
+ "batch_size": batch_size,
37
+ "seq_length": seq_length,
38
+ "learning_rate": learning_rate,
39
+ "infer_step": infer_step,
40
+ "input_len": input_len,
41
+ "num_predict": num_predict,
42
+ "reset_state_every": reset_state_every,
43
+ "validate_every": validate_every,
44
+ "device": str(device),
45
+ settings=wandb.Settings(api_key="860f8753998c6e6dc356914de07e8855aa2f9642")
46
+ })
47
+
48
  # ============================
49
  # Load and Preprocess Dataset
50
  # ============================
 
51
  dataset = load_dataset(dataset_path)
52
 
53
  def process_function(examples):
 
60
 
61
  tokenized_datasets = dataset.map(process_function, batched=True)
62
  print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
63
+ wandb.log({"dataset_rows": tokenized_datasets['train'].num_rows, "dataset_token_count": batch_size * seq_length}) # Log dataset stats
64
 
65
  # ============================
66
  # Split Dataset into Train and Validation
67
  # ============================
 
 
68
  split_dataset = tokenized_datasets['train'].train_test_split(test_size=1 / validate_every)
69
  train_dataset = split_dataset['train']
70
  valid_dataset = split_dataset['test']
71
 
72
  print(f"Training set size: {len(train_dataset)}")
73
  print(f"Validation set size: {len(valid_dataset)}")
74
+ wandb.log({"train_set_size": len(train_dataset), "valid_set_size": len(valid_dataset)}) # Log set sizes
75
 
76
  # ============================
77
  # Initialize the Model
78
  # ============================
 
79
  model = minGRULM(
80
  vocab_size = vocab_size,
81
  d_model = 384,
 
85
 
86
  model.to(device)
87
  print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
88
+ wandb.log({"model_parameters": sum(p.numel() for p in model.parameters())}) # Log model parameter count
89
 
90
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
91
 
 
94
  # ============================
95
  # Training Loop with Validation
96
  # ============================
 
97
  step = 0
98
  for epoch in range(num_epochs):
99
  print(f"Starting Epoch {epoch + 1}/{num_epochs}")
 
102
  input_ids = torch.tensor(batch['input_ids']).to(device)
103
 
104
  # Reset hidden states if needed
105
+ h_states = h_states if (step % reset_state_every!= 0) else None
106
  str_states = (
107
+ ''.join(['{:.3f}, '.format(h_states[0][0][0][j].item()) for j in range(10)] if h_states is not None else 'None'
 
108
  )
109
 
110
  optimizer.zero_grad()
 
113
  optimizer.step()
114
 
115
  step += 1
116
+ # Log step information
117
+ wandb.log({
118
+ "step": step,
119
+ "epoch": epoch + 1,
120
+ "loss": loss.item(),
121
+ "hidden_state": str_states
122
+ })
123
  print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden State: {str_states}")
124
 
125
  # Perform validation at specified intervals
 
138
  validation_loss += val_loss.item()
139
  valid_steps += 1
140
 
 
 
 
 
 
141
  avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf')
142
+ # Log validation loss
143
+ wandb.log({"validation_loss": avg_validation_loss, "step": step})
144
  print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----")
145
  model.train() # Switch back to training mode
146
 
 
155
  prompt = sample_ids.unsqueeze(0) # Shape: [1, input_len]
156
  generated_text = generate_text(model, tokenizer, prompt, num_predict)
157
  print(f"Generated Text:\n{generated_text}\n")
158
+ # Optionally, log generated text (e.g., as HTML to preserve formatting)
159
+ # wandb.log({"generated_text": wandb.Html(f"<pre>{generated_text}</pre>")}, step=step)