import json import os from collections import Counter from datasets import load_dataset, Dataset from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling from src.tokenizer import ChessTokenizer from src.model import ChessConfig, ChessForCausalLM, print_parameter_budget def build_vocab(dataset, max_vocab=1200): counter = Counter() for game in dataset: moves = game["text"].split() counter.update(moves) special = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"] vocab_tokens = special + [t for t, _ in counter.most_common(max_vocab - len(special))] vocab = {tok: i for i, tok in enumerate(vocab_tokens)} return vocab def encode_game(game, tokenizer, max_len=256): moves = game["text"].split() tokens = ["[BOS]"] + moves + ["[EOS]"] tokens = tokens[:max_len] ids = [tokenizer._convert_token_to_id(t) for t in tokens] ids += [tokenizer.pad_token_id] * (max_len - len(ids)) return ids def main(): print("Loading MICRO dataset...") raw_ds = load_dataset("dlouapre/lichess_2025-01_1M", split="train[:10]") print("Building tokenizer...") vocab = build_vocab(raw_ds) tokenizer = ChessTokenizer(vocab) config = ChessConfig(vocab_size=tokenizer.vocab_size) model = ChessForCausalLM(config) print("Tokenizing dataset...") input_ids = [encode_game(g, tokenizer) for g in raw_ds] ds = Dataset.from_dict({"input_ids": input_ids}) train_output = "./my_model" args = TrainingArguments( output_dir=train_output, per_device_train_batch_size=2, num_train_epochs=1, save_strategy="no", report_to="none", use_cpu=False ) trainer = Trainer( model=model, args=args, train_dataset=ds, data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) print("🚀 Starting MICRO training...") trainer.train() print("💾 Saving model locally...") final_path = os.path.join(train_output, "final_model") # SAUVEGARDE CORRIGÉE model.save_pretrained(final_path, safe_serialization=False) tokenizer.save_pretrained(final_path) print("✅ Training complete.") if __name__ == "__main__": main()