File size: 3,099 Bytes
6205838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from tokenizers.processors import TemplateProcessing
from transformers import AutoTokenizer

def get_tokenizer(model_name="answerdotai/ModernBERT-base",
                  bos_token="<BOS>",
                  eos_token="<EOS>",
                  start_token="<START_ID>",
                  end_token="<END_ID>",
                  eot_token="<EOT_ID>"):
    """
    Load tokenizer, add special tokens, and set up chat template.
    """

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    # Define our special tokens, missing in ModernBERT
    special_tokens = {
        "bos_token": bos_token,
        "eos_token": eos_token,
        "additional_special_tokens": [start_token, end_token, eot_token],
    }

    # Add them to tokenizer
    tokenizer.add_special_tokens(special_tokens)

    # Set EOS token as PAD token and CLS to BOS Token
    tokenizer.pad_token = eos_token
    tokenizer.cls_token = bos_token

    # Templete processing for tokenizing pretrain data
    tokenizer._tokenizer.post_processor = TemplateProcessing(
        single=f"{bos_token} $A {eos_token}",
        special_tokens=[
            (bos_token, tokenizer.bos_token_id),
            (eos_token, tokenizer.eos_token_id)
        ]
    )

    # Chat templete for SFT
    tokenizer.chat_template = (
        "{% for message in messages %}"
        "{{ bos_token if loop.first else '' }}"
        f"{{{{ '{start_token}' + message['role'] + '{end_token}' }}}}\n"
        "{{ message['content'] }}"
        f"{{{{ '{eot_token}' if message['role'] == 'user' else eos_token }}}}"
        "{% endfor %}"
        "{% if add_generation_prompt %}"
        f"{{{{ '{start_token}' + 'assistant' + '{end_token}' }}}}"
        "{% endif %}"
    )

    return tokenizer

if __name__ == "__main__":
    tok = get_tokenizer()

    print("--------------Pre Train Test-----------------")
    text = "Hello World"
    ids = tok(text, padding=True, return_tensors="pt")["input_ids"][0]
    decoded = tok.decode(ids, skip_special=False)
    print("Text:", text)
    print("ids:", ids)
    print("decoded:", decoded)

    print("\n--------------SFT Test-----------------")
    messages = [
        {"role": "user", "content": "Hello!"},
        {"role": "assistant", "content": "Hi! How can I help you?"},
    ]
    encoded = tok.apply_chat_template(messages, tokenize=True, add_special_tokens=True)
    decoded = tok.decode(encoded, skip_special=False)
    print("Messages:", messages)
    print("Encoded ids:", encoded)
    print("Decoded:", decoded)

    print("\n--------------Generation Prompt Test-----------------")
    messages = [
        {"role": "user", "content": "Hello!"},
    ]

    encoded = tok.apply_chat_template(messages, 
                                      tokenize=True,
                                      add_special_tokens=True, 
                                      add_generation_prompt=True
                                      )
    
    decoded = tok.decode(encoded, skip_special=False)
    print("Messages:", messages)
    print("Encoded ids:", encoded)
    print("Decoded:", decoded)