File size: 2,701 Bytes
f669547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

from transformers import PreTrainedTokenizer
import json
import os

class ChessTokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids", "attention_mask"]
    
    def __init__(self, vocab_file="vocab.json", **kwargs):
        if os.path.exists(vocab_file):
            with open(vocab_file, 'r') as f: data = json.load(f)
            self.token_to_id = data["token_to_id"]
            self.id_to_token = {int(k): v for k, v in data["id_to_token"].items()}
        else:
            raise ValueError(f"Vocabulary file {vocab_file} not found.")
            
        self.unk_token = "[UNK]"
        self.pad_token = "[PAD]"
        self.bos_token = "[BOS]"
        self.eos_token = "[EOS]"
        
        self.bos_token_id = self.token_to_id.get("[BOS]")
        self.eos_token_id = self.token_to_id.get("[EOS]")
        self.unk_token_id = self.token_to_id.get("[UNK]")
        
        super().__init__(pad_token="[PAD]", bos_token="[BOS]", eos_token="[EOS]", unk_token="[UNK]", **kwargs)
    
    @property
    def vocab_size(self): return len(self.token_to_id)
    
    def get_vocab(self): return self.token_to_id
    
    def _convert_token_to_id(self, token): 
        return self.token_to_id.get(token, self.unk_token_id)
        
    def _convert_id_to_token(self, index): 
        return self.id_to_token.get(index, "[UNK]")
        
    def __call__(self, text, **kwargs):
        # Gestion correcte des listes de textes
        if isinstance(text, list): 
            return {"input_ids": [self.__call__(t, **kwargs)["input_ids"] for t in text]}
        
        moves = text.split()
        ids = [self.token_to_id.get(m, self.unk_token_id) for m in moves]
        
        # AJOUT AUTOMATIQUE DE [BOS] et [EOS] (Crucial pour la Loss)
        if self.bos_token_id is not None:
            ids = [self.bos_token_id] + ids
        if self.eos_token_id is not None:
            ids = ids + [self.eos_token_id]

        max_len = kwargs.get('max_length', 256)
        if len(ids) > max_len: ids = ids[:max_len]
            
        return {"input_ids": ids}
        
    def save_pretrained(self, save_directory, **kwargs):
        with open(os.path.join(save_directory, "vocab.json"), "w") as f:
            json.dump({"token_to_id": self.token_to_id, "id_to_token": self.id_to_token}, f)
        with open(os.path.join(save_directory, "tokenizer_config.json"), "w") as f:
            json.dump({"model_type": "chess_transformer"}, f)
            
    @classmethod
    def from_pretrained(cls, path, **kwargs):
        vocab_path = os.path.join(path, "vocab.json")
        if os.path.exists(vocab_path): return cls(vocab_file=vocab_path, **kwargs)
        return cls(**kwargs)