import torch from model import MiniGPT from tokenizer import load_tokenizer # ------- Paramètres ------- device = 'cuda' if torch.cuda.is_available() else 'cpu' checkpoint_path = "checkpoints/model_step_best.pt" # ← remplace par ton fichier tokenizer_path = "tokenizer_wtw_tinystories.json" block_size = 128 embed_dim = 128 n_heads = 16 n_layers = 16 max_new_tokens = 500 # ------- Load tokenizer ------- stoi, itos, encode, decode, pad_token_id = load_tokenizer(tokenizer_path) vocab_size = len(stoi) # ------- Load model ------- model = MiniGPT( vocab_size=vocab_size, block_size=block_size, embed_dim=embed_dim, depth=n_layers, heads=n_heads ).to(device) checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() # ------- Contexte initial ------- prompt = "Maman" context_ids = encode(prompt) context = torch.tensor([context_ids], dtype=torch.long, device=device) # ------- Génération ------- with torch.no_grad(): output_ids = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist() # ------- Décodage ------- generated_text = decode(output_ids) print("\n--- Histoire générée ---\n") print(generated_text) print("\n------------------------\n")