Model-v3-mdaytek / train.py
MDaytek's picture
Micro test upload (fixed v3)
e07124f verified
import json
import os
from collections import Counter
from datasets import load_dataset, Dataset
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from src.tokenizer import ChessTokenizer
from src.model import ChessConfig, ChessForCausalLM, print_parameter_budget
def build_vocab(dataset, max_vocab=1200):
counter = Counter()
for game in dataset:
moves = game["text"].split()
counter.update(moves)
special = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"]
vocab_tokens = special + [t for t, _ in counter.most_common(max_vocab - len(special))]
vocab = {tok: i for i, tok in enumerate(vocab_tokens)}
return vocab
def encode_game(game, tokenizer, max_len=256):
moves = game["text"].split()
tokens = ["[BOS]"] + moves + ["[EOS]"]
tokens = tokens[:max_len]
ids = [tokenizer._convert_token_to_id(t) for t in tokens]
ids += [tokenizer.pad_token_id] * (max_len - len(ids))
return ids
def main():
print("Loading MICRO dataset...")
raw_ds = load_dataset("dlouapre/lichess_2025-01_1M", split="train[:10]")
print("Building tokenizer...")
vocab = build_vocab(raw_ds)
tokenizer = ChessTokenizer(vocab)
config = ChessConfig(vocab_size=tokenizer.vocab_size)
model = ChessForCausalLM(config)
print("Tokenizing dataset...")
input_ids = [encode_game(g, tokenizer) for g in raw_ds]
ds = Dataset.from_dict({"input_ids": input_ids})
train_output = "./my_model"
args = TrainingArguments(
output_dir=train_output,
per_device_train_batch_size=2,
num_train_epochs=1,
save_strategy="no",
report_to="none",
use_cpu=False
)
trainer = Trainer(
model=model,
args=args,
train_dataset=ds,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
print("🚀 Starting MICRO training...")
trainer.train()
print("💾 Saving model locally...")
final_path = os.path.join(train_output, "final_model")
# SAUVEGARDE CORRIGÉE
model.save_pretrained(final_path, safe_serialization=False)
tokenizer.save_pretrained(final_path)
print("✅ Training complete.")
if __name__ == "__main__":
main()