File size: 14,309 Bytes
7a60a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import json
import os
import unittest
from typing import Any, Dict, List, Optional

from transformers import AutoTokenizer

from specforge.data.preprocessing import preprocess_conversations
from specforge.data.template import TEMPLATE_REGISTRY


class TestTemplatePreprocessing(unittest.TestCase):
    # Configuration section
    SAVE_REFERENCE = False
    REF_DIR = os.path.join(os.path.dirname(__file__), "test_references")

    @classmethod
    def setUpClass(cls):
        """Initialize standard test data"""
        cls.max_length = 65535
        if not os.path.exists(cls.REF_DIR):
            os.makedirs(cls.REF_DIR)

        # 1. General model test data (Qwen, DeepSeek, etc.)
        cls.standard_messages = [
            [
                {"role": "user", "content": "Who are you?"},
                {"role": "assistant", "content": "My name is Qwen2."},
                {"role": "user", "content": "How old are you?"},
                {"role": "assistant", "content": "11 years old."},
            ]
        ]

        # 2. GPT-OSS Dedicated Test Data (Including Analysis and Final Channel)
        cls.gpt_oss_messages = [
            [
                {"role": "user", "content": "Explain Quantum Physics."},
                {
                    "role": "assistant_analysis",
                    "content": "The user wants a summary of quantum physics. I should cover wave-particle duality and uncertainty principle.",
                },
                {
                    "role": "assistant_final",
                    "content": "Quantum physics is the study of matter and energy at the most fundamental level...",
                },
                {"role": "user", "content": "Explain Quantum Physics."},
                {"role": "assistant_final", "content": "I'm Qwen"},
            ]
        ]

        # 3. Tool-use conversation
        cls.tool_use_messages = [
            [
                # First turn: User asks about weather
                {"role": "user", "content": "我想知道今天北京和上海的天气怎么样?"},
                # Assistant thinks and decides to call tools
                {
                    "role": "assistant",
                    "content": "我来帮您查询北京和上海的天气情况。",
                    "tool_calls": [
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather",
                                "arguments": {"location": "北京", "date": "today"},
                            },
                        },
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather",
                                "arguments": {"location": "上海", "date": "today"},
                            },
                        },
                    ],
                },
                # Tool responses
                {
                    "role": "tool",
                    "content": '{"location": "北京", "temperature": 25, "condition": "晴朗", "humidity": "45%"}',
                },
                {
                    "role": "tool",
                    "content": '{"location": "上海", "temperature": 28, "condition": "多云", "humidity": "65%"}',
                },
                # Assistant summarizes with reasoning
                {
                    "role": "assistant",
                    "content": "根据查询结果,北京今天晴朗,25°C;上海多云,28°C。两地都比较适合出行。",
                },
            ]
        ]
        # 4. Reasoning multi-turn conversation
        cls.reasoning_multi_turn_messages = [
            [
                {
                    "role": "user",
                    "content": "Can you recommend a good restaurant in Shanghai?",
                },
                {
                    "role": "assistant",
                    "content": "Sure! I think I can help with that.",
                    "reasoning_content": "If a user is looking for a restaurant in Shanghai, they can go to the Peace Hotel.",
                },
                {
                    "role": "user",
                    "content": "Where is the Peace Hotel?",
                },
                {
                    "role": "assistant",
                    "content": "The Peace Hotel is located at the intersection of Nanjing East Road and the Bund.",
                    "reasoning_content": "Let me think. The Peace Hotel is located at the intersection of Nanjing East Road and the Bund.",
                },
            ]
        ]

        # 5. Complete multi-turn conversation with reasoning, tool_calls, and tool responses
        cls.complete_reasoning_tool_conversation = [
            [
                # First turn: User asks about weather
                {"role": "user", "content": "我想知道今天北京和上海的天气怎么样?"},
                # Assistant thinks and decides to call tools
                {
                    "role": "assistant",
                    "content": "我来帮您查询北京和上海的天气情况。",
                    "reasoning_content": "用户想知道两个城市的天气:北京和上海。我需要分别调用 get_weather 工具两次,一次查询北京,一次查询上海。",
                    "tool_calls": [
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather",
                                "arguments": {"location": "北京", "date": "today"},
                            },
                        },
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather",
                                "arguments": {"location": "上海", "date": "today"},
                            },
                        },
                    ],
                },
                # Tool responses
                {
                    "role": "tool",
                    "content": '{"location": "北京", "temperature": 25, "condition": "晴朗", "humidity": "45%"}',
                },
                {
                    "role": "tool",
                    "content": '{"location": "上海", "temperature": 28, "condition": "多云", "humidity": "65%"}',
                },
                # Assistant summarizes with reasoning
                {
                    "role": "assistant",
                    "content": "根据查询结果,北京今天晴朗,25°C;上海多云,28°C。两地都比较适合出行。",
                    "reasoning_content": "我已经获取了两个城市的天气数据。北京天气更好,晴朗且温度适宜;上海稍微热一些且多云。我可以给用户一个简洁的总结。",
                },
                # Second turn: User asks follow-up question
                {"role": "user", "content": "那明天呢?会下雨吗?"},
                # Assistant checks forecast
                {
                    "role": "assistant",
                    "content": "让我查询一下明天的天气预报。",
                    "reasoning_content": "用户想知道明天是否会下雨,我需要查询两个城市的天气预报。",
                    "tool_calls": [
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather_forecast",
                                "arguments": {"location": "北京", "days": 1},
                            },
                        },
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather_forecast",
                                "arguments": {"location": "上海", "days": 1},
                            },
                        },
                    ],
                },
                # Tool forecast responses
                {
                    "role": "tool",
                    "content": '{"location": "北京", "tomorrow": {"condition": "小雨", "temperature": 22, "rain_probability": 70}}',
                },
                {
                    "role": "tool",
                    "content": '{"location": "上海", "tomorrow": {"condition": "晴", "temperature": 29, "rain_probability": 10}}',
                },
                # Final assistant response
                {
                    "role": "assistant",
                    "content": "明天北京有小雨,记得带伞;上海晴天,适合外出。",
                    "reasoning_content": "北京明天有70%概率下雨,需要提醒用户带伞;上海天气很好,不需要特别准备。",
                },
            ]
        ]

    def _get_ref_path(self, template_key: str, message_label: str = "standard"):
        return os.path.join(self.REF_DIR, f"{template_key}_{message_label}_ref.json")

    def _run_template_test(
        self,
        model_name: str,
        template_key: str,
        messages: Optional[List[List[Dict[str, str]]]] = None,
    ):
        """Encapsulate common test and regression validation logic"""

        # Use the input message or the default standard message.
        target_messages = messages if messages is not None else self.standard_messages
        message_label = None
        if target_messages == self.standard_messages:
            message_label = "standard"
        elif target_messages == self.gpt_oss_messages:
            message_label = "gpt-oss"
        elif target_messages == self.tool_use_messages:
            message_label = "tool-use"
        elif target_messages == self.reasoning_multi_turn_messages:
            message_label = "reasoning-multi-turn"
        elif target_messages == self.complete_reasoning_tool_conversation:
            message_label = "multi-turn-tool-calls-with-reasoning"
        else:
            raise ValueError("Invalid message set")
        print(f"\n>>> Running: {template_key} ({model_name}) {message_label}")

        # 1. Initialize tokenizer and template
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        chat_template = TEMPLATE_REGISTRY.get(template_key)

        # 2. Preprocess conversations
        res = preprocess_conversations(
            tokenizer, target_messages, chat_template, self.max_length
        )
        # Extract current result
        current_data = {
            "input_ids": res["input_ids"][0][0].tolist(),
            "loss_mask": res["loss_mask"][0][0].tolist(),
        }

        ref_path = self._get_ref_path(template_key, message_label)
        # 3. Branch logic: update reference or perform comparison
        if self.SAVE_REFERENCE:
            with open(ref_path, "w", encoding="utf-8") as f:
                json.dump(current_data, f)
            print(f" [INFO] Reference saved to {ref_path}")
        else:
            if not os.path.exists(ref_path):
                self.fail(
                    f"Reference file not found for {template_key}. Set SAVE_REFERENCE=True."
                )

            with open(ref_path, "r", encoding="utf-8") as f:
                ref_data = json.load(f)

            self.assertListEqual(current_data["input_ids"], ref_data["input_ids"])
            self.assertListEqual(current_data["loss_mask"], ref_data["loss_mask"])
            print(f" [PASS] Regression test passed for {template_key}")

        # 4. Debug output
        self.debug_show_loss_mask(res, tokenizer)

    @staticmethod
    def debug_show_loss_mask(res: Dict[str, Any], tokenizer: AutoTokenizer):
        input_ids = res["input_ids"][0][0].tolist()
        loss_mask = res["loss_mask"][0][0].tolist()
        RED, RESET = "\033[91m", "\033[0m"
        print("-" * 30)
        for tid, m in zip(input_ids, loss_mask):
            txt = tokenizer.decode([tid])
            txt = txt.replace("\n", "\\n")
            print(f"{RED if m == 1 else ''}{txt}{RESET}", end="")
        print("\n" + "-" * 30)

    ## The Following are tests. Each test corresponds to a specific model and template.

    def test_deepseek(self):
        self._run_template_test("deepseek-ai/DeepSeek-V3", "deepseek-v3")

    def test_deepseek_v32(self):
        self._run_template_test("deepseek-ai/DeepSeek-V3.2", "deepseek-v32")

    def test_qwen3_thinking(self):
        self._run_template_test(
            "Qwen/Qwen3-0.6B",
            "qwen3-thinking",
            messages=self.reasoning_multi_turn_messages,
        )

    def test_qwen3_instruct(self):
        self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-instruct")

    def test_qwen3_next_instruct(self):
        self._run_template_test("Qwen/Qwen3-Next-80B-A3B-Instruct", "qwen")

    def test_kimi_k2_thinking(self):
        self._run_template_test(
            "moonshotai/Kimi-K2-Thinking",
            "kimi-k2-thinking",
            messages=self.reasoning_multi_turn_messages,
        )

    def test_kimi_k2_instruct(self):
        self._run_template_test("moonshotai/Kimi-K2-Instruct", "kimi-k2-instruct")

    def test_qwen3_next_thinking(self):
        self._run_template_test(
            "Qwen/Qwen3-Next-80B-A3B-Thinking",
            "qwen3-next-thinking",
            messages=self.complete_reasoning_tool_conversation,
        )

    def test_gpt_oss(self):
        self._run_template_test(
            "openai/gpt-oss-120b", "gpt-oss", messages=self.gpt_oss_messages
        )

    def test_ling_flash_2_0(self):
        self._run_template_test("inclusionAI/Ling-flash-2.0", "ling-flash-2.0")

    def test_qwen3_instruct_with_tools(self):
        self._run_template_test(
            "Qwen/Qwen3-0.6B",
            "qwen3-instruct",
            messages=self.tool_use_messages,
        )

    def test_qwen35_instruct(self):
        self._run_template_test(
            "Qwen/Qwen3.5-35B-A3B",
            "qwen3.5",
            messages=self.complete_reasoning_tool_conversation,
        )


if __name__ == "__main__":
    unittest.main()