File size: 5,669 Bytes
c64cf6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# mini_gpt_transformer/train.py
import torch
import torch.nn as nn
from torch.nn import functional as F
from model import MiniGPT
from datasets import load_dataset
from dataloader import TinyLLMDataset
from torch.utils.data import DataLoader
import os
from torch.nn.utils.rnn import pad_sequence
from tokenizer import load_tokenizer
from utils import print_gpu_memory
import time
from torch.optim.lr_scheduler import OneCycleLR



# ----------- Hyperparamètres -----------
block_size = 128       # taille du contexte, voir plus loin dans la phrase
batch_size = 32       # nombre de séquences par batch
max_iters = 100000       # nombre d'itérations d'entraînement
eval_interval = 100   # fréquence d'évaluation
learning_rate = 1e-3 # 5e-5
embed_dim = 256
n_heads = 32 
n_layers = 20
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#dt = load_dataset("CATIE-AQ/wikipedia_fr_2022_250K")
#texts = dt["train"]["text"]
dt = load_dataset("iproskurina/TinyStories-French")
texts = dt["train"]["french-tinystories"] 

stoi, itos, encode, decode, pad_token_id = load_tokenizer("tokenizer_wtw_tinystories.json")
vocab_size = len(stoi) 


resume_path = "checkpoints/model_step_best.pt" 
if os.path.exists(resume_path):
    checkpoint = torch.load(resume_path)
    start_iter = checkpoint["step"] + 1
    print(f"Reprise à l'étape {start_iter}")
else:
    start_iter = 0
# ---------- Création du modèle une fois vocab prêt ----------
model = MiniGPT(
    vocab_size=vocab_size,
    block_size=block_size,
    embed_dim=embed_dim,
    depth=n_layers,
    heads=n_heads
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# ---------- Puis chargement des poids si reprise ----------
if os.path.exists(resume_path):
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])



def collate_fn(batch):
    xs, ys = zip(*batch) 
    xs_padded = pad_sequence(xs, batch_first=True, padding_value=pad_token_id)
    ys_padded = pad_sequence(ys, batch_first=True, padding_value=pad_token_id)
    return xs_padded, ys_padded



list_of_sentences = texts[:10000]
split_idx = int(0.9 * len(list_of_sentences))
train_sentences = list_of_sentences[:split_idx]
val_sentences = list_of_sentences[split_idx:]
train_dataset = TinyLLMDataset(train_sentences, block_size, encode)
val_dataset = TinyLLMDataset(val_sentences, block_size, encode)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn)


def count_parameters(model):
    total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if total >= 1e9:
        return f"{total/1e9:.2f}B"
    elif total >= 1e6:
        return f"{total/1e6:.2f}M"
    elif total >= 1e3:
        return f"{total/1e3:.2f}K"
    return str(total)

print("Nombre de paramètres du modèle :", count_parameters(model))


# ----------- Learning rate scheduler -----------
scheduler = OneCycleLR(
    optimizer,
    max_lr=learning_rate,      
    total_steps=max_iters,     
)


# ----------- Boucle d'entraînement -----------
num_epochs = 10  
global_step = start_iter 
best_loss = 10000
for epoch in range(num_epochs):
    print(f"\n=== Epoch {epoch + 1}/{num_epochs} ===")

    for xb, yb in train_loader:
        start_time_total = time.time()
        xb = xb.to(device)
        yb = yb.to(device)
        model.train()
        #print_gpu_memory("Train ")
        
        start_time = time.time()
        logits = model(xb)
        forward_time = time.time() - start_time
        
        #print_gpu_memory("Logits")
        
        start_time = time.time()
        B, T, C = logits.shape
        loss = F.cross_entropy(logits.view(B * T, C), yb.view(B * T), ignore_index=pad_token_id)
        loss_time = time.time() - start_time
        
        #print_gpu_memory("Loss  ")
        
        start_time = time.time()
        optimizer.zero_grad()
        #print_gpu_memory("Zero G")
        loss.backward()
        backward_time = time.time() - start_time
        
        #print_gpu_memory("Back w")
        
        start_time = time.time()
        optimizer.step()
        scheduler.step()
        step_time = time.time() - start_time
        
        #print_gpu_memory("Opt st")
        end_time_total = time.time()
        
        total_time = time.time() - start_time_total
        print(f"[Step {global_step}] Perte = {loss.item():.4f} | total: {total_time:.3f}s | forward: {forward_time:.3f}s | loss: {loss_time:.3f}s | backward: {backward_time:.3f}s | step: {step_time:.3f}s")


            
        if global_step % eval_interval == 0:
            print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}")
            model.eval()
            context = torch.zeros((1, 1), dtype=torch.long, device=device)
            generated = model.generate(context, max_new_tokens=500)[0].tolist()
            print("\n--- Généré ---")
            print(decode(generated))
            print("--------------\n")
        else:
            print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}")


        if loss.item() < best_loss:
            best_loss = loss.item()
            torch.save({
                'step': global_step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
                'vocab': {'stoi': stoi, 'itos': itos}
            }, f"checkpoints/model_step_best.pt")

        global_step += 1