File size: 2,246 Bytes
e07124f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

import json
import os
from collections import Counter
from datasets import load_dataset, Dataset
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from src.tokenizer import ChessTokenizer
from src.model import ChessConfig, ChessForCausalLM, print_parameter_budget

def build_vocab(dataset, max_vocab=1200):
    counter = Counter()
    for game in dataset:
        moves = game["text"].split()
        counter.update(moves)
    special = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"]
    vocab_tokens = special + [t for t, _ in counter.most_common(max_vocab - len(special))]
    vocab = {tok: i for i, tok in enumerate(vocab_tokens)}
    return vocab

def encode_game(game, tokenizer, max_len=256):
    moves = game["text"].split()
    tokens = ["[BOS]"] + moves + ["[EOS]"]
    tokens = tokens[:max_len]
    ids = [tokenizer._convert_token_to_id(t) for t in tokens]
    ids += [tokenizer.pad_token_id] * (max_len - len(ids))
    return ids

def main():
    print("Loading MICRO dataset...")
    raw_ds = load_dataset("dlouapre/lichess_2025-01_1M", split="train[:10]")

    print("Building tokenizer...")
    vocab = build_vocab(raw_ds)
    tokenizer = ChessTokenizer(vocab)

    config = ChessConfig(vocab_size=tokenizer.vocab_size)
    model = ChessForCausalLM(config)

    print("Tokenizing dataset...")
    input_ids = [encode_game(g, tokenizer) for g in raw_ds]
    ds = Dataset.from_dict({"input_ids": input_ids})

    train_output = "./my_model"

    args = TrainingArguments(
        output_dir=train_output,
        per_device_train_batch_size=2,
        num_train_epochs=1,
        save_strategy="no",
        report_to="none",
        use_cpu=False
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=ds,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    )

    print("🚀 Starting MICRO training...")
    trainer.train()

    print("💾 Saving model locally...")
    final_path = os.path.join(train_output, "final_model")
    
    # SAUVEGARDE CORRIGÉE
    model.save_pretrained(final_path, safe_serialization=False)
    tokenizer.save_pretrained(final_path)
    print("✅ Training complete.")

if __name__ == "__main__":
    main()