| import os |
| import tempfile |
| import unittest |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| from datasets import Dataset |
| from specforge.data.preprocessing import build_eagle3_dataset |
| from specforge.utils import safe_conversations_generator |
|
|
| |
| RED = "\033[91m" |
| RESET = "\033[0m" |
|
|
|
|
| def print_with_loss_mask(tokenizer, input_ids, loss_mask, title=""): |
| """Print text with loss_mask=1 (assistant) parts in RED.""" |
| input_ids = input_ids.flatten() |
| loss_mask = loss_mask.flatten() |
|
|
| print(f"\n{'=' * 60}") |
| print(f"{title}") |
| print("=" * 60) |
|
|
| |
| current_mask = loss_mask[0].item() |
| current_ids = [input_ids[0].item()] |
|
|
| for i in range(1, len(input_ids)): |
| if loss_mask[i].item() == current_mask: |
| current_ids.append(input_ids[i].item()) |
| else: |
| |
| text = tokenizer.decode(current_ids, skip_special_tokens=False) |
| if current_mask == 1: |
| print(f"{RED}{text}{RESET}", end="") |
| else: |
| print(text, end="") |
| current_ids = [input_ids[i].item()] |
| current_mask = loss_mask[i].item() |
|
|
| |
| if current_ids: |
| text = tokenizer.decode(current_ids, skip_special_tokens=False) |
| if current_mask == 1: |
| print(f"{RED}{text}{RESET}") |
| else: |
| print(text) |
|
|
| print("=" * 60) |
|
|
|
|
| |
| TOOLS = [ |
| { |
| "type": "function", |
| "function": { |
| "name": "get_weather", |
| "description": "Get the current weather for a location", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "location": { |
| "type": "string", |
| "description": "The city and state, e.g. San Francisco, CA", |
| }, |
| "unit": { |
| "type": "string", |
| "description": "The unit of temperature", |
| "enum": ["celsius", "fahrenheit"], |
| }, |
| }, |
| "required": ["location"], |
| }, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "search_web", |
| "description": "Search the web for information", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "query": { |
| "type": "string", |
| "description": "The search query", |
| }, |
| "num_results": { |
| "type": "integer", |
| "description": "Number of results to return", |
| }, |
| }, |
| "required": ["query"], |
| }, |
| }, |
| }, |
| ] |
|
|
| |
| TOOL_USE_CONVERSATION = [ |
| {"role": "user", "content": "我想知道今天北京和上海的天气怎么样?"}, |
| { |
| "role": "assistant", |
| "content": "我来帮您查询北京和上海的天气情况。", |
| "tool_calls": [ |
| { |
| "type": "function", |
| "function": { |
| "name": "get_weather", |
| "arguments": {"location": "北京", "date": "today"}, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "get_weather", |
| "arguments": {"location": "上海", "date": "today"}, |
| }, |
| }, |
| ], |
| }, |
| { |
| "role": "tool", |
| "content": '{"location": "北京", "temperature": 25, "condition": "晴朗", "humidity": "45%"}', |
| }, |
| { |
| "role": "tool", |
| "content": '{"location": "上海", "temperature": 28, "condition": "多云", "humidity": "65%"}', |
| }, |
| { |
| "role": "assistant", |
| "content": "根据查询结果,北京今天晴朗,25°C;上海多云,28°C。两地都比较适合出行。", |
| }, |
| ] |
|
|
|
|
| class TestBuildEagle3Dataset(unittest.TestCase): |
| """Test for build_eagle3_dataset with tools from specforge/data/tools.py.""" |
|
|
| @classmethod |
| def setUpClass(cls): |
| cls.model_name = "Qwen/Qwen3.5-35B-A3B" |
| cls.template_key = "qwen3.5" |
| cls.tokenizer = AutoTokenizer.from_pretrained( |
| cls.model_name, trust_remote_code=True |
| ) |
| cls.max_length = 65535 |
|
|
| def test_build_eagle3_dataset_basic(self): |
| """Test build_eagle3_dataset with 1 tool_use conversation sample.""" |
| |
| data_file = os.path.join( |
| os.path.dirname(__file__), "data", "tool_use_conversation.jsonl" |
| ) |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| dataset = Dataset.from_generator( |
| generator=safe_conversations_generator, |
| gen_kwargs={"file_path": data_file}, |
| cache_dir=tmp_dir, |
| keep_in_memory=True, |
| ) |
| result_dataset = build_eagle3_dataset( |
| dataset=dataset, |
| tokenizer=self.tokenizer, |
| chat_template=self.template_key, |
| max_length=self.max_length, |
| shuffle_seed=42, |
| num_proc=1, |
| cache_dir=None, |
| cache_key=None, |
| ) |
|
|
| |
| self.assertIn("input_ids", result_dataset.column_names) |
| self.assertIn("loss_mask", result_dataset.column_names) |
| self.assertIn("attention_mask", result_dataset.column_names) |
| self.assertEqual(len(result_dataset), 1) |
|
|
| |
| input_ids = result_dataset[0]["input_ids"].squeeze() |
| loss_mask = result_dataset[0]["loss_mask"].squeeze() |
|
|
| |
| print_with_loss_mask( |
| self.tokenizer, |
| input_ids, |
| loss_mask, |
| title="[build_eagle3_dataset] Full text (RED = loss_mask=1):", |
| ) |
|
|
| |
| assistant_indices = torch.where(loss_mask == 1)[0] |
| self.assertTrue(len(assistant_indices) > 0, "No assistant tokens found") |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main(verbosity=2) |
|
|