File size: 7,776 Bytes
29658b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import json
import os
import unittest
from typing import Any, Dict, List, Optional

from transformers import AutoTokenizer

from specforge.data.preprocessing import preprocess_conversations
from specforge.data.template import TEMPLATE_REGISTRY


class TestTemplatePreprocessing(unittest.TestCase):
    # Configuration section
    SAVE_REFERENCE = False
    REF_DIR = os.path.join(os.path.dirname(__file__), "test_references")

    @classmethod
    def setUpClass(cls):
        """Initialize standard test data"""
        cls.max_length = 65535
        if not os.path.exists(cls.REF_DIR):
            os.makedirs(cls.REF_DIR)

        # 1. General model test data (Qwen, DeepSeek, etc.)
        cls.standard_messages = [
            [
                {"role": "user", "content": "Who are you?"},
                {"role": "assistant", "content": "My name is Qwen2."},
                {"role": "user", "content": "How old are you?"},
                {"role": "assistant", "content": "11 years old."},
            ]
        ]

        # 2. GPT-OSS Dedicated Test Data (Including Analysis and Final Channel)
        cls.gpt_oss_messages = [
            [
                {"role": "user", "content": "Explain Quantum Physics."},
                {
                    "role": "assistant_analysis",
                    "content": "The user wants a summary of quantum physics. I should cover wave-particle duality and uncertainty principle.",
                },
                {
                    "role": "assistant_final",
                    "content": "Quantum physics is the study of matter and energy at the most fundamental level...",
                },
                {"role": "user", "content": "Explain Quantum Physics."},
                {"role": "assistant_final", "content": "I'm Qwen"},
            ]
        ]

        # 3. Tool-Use Test Data
        cls.tool_use_messages = [
            [
                {
                    "role": "user",
                    "content": "What's the weather like in Beijing today?",
                },
                {
                    "role": "assistant",
                    "content": "I'll check the current weather in Beijing for you.",
                },
                {
                    "role": "tool",
                    "content": '{"location": "Beijing", "temperature": 22, "condition": "Sunny"}',
                },
                {
                    "role": "assistant",
                    "content": "The current weather in Beijing is sunny with a temperature of 22°C.",
                },
                {
                    "role": "tool",
                    "content": '{"unit": "Celsius", "forecast": "Clear skies all day."}',
                },
                {
                    "role": "tool",
                    "content": '{"unit": "Celsius", "forecast": "Clear skies all day."}',
                },
                {
                    "role": "user",
                    "content": "Great! Can you also tell me if it will rain tomorrow?",
                },
                {
                    "role": "assistant",
                    "content": "Based on the forecast, there will be no rain tomorrow—expect clear skies all day.",
                },
            ]
        ]

    def _get_ref_path(self, template_key: str, message_label: str = "standard"):
        return os.path.join(self.REF_DIR, f"{template_key}_{message_label}_ref.json")

    def _run_template_test(
        self,
        model_name: str,
        template_key: str,
        messages: Optional[List[List[Dict[str, str]]]] = None,
    ):
        """Encapsulate common test and regression validation logic"""

        # Use the input message or the default standard message.
        target_messages = messages if messages is not None else self.standard_messages
        message_label = None
        if target_messages == self.standard_messages:
            message_label = "standard"
        elif target_messages == self.gpt_oss_messages:
            message_label = "gpt-oss"
        elif target_messages == self.tool_use_messages:
            message_label = "tool-use"
        else:
            raise ValueError("Invalid message set")
        print(f"\n>>> Running: {template_key} ({model_name}) {message_label}")

        # 1. Initialize tokenizer and template
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        chat_template = TEMPLATE_REGISTRY.get(template_key)

        # 2. Preprocess conversations
        res = preprocess_conversations(
            tokenizer, target_messages, chat_template, self.max_length
        )
        # Extract current result
        current_data = {
            "input_ids": res["input_ids"][0][0].tolist(),
            "loss_mask": res["loss_mask"][0][0].tolist(),
        }

        ref_path = self._get_ref_path(template_key, message_label)
        # 3. Branch logic: update reference or perform comparison
        if self.SAVE_REFERENCE:
            with open(ref_path, "w", encoding="utf-8") as f:
                json.dump(current_data, f)
            print(f" [INFO] Reference saved to {ref_path}")
        else:
            if not os.path.exists(ref_path):
                self.fail(
                    f"Reference file not found for {template_key}. Set SAVE_REFERENCE=True."
                )

            with open(ref_path, "r", encoding="utf-8") as f:
                ref_data = json.load(f)

            self.assertListEqual(current_data["input_ids"], ref_data["input_ids"])
            self.assertListEqual(current_data["loss_mask"], ref_data["loss_mask"])
            print(f" [PASS] Regression test passed for {template_key}")

        # 4. Debug output
        self.debug_show_loss_mask(res, tokenizer)

    @staticmethod
    def debug_show_loss_mask(res: Dict[str, Any], tokenizer: AutoTokenizer):
        input_ids = res["input_ids"][0][0].tolist()
        loss_mask = res["loss_mask"][0][0].tolist()
        RED, RESET = "\033[91m", "\033[0m"
        print("-" * 30)
        for tid, m in zip(input_ids, loss_mask):
            txt = tokenizer.decode([tid])
            txt = txt.replace("\n", "\\n")
            print(f"{RED if m == 1 else ''}{txt}{RESET}", end="")
        print("\n" + "-" * 30)

    ## The Following are tests. Each test corresponds to a specific model and template.

    def test_deepseek(self):
        self._run_template_test("deepseek-ai/DeepSeek-V3", "deepseek-v3")

    def test_deepseek_v32(self):
        self._run_template_test("deepseek-ai/DeepSeek-V3.2", "deepseek-v32")

    def test_qwen3_thinking(self):
        self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-thinking")

    def test_qwen3_instruct(self):
        self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-instruct")

    def test_qwen3_next_instruct(self):
        self._run_template_test("Qwen/Qwen3-Next-80B-A3B-Instruct", "qwen")

    def test_kimi_k2_thinking(self):
        self._run_template_test("moonshotai/Kimi-K2-Thinking", "kimi-k2-thinking")

    def test_kimi_k2_instruct(self):
        self._run_template_test("moonshotai/Kimi-K2-Instruct", "kimi-k2-instruct")

    def test_qwen3_next_thinking(self):
        self._run_template_test(
            "Qwen/Qwen3-Next-80B-A3B-Thinking", "qwen3-next-thinking"
        )

    def test_gpt_oss(self):
        self._run_template_test(
            "openai/gpt-oss-120b", "gpt-oss", messages=self.gpt_oss_messages
        )

    def test_ling_flash_2_0(self):
        self._run_template_test("inclusionAI/Ling-flash-2.0", "ling-flash-2.0")

    def test_qwen3_instruct_with_tools(self):
        self._run_template_test(
            "Qwen/Qwen3-0.6B", "qwen3-instruct", messages=self.tool_use_messages
        )


if __name__ == "__main__":
    unittest.main()