File size: 9,308 Bytes
b2ea2e8
 
2fc11ed
 
 
 
 
aed0c2d
2fc11ed
 
c9340cb
 
0f2d26d
 
 
cb65f81
9b48427
62466de
2fc11ed
 
 
 
62466de
2fc11ed
 
62466de
307f9d8
2fc11ed
307f9d8
a563d57
62466de
 
 
 
 
0f2d26d
 
2fc11ed
62466de
 
 
 
2fc11ed
 
 
b53dc87
aed0c2d
0f2d26d
 
 
2fc11ed
 
 
29a6938
 
 
 
87a4d20
29a6938
2fc11ed
 
0f2d26d
62466de
0f2d26d
 
 
beb0496
0f2d26d
 
2fc11ed
0f2d26d
 
62466de
0f2d26d
 
 
2fc11ed
 
 
 
 
 
 
 
62466de
 
 
b2ea2e8
a563d57
b2ea2e8
a563d57
b2ea2e8
a563d57
b2ea2e8
a563d57
 
b2ea2e8
a563d57
 
b2ea2e8
a563d57
 
 
 
 
 
 
b2ea2e8
b53dc87
a563d57
b53dc87
a563d57
 
 
 
 
 
 
 
b53dc87
 
62466de
 
 
 
 
 
ac8f63b
a563d57
62466de
 
 
 
 
 
 
 
 
 
 
b53dc87
62466de
 
 
 
 
 
83c3c6d
 
 
b2ea2e8
 
62466de
 
2fc11ed
0f2d26d
b53dc87
0f2d26d
62466de
 
f10ccaa
62466de
2fc11ed
0f2d26d
 
 
2fc11ed
 
0f2d26d
a563d57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2ea2e8
 
 
 
 
 
 
 
 
 
2fc11ed
aed0c2d
 
 
be0031f
b2ea2e8
 
aed0c2d
de45ebc
2fc11ed
0f2d26d
 
 
 
 
 
 
 
 
2fc11ed
0f2d26d
 
 
 
2fc11ed
0f2d26d
aed0c2d
 
0f2d26d
 
 
62466de
0f2d26d
b2ea2e8
 
 
0f2d26d
 
 
 
 
 
 
 
aed0c2d
b53dc87
 
 
 
b2ea2e8
 
a563d57
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import os  # For file operations
import re  # For regex operations
import torch
import math
from transformers import GPT2Tokenizer
from datasets import load_dataset
import numpy as np
import wandb  # Import W&B library

from model import minGRULM
from util import generate_text, generate_name

# ============================
# Configuration Parameters
# ============================
dataset_path  = 'flpelerin/tinystories-100k'
model_name    = generate_name()  # Example: "mingru-a14c"

num_epochs    = 1
batch_size    = 4
seq_length    = 256
learning_rate = 1e-4

input_len     = 50
num_predict   = 250

infer_every       = 200
reset_state_every = 16
validate_every    = 200
save_every        = 500  # Controls checkpointing frequency

# ============================
# Initialize the Device
# ============================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Total context size is {batch_size * seq_length} tokens")

# ============================
# Initialize the Tokenizer
# ============================

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size
print(f"Tokenizer has {vocab_size} unique tokens")

# ============================
# Load and Preprocess Dataset
# ============================
dataset = load_dataset(dataset_path)

def process_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',  # Fixed padding
        truncation=True,
        max_length=(seq_length * batch_size)  # Fixed max length
    )

tokenized_datasets = dataset.map(process_function, batched=True)
print(f"Dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")

# ============================
# Split Dataset into Train and Validation
# ============================
split_dataset = tokenized_datasets['train'].train_test_split(test_size=(1/validate_every))
train_dataset = split_dataset['train']
valid_dataset = split_dataset['test']

print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(valid_dataset)}")

# ============================
# Initialize the Model
# ============================
model = minGRULM(
    vocab_size = vocab_size,
    d_model = 384,
    d_inner = 768,
    n_layers = 6
)

model.to(device)
parameters_count = sum(p.numel() for p in model.parameters())
print(f"Model has {parameters_count:,} parameters")

# ============================
# Symbolic Link Configuration
# ============================
symlink_path = 'pytorch_model.bin'

def update_symlink(target_path, symlink_path):
    """
    Creates or updates a symbolic link pointing to the target path.

    Args:
        target_path (str): The file path the symlink should point to.
        symlink_path (str): The symlink's path.
    """
    try:
        if os.path.islink(symlink_path) or os.path.exists(symlink_path):
            os.remove(symlink_path)
        os.symlink(target_path, symlink_path)
        print(f"Updated symlink: {symlink_path} -> {target_path}")
    except OSError as e:
        print(f"Warning: Failed to create symlink {symlink_path} -> {target_path}. Error: {e}")

