Hanrui / SpecForge /tests /test_data /test_parsers.py
Lekr0's picture
Add files using upload-large-folder tool
7a60a87 verified
import json
import os
import unittest
from typing import Any, Dict, List, Optional
from transformers import AutoTokenizer
from specforge.data.preprocessing import preprocess_conversations
from specforge.data.template import TEMPLATE_REGISTRY
class TestTemplatePreprocessing(unittest.TestCase):
# Configuration section
SAVE_REFERENCE = False
REF_DIR = os.path.join(os.path.dirname(__file__), "test_references")
@classmethod
def setUpClass(cls):
"""Initialize standard test data"""
cls.max_length = 65535
if not os.path.exists(cls.REF_DIR):
os.makedirs(cls.REF_DIR)
# 1. General model test data (Qwen, DeepSeek, etc.)
cls.standard_messages = [
[
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "My name is Qwen2."},
{"role": "user", "content": "How old are you?"},
{"role": "assistant", "content": "11 years old."},
]
]
# 2. GPT-OSS Dedicated Test Data (Including Analysis and Final Channel)
cls.gpt_oss_messages = [
[
{"role": "user", "content": "Explain Quantum Physics."},
{
"role": "assistant_analysis",
"content": "The user wants a summary of quantum physics. I should cover wave-particle duality and uncertainty principle.",
},
{
"role": "assistant_final",
"content": "Quantum physics is the study of matter and energy at the most fundamental level...",
},
{"role": "user", "content": "Explain Quantum Physics."},
{"role": "assistant_final", "content": "I'm Qwen"},
]
]
# 3. Tool-use conversation
cls.tool_use_messages = [
[
# First turn: User asks about weather
{"role": "user", "content": "我想知道今天北京和上海的天气怎么样?"},
# Assistant thinks and decides to call tools
{
"role": "assistant",
"content": "我来帮您查询北京和上海的天气情况。",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "北京", "date": "today"},
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "上海", "date": "today"},
},
},
],
},
# Tool responses
{
"role": "tool",
"content": '{"location": "北京", "temperature": 25, "condition": "晴朗", "humidity": "45%"}',
},
{
"role": "tool",
"content": '{"location": "上海", "temperature": 28, "condition": "多云", "humidity": "65%"}',
},
# Assistant summarizes with reasoning
{
"role": "assistant",
"content": "根据查询结果,北京今天晴朗,25°C;上海多云,28°C。两地都比较适合出行。",
},
]
]
# 4. Reasoning multi-turn conversation
cls.reasoning_multi_turn_messages = [
[
{
"role": "user",
"content": "Can you recommend a good restaurant in Shanghai?",
},
{
"role": "assistant",
"content": "Sure! I think I can help with that.",
"reasoning_content": "If a user is looking for a restaurant in Shanghai, they can go to the Peace Hotel.",
},
{
"role": "user",
"content": "Where is the Peace Hotel?",
},
{
"role": "assistant",
"content": "The Peace Hotel is located at the intersection of Nanjing East Road and the Bund.",
"reasoning_content": "Let me think. The Peace Hotel is located at the intersection of Nanjing East Road and the Bund.",
},
]
]
# 5. Complete multi-turn conversation with reasoning, tool_calls, and tool responses
cls.complete_reasoning_tool_conversation = [
[
# First turn: User asks about weather
{"role": "user", "content": "我想知道今天北京和上海的天气怎么样?"},
# Assistant thinks and decides to call tools
{
"role": "assistant",
"content": "我来帮您查询北京和上海的天气情况。",
"reasoning_content": "用户想知道两个城市的天气:北京和上海。我需要分别调用 get_weather 工具两次,一次查询北京,一次查询上海。",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "北京", "date": "today"},
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "上海", "date": "today"},
},
},
],
},
# Tool responses
{
"role": "tool",
"content": '{"location": "北京", "temperature": 25, "condition": "晴朗", "humidity": "45%"}',
},
{
"role": "tool",
"content": '{"location": "上海", "temperature": 28, "condition": "多云", "humidity": "65%"}',
},
# Assistant summarizes with reasoning
{
"role": "assistant",
"content": "根据查询结果,北京今天晴朗,25°C;上海多云,28°C。两地都比较适合出行。",
"reasoning_content": "我已经获取了两个城市的天气数据。北京天气更好,晴朗且温度适宜;上海稍微热一些且多云。我可以给用户一个简洁的总结。",
},
# Second turn: User asks follow-up question
{"role": "user", "content": "那明天呢?会下雨吗?"},
# Assistant checks forecast
{
"role": "assistant",
"content": "让我查询一下明天的天气预报。",
"reasoning_content": "用户想知道明天是否会下雨,我需要查询两个城市的天气预报。",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather_forecast",
"arguments": {"location": "北京", "days": 1},
},
},
{
"type": "function",
"function": {
"name": "get_weather_forecast",
"arguments": {"location": "上海", "days": 1},
},
},
],
},
# Tool forecast responses
{
"role": "tool",
"content": '{"location": "北京", "tomorrow": {"condition": "小雨", "temperature": 22, "rain_probability": 70}}',
},
{
"role": "tool",
"content": '{"location": "上海", "tomorrow": {"condition": "晴", "temperature": 29, "rain_probability": 10}}',
},
# Final assistant response
{
"role": "assistant",
"content": "明天北京有小雨,记得带伞;上海晴天,适合外出。",
"reasoning_content": "北京明天有70%概率下雨,需要提醒用户带伞;上海天气很好,不需要特别准备。",
},
]
]
def _get_ref_path(self, template_key: str, message_label: str = "standard"):
return os.path.join(self.REF_DIR, f"{template_key}_{message_label}_ref.json")
def _run_template_test(
self,
model_name: str,
template_key: str,
messages: Optional[List[List[Dict[str, str]]]] = None,
):
"""Encapsulate common test and regression validation logic"""
# Use the input message or the default standard message.
target_messages = messages if messages is not None else self.standard_messages
message_label = None
if target_messages == self.standard_messages:
message_label = "standard"
elif target_messages == self.gpt_oss_messages:
message_label = "gpt-oss"
elif target_messages == self.tool_use_messages:
message_label = "tool-use"
elif target_messages == self.reasoning_multi_turn_messages:
message_label = "reasoning-multi-turn"
elif target_messages == self.complete_reasoning_tool_conversation:
message_label = "multi-turn-tool-calls-with-reasoning"
else:
raise ValueError("Invalid message set")
print(f"\n>>> Running: {template_key} ({model_name}) {message_label}")
# 1. Initialize tokenizer and template
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
chat_template = TEMPLATE_REGISTRY.get(template_key)
# 2. Preprocess conversations
res = preprocess_conversations(
tokenizer, target_messages, chat_template, self.max_length
)
# Extract current result
current_data = {
"input_ids": res["input_ids"][0][0].tolist(),
"loss_mask": res["loss_mask"][0][0].tolist(),
}
ref_path = self._get_ref_path(template_key, message_label)
# 3. Branch logic: update reference or perform comparison
if self.SAVE_REFERENCE:
with open(ref_path, "w", encoding="utf-8") as f:
json.dump(current_data, f)
print(f" [INFO] Reference saved to {ref_path}")
else:
if not os.path.exists(ref_path):
self.fail(
f"Reference file not found for {template_key}. Set SAVE_REFERENCE=True."
)
with open(ref_path, "r", encoding="utf-8") as f:
ref_data = json.load(f)
self.assertListEqual(current_data["input_ids"], ref_data["input_ids"])
self.assertListEqual(current_data["loss_mask"], ref_data["loss_mask"])
print(f" [PASS] Regression test passed for {template_key}")
# 4. Debug output
self.debug_show_loss_mask(res, tokenizer)
@staticmethod
def debug_show_loss_mask(res: Dict[str, Any], tokenizer: AutoTokenizer):
input_ids = res["input_ids"][0][0].tolist()
loss_mask = res["loss_mask"][0][0].tolist()
RED, RESET = "\033[91m", "\033[0m"
print("-" * 30)
for tid, m in zip(input_ids, loss_mask):
txt = tokenizer.decode([tid])
txt = txt.replace("\n", "\\n")
print(f"{RED if m == 1 else ''}{txt}{RESET}", end="")
print("\n" + "-" * 30)
## The Following are tests. Each test corresponds to a specific model and template.
def test_deepseek(self):
self._run_template_test("deepseek-ai/DeepSeek-V3", "deepseek-v3")
def test_deepseek_v32(self):
self._run_template_test("deepseek-ai/DeepSeek-V3.2", "deepseek-v32")
def test_qwen3_thinking(self):
self._run_template_test(
"Qwen/Qwen3-0.6B",
"qwen3-thinking",
messages=self.reasoning_multi_turn_messages,
)
def test_qwen3_instruct(self):
self._run_template_test("Qwen/Qwen3-0.6B", "qwen3-instruct")
def test_qwen3_next_instruct(self):
self._run_template_test("Qwen/Qwen3-Next-80B-A3B-Instruct", "qwen")
def test_kimi_k2_thinking(self):
self._run_template_test(
"moonshotai/Kimi-K2-Thinking",
"kimi-k2-thinking",
messages=self.reasoning_multi_turn_messages,
)
def test_kimi_k2_instruct(self):
self._run_template_test("moonshotai/Kimi-K2-Instruct", "kimi-k2-instruct")
def test_qwen3_next_thinking(self):
self._run_template_test(
"Qwen/Qwen3-Next-80B-A3B-Thinking",
"qwen3-next-thinking",
messages=self.complete_reasoning_tool_conversation,
)
def test_gpt_oss(self):
self._run_template_test(
"openai/gpt-oss-120b", "gpt-oss", messages=self.gpt_oss_messages
)
def test_ling_flash_2_0(self):
self._run_template_test("inclusionAI/Ling-flash-2.0", "ling-flash-2.0")
def test_qwen3_instruct_with_tools(self):
self._run_template_test(
"Qwen/Qwen3-0.6B",
"qwen3-instruct",
messages=self.tool_use_messages,
)
def test_qwen35_instruct(self):
self._run_template_test(
"Qwen/Qwen3.5-35B-A3B",
"qwen3.5",
messages=self.complete_reasoning_tool_conversation,
)
if __name__ == "__main__":
unittest.main()