File size: 5,496 Bytes
d666c9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from datasets import load_dataset, concatenate_datasets
from tokenizers import Tokenizer, models, trainers
import math

# --------------------------------------------------
# 1. Loading datasets from Hugging Face
# --------------------------------------------------
def load_hf_datasets():
    """Load and concatenate datasets"""
    bookcorpus = load_dataset("bookcorpus", split="train")  # 11K books
    wiki = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")  # Wikipedia
    fineweb = load_dataset("fineweb", split="train")
    arabic_raw_text = load_dataset("ARABIC-RAW-TEXT", split="train")
    tinybooks = load_dataset("tiny-textbooks", split="train")
    cc_trajectories = load_dataset("CC-Bench-trajectories", split="train")
    textbook = load_dataset("TextbookReasoning", split="train")
    megascience = load_dataset("MegaScience", split="train") 
    return concatenate_datasets([bookcorpus, wiki, fineweb, arabic_raw_text, tinybooks, cc_trajectories, textbook, megascience])

# --------------------------------------------------
# 2. Tokenization (BPE)
# --------------------------------------------------
def train_tokenizer(dataset, vocab_size=30000):
    """Train a Byte-Level BPE tokenizer"""
    tokenizer = Tokenizer(models.BPE())
    trainer = trainers.BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]"]
    )
    
    # Train on dataset texts
    def batch_iterator(batch_size=1000):
        for i in range(0, len(dataset), batch_size):
            yield dataset[i:i+batch_size]["text"]
    
    tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
    return tokenizer

# --------------------------------------------------
# 3. Preparing DataLoader
# --------------------------------------------------
class TextDataset(Dataset):
    def __init__(self, encoded_text, seq_length=128):
        self.data = encoded_text
        self.seq_length = seq_length

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = self.data[idx:idx+self.seq_length]
        y = self.data[idx+1:idx+self.seq_length+1]
        return torch.tensor(x), torch.tensor(y)

# --------------------------------------------------
# 4. Transformer Model
# --------------------------------------------------
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=d_model*4)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) * torch.sqrt(torch.tensor(self.embedding.embedding_dim))
        x = self.pos_encoder(x)
        x = self.transformer(x)
        return self.fc(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1), :]

# --------------------------------------------------
# 5. Training and Generation
# --------------------------------------------------
def main():
    # Configuration
    SEQ_LENGTH = 128
    BATCH_SIZE = 64
    VOCAB_SIZE = 30000
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # 1. Load data
    dataset = load_hf_datasets()
    
    # 2. Tokenization
    tokenizer = train_tokenizer(dataset, VOCAB_SIZE)
    encoded_text = tokenizer.encode(dataset["text"]).ids
    
    # 3. DataLoader
    train_dataset = TextDataset(encoded_text, SEQ_LENGTH)
    dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # 4. Model
    model = TransformerModel(VOCAB_SIZE).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()
    
    # 5. Training
    for epoch in range(10):
        for batch_x, batch_y in dataloader:
            batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
            optimizer.zero_grad()
            logits = model(batch_x)
            loss = criterion(logits.view(-1, VOCAB_SIZE), batch_y.view(-1))
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
    
    # 6. Text generation
    def generate(prompt, max_length=100, temperature=0.7):
        model.eval()
        tokens = tokenizer.encode(prompt).ids
        for _ in range(max_length):
            with torch.no_grad():
                logits = model(torch.tensor([tokens[-SEQ_LENGTH:]]).to(DEVICE))
            probs = torch.softmax(logits[0, -1] / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            tokens.append(next_token)
        return tokenizer.decode(tokens)
    
    print(generate("The meaning of life is"))

if __name__ == "__main__":
    main()