slm-tiny-stories / train.py
Eric Houzelle
Initial commit
c64cf6f
# mini_gpt_transformer/train.py
import torch
import torch.nn as nn
from torch.nn import functional as F
from model import MiniGPT
from datasets import load_dataset
from dataloader import TinyLLMDataset
from torch.utils.data import DataLoader
import os
from torch.nn.utils.rnn import pad_sequence
from tokenizer import load_tokenizer
from utils import print_gpu_memory
import time
from torch.optim.lr_scheduler import OneCycleLR
# ----------- Hyperparamètres -----------
block_size = 128 # taille du contexte, voir plus loin dans la phrase
batch_size = 32 # nombre de séquences par batch
max_iters = 100000 # nombre d'itérations d'entraînement
eval_interval = 100 # fréquence d'évaluation
learning_rate = 1e-3 # 5e-5
embed_dim = 256
n_heads = 32
n_layers = 20
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#dt = load_dataset("CATIE-AQ/wikipedia_fr_2022_250K")
#texts = dt["train"]["text"]
dt = load_dataset("iproskurina/TinyStories-French")
texts = dt["train"]["french-tinystories"]
stoi, itos, encode, decode, pad_token_id = load_tokenizer("tokenizer_wtw_tinystories.json")
vocab_size = len(stoi)
resume_path = "checkpoints/model_step_best.pt"
if os.path.exists(resume_path):
checkpoint = torch.load(resume_path)
start_iter = checkpoint["step"] + 1
print(f"Reprise à l'étape {start_iter}")
else:
start_iter = 0
# ---------- Création du modèle une fois vocab prêt ----------
model = MiniGPT(
vocab_size=vocab_size,
block_size=block_size,
embed_dim=embed_dim,
depth=n_layers,
heads=n_heads
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# ---------- Puis chargement des poids si reprise ----------
if os.path.exists(resume_path):
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
def collate_fn(batch):
xs, ys = zip(*batch)
xs_padded = pad_sequence(xs, batch_first=True, padding_value=pad_token_id)
ys_padded = pad_sequence(ys, batch_first=True, padding_value=pad_token_id)
return xs_padded, ys_padded
list_of_sentences = texts[:10000]
split_idx = int(0.9 * len(list_of_sentences))
train_sentences = list_of_sentences[:split_idx]
val_sentences = list_of_sentences[split_idx:]
train_dataset = TinyLLMDataset(train_sentences, block_size, encode)
val_dataset = TinyLLMDataset(val_sentences, block_size, encode)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn)
def count_parameters(model):
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total >= 1e9:
return f"{total/1e9:.2f}B"
elif total >= 1e6:
return f"{total/1e6:.2f}M"
elif total >= 1e3:
return f"{total/1e3:.2f}K"
return str(total)
print("Nombre de paramètres du modèle :", count_parameters(model))
# ----------- Learning rate scheduler -----------
scheduler = OneCycleLR(
optimizer,
max_lr=learning_rate,
total_steps=max_iters,
)
# ----------- Boucle d'entraînement -----------
num_epochs = 10
global_step = start_iter
best_loss = 10000
for epoch in range(num_epochs):
print(f"\n=== Epoch {epoch + 1}/{num_epochs} ===")
for xb, yb in train_loader:
start_time_total = time.time()
xb = xb.to(device)
yb = yb.to(device)
model.train()
#print_gpu_memory("Train ")
start_time = time.time()
logits = model(xb)
forward_time = time.time() - start_time
#print_gpu_memory("Logits")
start_time = time.time()
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B * T, C), yb.view(B * T), ignore_index=pad_token_id)
loss_time = time.time() - start_time
#print_gpu_memory("Loss ")
start_time = time.time()
optimizer.zero_grad()
#print_gpu_memory("Zero G")
loss.backward()
backward_time = time.time() - start_time
#print_gpu_memory("Back w")
start_time = time.time()
optimizer.step()
scheduler.step()
step_time = time.time() - start_time
#print_gpu_memory("Opt st")
end_time_total = time.time()
total_time = time.time() - start_time_total
print(f"[Step {global_step}] Perte = {loss.item():.4f} | total: {total_time:.3f}s | forward: {forward_time:.3f}s | loss: {loss_time:.3f}s | backward: {backward_time:.3f}s | step: {step_time:.3f}s")
if global_step % eval_interval == 0:
print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}")
model.eval()
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated = model.generate(context, max_new_tokens=500)[0].tolist()
print("\n--- Généré ---")
print(decode(generated))
print("--------------\n")
else:
print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}")
if loss.item() < best_loss:
best_loss = loss.item()
torch.save({
'step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
'vocab': {'stoi': stoi, 'itos': itos}
}, f"checkpoints/model_step_best.pt")
global_step += 1