# ============================
# Load Checkpoint from pytorch_model.bin if Exists
# ============================
if os.path.exists(symlink_path):
    try:
        model.load_state_dict(torch.load(symlink_path, map_location=device))
        print(f"Loaded model weights from {symlink_path}")
        
    except Exception as e:
        print(f"Error loading model from {symlink_path}: {e}")
        print("Starting training from scratch.")
else:
    print("No checkpoint found. Starting training from scratch.")

# ============================
# Initialize the Weights and Biases Run
# ============================
wandb.login(key="860f8753998c6e6dc356914de07e8855aa2f9642")
wandb.init(
    project="minGRU-Training",
    name=model_name,
    config={
        "dataset_path": dataset_path,
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "seq_length": seq_length,
        "learning_rate": learning_rate,
        "input_len": input_len,
        "num_predict": num_predict,
        "infer_every": infer_every,
        "reset_state_every": reset_state_every,
        "validate_every": validate_every,
        "save_every": save_every,  # Logging the new variable
        "dataset_rows": tokenized_datasets['train'].num_rows,
        "dataset_token_count": batch_size * seq_length,
        "train_set_size": len(train_dataset),
        "valid_set_size": len(valid_dataset),
        "model_parameters": parameters_count,
        "vocab_size": vocab_size,
        "d_model": model.d_model,
        "d_inner": model.d_inner,
        "n_layers": model.n_layers,
        "device": str(device),
        "model_name": model_name  # Log model_name
    }
)

# ============================
# Training Loop with Validation and Checkpointing
# ============================
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
h_states = None
step = 0

for epoch in range(num_epochs):
    print(f"Starting Epoch {epoch + 1}/{num_epochs}")
    for i in range(0, len(train_dataset), batch_size):
        batch = train_dataset[i:i + batch_size]
        input_ids = torch.tensor(batch['input_ids']).to(device)

        # Reset hidden states if needed
        if step % reset_state_every == 0:
            h_states = None
        # Otherwise, keep existing hidden states

        optimizer.zero_grad()
        try:
            _, h_states, loss = model.forward(input_ids, h_states)
            loss.backward()
            optimizer.step()
        except Exception as e:
            print(f"Error during training step {step + 1}: {e}")
            continue  # Skip to the next batch

        step += 1

        # Compute statistics of hidden states
        if h_states is not None:
            try:
                avg_states = sum([torch.mean(h).item() for h in h_states]) / len(h_states)
                var_states = torch.var(torch.cat(h_states, dim=0)).item()
            except Exception as e:
                avg_states = None
                var_states = None
        else:
            avg_states = None
            var_states = None

        # Log step information
        wandb.log({
            "loss": loss.item(),
            "average_hidden_state": avg_states,
            "variance_hidden_state": var_states,
            "step": step
        })
        print(f"Epoch: {epoch + 1}/{num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden States: average = {avg_states}, variance = {var_states}")

        # Perform validation at specified intervals
        if step % validate_every == 0:
            validation_loss = 0.0
            valid_steps = 0

            with torch.no_grad():
                for vi in range(0, len(valid_dataset), batch_size):
                    val_batch = valid_dataset[vi:vi + batch_size]
                    val_input_ids = torch.tensor(val_batch['input_ids']).to(device)

                    # Forward pass
                    _, _, val_loss = model.forward(val_input_ids, None)
                    validation_loss += val_loss.item()
                    valid_steps += 1

            avg_validation_loss = validation_loss / valid_steps if valid_steps > 0 else float('inf')
            # Log validation loss
            wandb.log({"validation_loss": avg_validation_loss, "step": step})
            print(f"----- Validation after Step {step}: Average Loss = {avg_validation_loss:.4f} -----")

        # Perform inference at specified steps
        if step % infer_every == 0:
            with torch.no_grad():
                if input_ids.size(1) < input_len:
                    print("Input length is shorter than input_len. Skipping inference.")
                    continue
                # Select a single input from the current batch for inference
                sample_ids = input_ids[0][:input_len]
                input_text = tokenizer.decode(sample_ids, skip_special_tokens=True)
                print(f"Input for Inference: {input_text}")

                prompt = sample_ids.unsqueeze(0)  # Shape: [1, input_len]
                generated_text = generate_text(model, tokenizer, prompt, num_predict)
                print(f"Generated Text:\n{generated_text}\n")
                # Optionally, log generated text (e.g., as HTML to preserve formatting)
                # wandb.log({"generated_text": wandb.Html(f"<pre>{generated_text}</pre>")}, step=step)

        # Perform checkpointing at specified steps
        if step % save_every == 0:
            step_str = f"{step}k"  # Format step with 'k', e.g., '750k'
            checkpoint_filename = f"{model_name}-{step_str}.bin"
            checkpoint_path = checkpoint_filename
            try:
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Saved model checkpoint at step {step} to {checkpoint_path}")

                # Update the symbolic link to point to this checkpoint
                update_symlink(checkpoint_path, symlink_path)

                # Optionally, log the checkpoint to W&B
                # wandb.save(checkpoint_path)
            except Exception as e:
                print(f"Error saving checkpoint at step {step}: {e}")