| | """
|
| | Tests for Chat Formatter.
|
| | """
|
| |
|
| | import pytest
|
| | import json
|
| | from pathlib import Path
|
| |
|
| | from TouchGrass.data.chat_formatter import ChatFormatter, format_chat_qwen, validate_sample
|
| |
|
| |
|
| | class TestChatFormatter:
|
| | """Test suite for ChatFormatter."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.formatter = ChatFormatter()
|
| |
|
| | def test_formatter_initialization(self):
|
| | """Test that formatter initializes correctly."""
|
| | assert hasattr(self.formatter, "format_sample")
|
| | assert hasattr(self.formatter, "format_dataset")
|
| | assert hasattr(self.formatter, "save_dataset")
|
| | assert hasattr(self.formatter, "create_splits")
|
| |
|
| | def test_format_single_sample(self):
|
| | """Test formatting a single valid sample."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "system", "content": "You are a music assistant."},
|
| | {"role": "user", "content": "How do I play a C chord?"},
|
| | {"role": "assistant", "content": "Place your fingers on the 1st, 2nd, and 3rd strings at the 1st fret."}
|
| | ]
|
| | }
|
| | formatted = self.formatter.format_sample(sample)
|
| | assert "text" in formatted
|
| | assert isinstance(formatted["text"], str)
|
| |
|
| | text = formatted["text"]
|
| | assert "system" in text
|
| | assert "user" in text
|
| | assert "assistant" in text
|
| |
|
| | def test_format_sample_without_system(self):
|
| | """Test formatting a sample without system message."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "user", "content": "What is a scale?"},
|
| | {"role": "assistant", "content": "A scale is a sequence of notes in ascending or descending order."}
|
| | ]
|
| | }
|
| | formatted = self.formatter.format_sample(sample)
|
| | assert "text" in formatted
|
| |
|
| | assert "user" in formatted["text"]
|
| | assert "assistant" in formatted["text"]
|
| |
|
| | def test_format_sample_multiple_turns(self):
|
| | """Test formatting a sample with multiple conversation turns."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "system", "content": "You are a helpful assistant."},
|
| | {"role": "user", "content": "Question 1"},
|
| | {"role": "assistant", "content": "Answer 1"},
|
| | {"role": "user", "content": "Follow-up question"},
|
| | {"role": "assistant", "content": "Follow-up answer"}
|
| | ]
|
| | }
|
| | formatted = self.formatter.format_sample(sample)
|
| | text = formatted["text"]
|
| |
|
| | assert text.count("user") >= 2
|
| | assert text.count("assistant") >= 2
|
| |
|
| | def test_validate_sample_valid(self):
|
| | """Test sample validation with valid sample."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "system", "content": "Test system"},
|
| | {"role": "user", "content": "Test user"},
|
| | {"role": "assistant", "content": "Test assistant"}
|
| | ]
|
| | }
|
| | is_valid, error = validate_sample(sample)
|
| | assert is_valid is True
|
| | assert error is None
|
| |
|
| | def test_validate_sample_missing_role(self):
|
| | """Test sample validation with missing role."""
|
| | sample = {
|
| | "messages": [
|
| | {"content": "Missing role field"},
|
| | ]
|
| | }
|
| | is_valid, error = validate_sample(sample)
|
| | assert is_valid is False
|
| | assert "role" in error.lower()
|
| |
|
| | def test_validate_sample_missing_content(self):
|
| | """Test sample validation with missing content."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "user"},
|
| | ]
|
| | }
|
| | is_valid, error = validate_sample(sample)
|
| | assert is_valid is False
|
| | assert "content" in error.lower()
|
| |
|
| | def test_validate_sample_invalid_role(self):
|
| | """Test sample validation with invalid role."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "invalid", "content": "Test"}
|
| | ]
|
| | }
|
| | is_valid, error = validate_sample(sample)
|
| | assert is_valid is False
|
| | assert "role" in error.lower()
|
| |
|
| | def test_validate_sample_empty_messages(self):
|
| | """Test sample validation with empty messages list."""
|
| | sample = {"messages": []}
|
| | is_valid, error = validate_sample(sample)
|
| | assert is_valid is False
|
| | assert "empty" in error.lower() or "message" in error.lower()
|
| |
|
| | def test_format_dataset(self):
|
| | """Test formatting a full dataset."""
|
| | dataset = [
|
| | {
|
| | "messages": [
|
| | {"role": "system", "content": "System 1"},
|
| | {"role": "user", "content": "User 1"},
|
| | {"role": "assistant", "content": "Assistant 1"}
|
| | ]
|
| | },
|
| | {
|
| | "messages": [
|
| | {"role": "system", "content": "System 2"},
|
| | {"role": "user", "content": "User 2"},
|
| | {"role": "assistant", "content": "Assistant 2"}
|
| | ]
|
| | }
|
| | ]
|
| | formatted = self.formatter.format_dataset(dataset)
|
| | assert len(formatted) == 2
|
| | for item in formatted:
|
| | assert "text" in item
|
| | assert isinstance(item["text"], str)
|
| |
|
| | def test_save_dataset_jsonl(self, tmp_path):
|
| | """Test saving formatted dataset as JSONL."""
|
| | formatted = [
|
| | {"text": "Sample 1"},
|
| | {"text": "Sample 2"},
|
| | {"text": "Sample 3"}
|
| | ]
|
| | output_path = tmp_path / "test_output.jsonl"
|
| | self.formatter.save_dataset(formatted, str(output_path), format="jsonl")
|
| | assert output_path.exists()
|
| |
|
| |
|
| | with open(output_path, 'r', encoding='utf-8') as f:
|
| | lines = f.readlines()
|
| | assert len(lines) == 3
|
| | for line in lines:
|
| | data = json.loads(line)
|
| | assert "text" in data
|
| |
|
| | def test_save_dataset_json(self, tmp_path):
|
| | """Test saving formatted dataset as JSON."""
|
| | formatted = [
|
| | {"text": "Sample 1"},
|
| | {"text": "Sample 2"}
|
| | ]
|
| | output_path = tmp_path / "test_output.json"
|
| | self.formatter.save_dataset(formatted, str(output_path), format="json")
|
| | assert output_path.exists()
|
| |
|
| | with open(output_path, 'r', encoding='utf-8') as f:
|
| | data = json.load(f)
|
| | assert isinstance(data, list)
|
| | assert len(data) == 2
|
| |
|
| | def test_create_splits(self):
|
| | """Test train/val split creation."""
|
| | dataset = [{"text": f"Sample {i}"} for i in range(100)]
|
| | train, val = self.formatter.create_splits(dataset, val_size=0.2)
|
| | assert len(train) == 80
|
| | assert len(val) == 20
|
| |
|
| | train_ids = [id(d) for d in train]
|
| | val_ids = [id(d) for d in val]
|
| | assert len(set(train_ids) & set(val_ids)) == 0
|
| |
|
| | def test_create_splits_with_seed(self):
|
| | """Test that splits are reproducible with seed."""
|
| | dataset = [{"text": f"Sample {i}"} for i in range(100)]
|
| | train1, val1 = self.formatter.create_splits(dataset, val_size=0.2, seed=42)
|
| | train2, val2 = self.formatter.create_splits(dataset, val_size=0.2, seed=42)
|
| |
|
| | assert [d["text"] for d in train1] == [d["text"] for d in train2]
|
| | assert [d["text"] for d in val1] == [d["text"] for d in val2]
|
| |
|
| | def test_format_preserves_original(self):
|
| | """Test that formatting doesn't modify original samples."""
|
| | original = {
|
| | "messages": [
|
| | {"role": "user", "content": "Original question"},
|
| | {"role": "assistant", "content": "Original answer"}
|
| | ],
|
| | "category": "test"
|
| | }
|
| | formatted = self.formatter.format_sample(original)
|
| |
|
| | assert "category" in original
|
| | assert "messages" in original
|
| | assert len(original["messages"]) == 2
|
| |
|
| | def test_qwen_format_system_first(self):
|
| | """Test that Qwen format places system message first."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "user", "content": "User message"},
|
| | {"role": "system", "content": "System message"},
|
| | {"role": "assistant", "content": "Assistant message"}
|
| | ]
|
| | }
|
| | formatted = self.formatter.format_sample(sample)
|
| | text = formatted["text"]
|
| |
|
| | system_pos = text.find("system")
|
| | user_pos = text.find("user")
|
| | assert system_pos < user_pos
|
| |
|
| | def test_format_with_special_tokens(self):
|
| | """Test formatting with special music tokens."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "system", "content": "You are a [GUITAR] assistant."},
|
| | {"role": "user", "content": "How do I play a [CHORD]?"},
|
| | {"role": "assistant", "content": "Use [TAB] notation."}
|
| | ]
|
| | }
|
| | formatted = self.formatter.format_sample(sample)
|
| | text = formatted["text"]
|
| |
|
| | assert "[GUITAR]" in text
|
| | assert "[CHORD]" in text
|
| | assert "[TAB]" in text
|
| |
|
| | def test_empty_content_handling(self):
|
| | """Test handling of empty message content."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "system", "content": ""},
|
| | {"role": "user", "content": "Valid question"},
|
| | {"role": "assistant", "content": "Valid answer"}
|
| | ]
|
| | }
|
| | is_valid, error = validate_sample(sample)
|
| |
|
| |
|
| | assert is_valid in [True, False]
|
| |
|
| | def test_large_dataset_processing(self):
|
| | """Test processing a larger dataset."""
|
| | dataset = [
|
| | {
|
| | "messages": [
|
| | {"role": "system", "content": f"System {i}"},
|
| | {"role": "user", "content": f"Question {i}"},
|
| | {"role": "assistant", "content": f"Answer {i}"}
|
| | ]
|
| | }
|
| | for i in range(500)
|
| | ]
|
| | formatted = self.formatter.format_dataset(dataset)
|
| | assert len(formatted) == 500
|
| | for item in formatted:
|
| | assert "text" in item
|
| | assert len(item["text"]) > 0
|
| |
|
| | def test_format_consistency(self):
|
| | """Test that same input produces same output."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "system", "content": "Test"},
|
| | {"role": "user", "content": "Question"},
|
| | {"role": "assistant", "content": "Answer"}
|
| | ]
|
| | }
|
| | formatted1 = self.formatter.format_sample(sample)
|
| | formatted2 = self.formatter.format_sample(sample)
|
| | assert formatted1["text"] == formatted2["text"]
|
| |
|
| | def test_unicode_handling(self):
|
| | """Test handling of unicode characters."""
|
| | sample = {
|
| | "messages": [
|
| | {"role": "system", "content": "You are a music assistant. 🎵"},
|
| | {"role": "user", "content": "Café au lait? 🎸"},
|
| | {"role": "assistant", "content": "That's a great question! 🎹"}
|
| | ]
|
| | }
|
| | formatted = self.formatter.format_sample(sample)
|
| | assert "🎵" in formatted["text"]
|
| | assert "🎸" in formatted["text"]
|
| | assert "🎹" in formatted["text"]
|
| | assert "Café" in formatted["text"]
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|