| |
|
| | import json |
| | import os |
| | from collections import Counter |
| | from datasets import load_dataset, Dataset |
| | from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling |
| | from src.tokenizer import ChessTokenizer |
| | from src.model import ChessConfig, ChessForCausalLM, print_parameter_budget |
| |
|
| | def build_vocab(dataset, max_vocab=1200): |
| | counter = Counter() |
| | for game in dataset: |
| | moves = game["text"].split() |
| | counter.update(moves) |
| | special = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"] |
| | vocab_tokens = special + [t for t, _ in counter.most_common(max_vocab - len(special))] |
| | vocab = {tok: i for i, tok in enumerate(vocab_tokens)} |
| | return vocab |
| |
|
| | def encode_game(game, tokenizer, max_len=256): |
| | moves = game["text"].split() |
| | tokens = ["[BOS]"] + moves + ["[EOS]"] |
| | tokens = tokens[:max_len] |
| | ids = [tokenizer._convert_token_to_id(t) for t in tokens] |
| | ids += [tokenizer.pad_token_id] * (max_len - len(ids)) |
| | return ids |
| |
|
| | def main(): |
| | print("Loading MICRO dataset...") |
| | raw_ds = load_dataset("dlouapre/lichess_2025-01_1M", split="train[:10]") |
| |
|
| | print("Building tokenizer...") |
| | vocab = build_vocab(raw_ds) |
| | tokenizer = ChessTokenizer(vocab) |
| |
|
| | config = ChessConfig(vocab_size=tokenizer.vocab_size) |
| | model = ChessForCausalLM(config) |
| |
|
| | print("Tokenizing dataset...") |
| | input_ids = [encode_game(g, tokenizer) for g in raw_ds] |
| | ds = Dataset.from_dict({"input_ids": input_ids}) |
| |
|
| | train_output = "./my_model" |
| |
|
| | args = TrainingArguments( |
| | output_dir=train_output, |
| | per_device_train_batch_size=2, |
| | num_train_epochs=1, |
| | save_strategy="no", |
| | report_to="none", |
| | use_cpu=False |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=args, |
| | train_dataset=ds, |
| | data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), |
| | ) |
| |
|
| | print("🚀 Starting MICRO training...") |
| | trainer.train() |
| |
|
| | print("💾 Saving model locally...") |
| | final_path = os.path.join(train_output, "final_model") |
| | |
| | |
| | model.save_pretrained(final_path, safe_serialization=False) |
| | tokenizer.save_pretrained(final_path) |
| | print("✅ Training complete.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|