|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from model import MiniGPT |
|
|
from datasets import load_dataset |
|
|
from dataloader import TinyLLMDataset |
|
|
from torch.utils.data import DataLoader |
|
|
import os |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from tokenizer import load_tokenizer |
|
|
from utils import print_gpu_memory |
|
|
import time |
|
|
from torch.optim.lr_scheduler import OneCycleLR |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_size = 128 |
|
|
batch_size = 32 |
|
|
max_iters = 100000 |
|
|
eval_interval = 100 |
|
|
learning_rate = 1e-3 |
|
|
embed_dim = 256 |
|
|
n_heads = 32 |
|
|
n_layers = 20 |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
dt = load_dataset("iproskurina/TinyStories-French") |
|
|
texts = dt["train"]["french-tinystories"] |
|
|
|
|
|
stoi, itos, encode, decode, pad_token_id = load_tokenizer("tokenizer_wtw_tinystories.json") |
|
|
vocab_size = len(stoi) |
|
|
|
|
|
|
|
|
resume_path = "checkpoints/model_step_best.pt" |
|
|
if os.path.exists(resume_path): |
|
|
checkpoint = torch.load(resume_path) |
|
|
start_iter = checkpoint["step"] + 1 |
|
|
print(f"Reprise à l'étape {start_iter}") |
|
|
else: |
|
|
start_iter = 0 |
|
|
|
|
|
model = MiniGPT( |
|
|
vocab_size=vocab_size, |
|
|
block_size=block_size, |
|
|
embed_dim=embed_dim, |
|
|
depth=n_layers, |
|
|
heads=n_heads |
|
|
).to(device) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
|
|
|
|
|
|
|
|
if os.path.exists(resume_path): |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
xs, ys = zip(*batch) |
|
|
xs_padded = pad_sequence(xs, batch_first=True, padding_value=pad_token_id) |
|
|
ys_padded = pad_sequence(ys, batch_first=True, padding_value=pad_token_id) |
|
|
return xs_padded, ys_padded |
|
|
|
|
|
|
|
|
|
|
|
list_of_sentences = texts[:10000] |
|
|
split_idx = int(0.9 * len(list_of_sentences)) |
|
|
train_sentences = list_of_sentences[:split_idx] |
|
|
val_sentences = list_of_sentences[split_idx:] |
|
|
train_dataset = TinyLLMDataset(train_sentences, block_size, encode) |
|
|
val_dataset = TinyLLMDataset(val_sentences, block_size, encode) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) |
|
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate_fn) |
|
|
|
|
|
|
|
|
def count_parameters(model): |
|
|
total = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
if total >= 1e9: |
|
|
return f"{total/1e9:.2f}B" |
|
|
elif total >= 1e6: |
|
|
return f"{total/1e6:.2f}M" |
|
|
elif total >= 1e3: |
|
|
return f"{total/1e3:.2f}K" |
|
|
return str(total) |
|
|
|
|
|
print("Nombre de paramètres du modèle :", count_parameters(model)) |
|
|
|
|
|
|
|
|
|
|
|
scheduler = OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=learning_rate, |
|
|
total_steps=max_iters, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
num_epochs = 10 |
|
|
global_step = start_iter |
|
|
best_loss = 10000 |
|
|
for epoch in range(num_epochs): |
|
|
print(f"\n=== Epoch {epoch + 1}/{num_epochs} ===") |
|
|
|
|
|
for xb, yb in train_loader: |
|
|
start_time_total = time.time() |
|
|
xb = xb.to(device) |
|
|
yb = yb.to(device) |
|
|
model.train() |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
logits = model(xb) |
|
|
forward_time = time.time() - start_time |
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
B, T, C = logits.shape |
|
|
loss = F.cross_entropy(logits.view(B * T, C), yb.view(B * T), ignore_index=pad_token_id) |
|
|
loss_time = time.time() - start_time |
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
loss.backward() |
|
|
backward_time = time.time() - start_time |
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
step_time = time.time() - start_time |
|
|
|
|
|
|
|
|
end_time_total = time.time() |
|
|
|
|
|
total_time = time.time() - start_time_total |
|
|
print(f"[Step {global_step}] Perte = {loss.item():.4f} | total: {total_time:.3f}s | forward: {forward_time:.3f}s | loss: {loss_time:.3f}s | backward: {backward_time:.3f}s | step: {step_time:.3f}s") |
|
|
|
|
|
|
|
|
|
|
|
if global_step % eval_interval == 0: |
|
|
print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}") |
|
|
model.eval() |
|
|
context = torch.zeros((1, 1), dtype=torch.long, device=device) |
|
|
generated = model.generate(context, max_new_tokens=500)[0].tolist() |
|
|
print("\n--- Généré ---") |
|
|
print(decode(generated)) |
|
|
print("--------------\n") |
|
|
else: |
|
|
print(f"[Epoch {epoch+1} | Step {global_step}] Perte = {loss.item():.4f}") |
|
|
|
|
|
|
|
|
if loss.item() < best_loss: |
|
|
best_loss = loss.item() |
|
|
torch.save({ |
|
|
'step': global_step, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'loss': loss.item(), |
|
|
'vocab': {'stoi': stoi, 'itos': itos} |
|
|
}, f"checkpoints/model_step_best.pt") |
|
|
|
|
|
global_step += 1 |
|
|
|