File size: 13,208 Bytes
c20cb51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# -*- coding: utf-8 -*-
"""
Task 1: Next-Word Prediction using MLP on a Multi-GPU Cluster.

This script is designed to be run using torchrun for distributed training.
Example usage on a 5-GPU machine:
torchrun --nproc_per_node=5 task_1_distributed.py --dataset shakespeare
torchrun --nproc_per_node=5 task_1_distributed.py --dataset linux
"""
import os
import re
import json
import time
import argparse
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import random

# --- Utility Functions for Distributed Training ---

def setup(rank, world_size):
    """Initializes the distributed process group."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    """Cleans up the distributed process group."""
    dist.destroy_process_group()

def is_main_process():
    """Checks if the current process is the main one (rank 0)."""
    return dist.get_rank() == 0

# --- Data Preprocessing ---

def download_and_preprocess_text(dataset_name):
    """Downloads and preprocesses the specified dataset."""
    if dataset_name == 'shakespeare':
        url = 'https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt'
        filename = 'shakespeare_input.txt'
        if not os.path.exists(filename):
            os.system(f"wget {url}")
        with open(filename, "r", encoding='utf-8') as f:
            text = f.read()
        # Keep full stops, remove other special characters
        text = re.sub(r'[^a-zA-Z0-9 \.]', '', text.lower())
        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text).strip()
        return text
    elif dataset_name == 'linux':
        url = 'https://cs.stanford.edu/people/karpathy/char-rnn/linux_input.txt'
        filename = 'linux_input.txt'
        if not os.path.exists(filename):
            os.system(f"wget {url}")
        with open(filename, "r", encoding='utf-8', errors='ignore') as f:
            text = f.read()
        # For code, we treat newlines as separators and don't lowercase
        # We also keep more special characters
        lines = text.split('\n')
        processed_lines = []
        for line in lines:
            # A more lenient regex for code
            processed_line = re.sub(r'[^\w\s\.\(\)\[\]\{\}\=\+\-\*\/,;:"\'#<>&|!~`?]', '', line)
            processed_lines.append(processed_line.strip())
        return ' \n '.join(processed_lines) # Use newline as a token
    else:
        raise ValueError("Invalid dataset name. Choose 'shakespeare' or 'linux'.")

def create_vocabulary_and_pairs(text, context_window_size):
    """Creates vocabulary, reports frequencies, and generates context-target pairs."""
    if is_main_process():
        print("Tokenizing text...")
    tokens = text.split(' ')
    tokens = [token for token in tokens if token] # Remove empty strings

    if is_main_process():
        # Report word frequencies
        word_counts = Counter(tokens)
        print("\n--- Vocabulary Report ---")
        print(f"10 Most Frequent Words: {word_counts.most_common(10)}")
        print(f"10 Least Frequent Words: {word_counts.most_common()[:-11:-1]}")

    # Build vocabulary
    vocab = sorted(list(set(tokens)))
    word_to_idx = {word: i+1 for i, word in enumerate(vocab)} # 0 is reserved for padding
    word_to_idx['<pad>'] = 0
    idx_to_word = {i: word for word, i in word_to_idx.items()}
    vocab_size = len(word_to_idx)

    if is_main_process():
        print(f"Vocabulary Size: {vocab_size}")

    # Create context-target pairs
    indexed_tokens = [word_to_idx[word] for word in tokens]
    contexts, targets = [], []
    for i in range(len(indexed_tokens) - context_window_size):
        contexts.append(indexed_tokens[i:i+context_window_size])
        targets.append(indexed_tokens[i+context_window_size])

    return torch.tensor(contexts, dtype=torch.long), torch.tensor(targets, dtype=torch.long), word_to_idx, idx_to_word

# --- Model Definition ---

class NextWordPredictor(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.fc1 = nn.Linear(context_size * embedding_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x).view(x.size(0), -1)
        out = self.relu(self.fc1(embedded))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out

# --- Training and Evaluation ---

def train(rank, world_size, args):
    """Main training and evaluation function."""
    setup(rank, world_size)
    device = torch.device(f"cuda:{rank}")

    # --- 1. Data Loading and Preprocessing ---
    if is_main_process():
        print(f"--- Starting training for dataset: {args.dataset} ---")
        raw_text = download_and_preprocess_text(args.dataset)
        # Save preprocessed text for other processes to load
        with open(f"{args.dataset}_processed.txt", "w", encoding='utf-8') as f:
            f.write(raw_text)

    # Ensure all processes have the preprocessed file before continuing
    dist.barrier()

    with open(f"{args.dataset}_processed.txt", "r", encoding='utf-8') as f:
        raw_text = f.read()

    contexts, targets, word_to_idx, idx_to_word = create_vocabulary_and_pairs(raw_text, args.context_size)
    vocab_size = len(word_to_idx)
    
    # Save vocabulary only from the main process
    if is_main_process():
        with open(f'{args.dataset}_word_to_idx.json', 'w') as f:
            json.dump(word_to_idx, f)

    # Split data
    dataset = TensorDataset(contexts, targets)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Distributed Samplers
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler, pin_memory=True)

    # --- 2. Model, Optimizer, and Loss ---
    model = NextWordPredictor(vocab_size, args.embedding_dim, args.context_size, args.hidden_dim).to(device)
    ddp_model = DDP(model, device_ids=[rank])

    criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding
    optimizer = optim.AdamW(ddp_model.parameters(), lr=args.lr)
    scaler = torch.cuda.amp.GradScaler() # For mixed precision

    # --- 3. Training Loop ---
    history = {'train_loss': [], 'val_loss': []}
    for epoch in range(args.epochs):
        ddp_model.train()
        train_sampler.set_epoch(epoch)
        total_train_loss = 0.0

        # Use tqdm only on the main process
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Train]", disable=not is_main_process())
        for inputs, labels in train_pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = ddp_model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # --- 4. Validation Loop ---
        ddp_model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.epochs} [Val]", disable=not is_main_process())
            for inputs, labels in val_pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                with torch.cuda.amp.autocast():
                    outputs = ddp_model(inputs)
                    loss = criterion(outputs, labels)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        history['val_loss'].append(avg_val_loss)

        if is_main_process():
            print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            # Save model checkpoint
            torch.save(ddp_model.module.state_dict(), f'{args.dataset}_model.pth')

    if is_main_process():
        print("--- Training Complete ---")
        print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}")

        # --- 5. Reporting and Visualization ---
        # Plotting Loss
        plt.figure(figsize=(10, 5))
        plt.plot(history['train_loss'], label='Training Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.title(f'Training vs. Validation Loss ({args.dataset.capitalize()})')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'{args.dataset}_loss_curve.png')
        print(f"Loss curve saved to {args.dataset}_loss_curve.png")

        # Example Predictions
        print("\n--- Example Predictions ---")
        model.load_state_dict(torch.load(f'{args.dataset}_model.pth'))
        model.to(device)
        model.eval()
        test_sentences = {
            'shakespeare': ["to be or not to", "a horse a horse my", "shall i compare thee to"],
            'linux': ["if (err != 0)", "static const struct file_operations", "return -EINVAL;"]
        }
        for sentence in test_sentences[args.dataset]:
            context_tokens = sentence.lower().split() if args.dataset == 'shakespeare' else sentence.split()
            context_indices = [word_to_idx.get(w, 0) for w in context_tokens]
            context_tensor = torch.tensor([context_indices[-args.context_size:]], dtype=torch.long).to(device)
            with torch.no_grad():
                prediction = model(context_tensor)
                predicted_index = torch.argmax(prediction, dim=1).item()
                predicted_word = idx_to_word.get(predicted_index, '<unk>')
            print(f"'{sentence}' -> '{predicted_word}'")

        # Embedding Visualization
        print("\n--- Visualizing Embeddings with t-SNE ---")
        num_words_to_visualize = 200
        words = list(word_to_idx.keys())
        if len(words) > num_words_to_visualize:
            words_to_visualize = random.sample(words, num_words_to_visualize)
        else:
            words_to_visualize = words
            
        indices = [word_to_idx[w] for w in words_to_visualize]
        embeddings = model.embedding.weight.data[indices].cpu().numpy()

        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
        embeddings_2d = tsne.fit_transform(embeddings)

        plt.figure(figsize=(16, 16))
        for i, word in enumerate(words_to_visualize):
            x, y = embeddings_2d[i, :]
            plt.scatter(x, y)
            plt.annotate(word, (x, y), alpha=0.7)
        plt.title(f't-SNE Visualization of Word Embeddings ({args.dataset.capitalize()})')
        plt.grid(True)
        plt.savefig(f'{args.dataset}_embeddings.png')
        print(f"Embedding visualization saved to {args.dataset}_embeddings.png")

    cleanup()

# --- Main Execution ---

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Multi-GPU Next Word Prediction Trainer")
    parser.add_argument('--dataset', type=str, required=True, choices=['shakespeare', 'linux'], help='Dataset to use.')
    # Model Hyperparameters
    parser.add_argument('--context_size', type=int, default=5, help='Number of context words.')
    parser.add_argument('--embedding_dim', type=int, default=64, help='Dimension of word embeddings.')
    parser.add_argument('--hidden_dim', type=int, default=1024, help='Dimension of hidden layers.')
    # Training Hyperparameters
    parser.add_argument('--epochs', type=int, default=50, help='Number of training epochs.')
    parser.add_argument('--batch_size', type=int, default=16384, help='Batch size per GPU.')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate.')

    args = parser.parse_args()

    world_size = torch.cuda.device_count()
    if world_size < 1:
        print("This script requires at least one GPU.")
    else:
        # Use torch.multiprocessing.spawn to launch DDP processes
        # Note: For cluster environments, torchrun is the preferred method.
        # This script is designed for torchrun.
        rank = int(os.environ["LOCAL_RANK"])
        train(rank, world_size, args)