TouchGrass-7b / tests /test_chat_formatter.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for Chat Formatter.
"""
import pytest
import json
from pathlib import Path
from TouchGrass.data.chat_formatter import ChatFormatter, format_chat_qwen, validate_sample
class TestChatFormatter:
"""Test suite for ChatFormatter."""
def setup_method(self):
"""Set up test fixtures."""
self.formatter = ChatFormatter()
def test_formatter_initialization(self):
"""Test that formatter initializes correctly."""
assert hasattr(self.formatter, "format_sample")
assert hasattr(self.formatter, "format_dataset")
assert hasattr(self.formatter, "save_dataset")
assert hasattr(self.formatter, "create_splits")
def test_format_single_sample(self):
"""Test formatting a single valid sample."""
sample = {
"messages": [
{"role": "system", "content": "You are a music assistant."},
{"role": "user", "content": "How do I play a C chord?"},
{"role": "assistant", "content": "Place your fingers on the 1st, 2nd, and 3rd strings at the 1st fret."}
]
}
formatted = self.formatter.format_sample(sample)
assert "text" in formatted
assert isinstance(formatted["text"], str)
# Should contain system, user, assistant markers
text = formatted["text"]
assert "system" in text
assert "user" in text
assert "assistant" in text
def test_format_sample_without_system(self):
"""Test formatting a sample without system message."""
sample = {
"messages": [
{"role": "user", "content": "What is a scale?"},
{"role": "assistant", "content": "A scale is a sequence of notes in ascending or descending order."}
]
}
formatted = self.formatter.format_sample(sample)
assert "text" in formatted
# Should still work without system
assert "user" in formatted["text"]
assert "assistant" in formatted["text"]
def test_format_sample_multiple_turns(self):
"""Test formatting a sample with multiple conversation turns."""
sample = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Question 1"},
{"role": "assistant", "content": "Answer 1"},
{"role": "user", "content": "Follow-up question"},
{"role": "assistant", "content": "Follow-up answer"}
]
}
formatted = self.formatter.format_sample(sample)
text = formatted["text"]
# Should have multiple user/assistant pairs
assert text.count("user") >= 2
assert text.count("assistant") >= 2
def test_validate_sample_valid(self):
"""Test sample validation with valid sample."""
sample = {
"messages": [
{"role": "system", "content": "Test system"},
{"role": "user", "content": "Test user"},
{"role": "assistant", "content": "Test assistant"}
]
}
is_valid, error = validate_sample(sample)
assert is_valid is True
assert error is None
def test_validate_sample_missing_role(self):
"""Test sample validation with missing role."""
sample = {
"messages": [
{"content": "Missing role field"},
]
}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "role" in error.lower()
def test_validate_sample_missing_content(self):
"""Test sample validation with missing content."""
sample = {
"messages": [
{"role": "user"},
]
}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "content" in error.lower()
def test_validate_sample_invalid_role(self):
"""Test sample validation with invalid role."""
sample = {
"messages": [
{"role": "invalid", "content": "Test"}
]
}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "role" in error.lower()
def test_validate_sample_empty_messages(self):
"""Test sample validation with empty messages list."""
sample = {"messages": []}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "empty" in error.lower() or "message" in error.lower()
def test_format_dataset(self):
"""Test formatting a full dataset."""
dataset = [
{
"messages": [
{"role": "system", "content": "System 1"},
{"role": "user", "content": "User 1"},
{"role": "assistant", "content": "Assistant 1"}
]
},
{
"messages": [
{"role": "system", "content": "System 2"},
{"role": "user", "content": "User 2"},
{"role": "assistant", "content": "Assistant 2"}
]
}
]
formatted = self.formatter.format_dataset(dataset)
assert len(formatted) == 2
for item in formatted:
assert "text" in item
assert isinstance(item["text"], str)
def test_save_dataset_jsonl(self, tmp_path):
"""Test saving formatted dataset as JSONL."""
formatted = [
{"text": "Sample 1"},
{"text": "Sample 2"},
{"text": "Sample 3"}
]
output_path = tmp_path / "test_output.jsonl"
self.formatter.save_dataset(formatted, str(output_path), format="jsonl")
assert output_path.exists()
# Verify content
with open(output_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
assert len(lines) == 3
for line in lines:
data = json.loads(line)
assert "text" in data
def test_save_dataset_json(self, tmp_path):
"""Test saving formatted dataset as JSON."""
formatted = [
{"text": "Sample 1"},
{"text": "Sample 2"}
]
output_path = tmp_path / "test_output.json"
self.formatter.save_dataset(formatted, str(output_path), format="json")
assert output_path.exists()
with open(output_path, 'r', encoding='utf-8') as f:
data = json.load(f)
assert isinstance(data, list)
assert len(data) == 2
def test_create_splits(self):
"""Test train/val split creation."""
dataset = [{"text": f"Sample {i}"} for i in range(100)]
train, val = self.formatter.create_splits(dataset, val_size=0.2)
assert len(train) == 80
assert len(val) == 20
# Check no overlap
train_ids = [id(d) for d in train]
val_ids = [id(d) for d in val]
assert len(set(train_ids) & set(val_ids)) == 0
def test_create_splits_with_seed(self):
"""Test that splits are reproducible with seed."""
dataset = [{"text": f"Sample {i}"} for i in range(100)]
train1, val1 = self.formatter.create_splits(dataset, val_size=0.2, seed=42)
train2, val2 = self.formatter.create_splits(dataset, val_size=0.2, seed=42)
# Should be identical
assert [d["text"] for d in train1] == [d["text"] for d in train2]
assert [d["text"] for d in val1] == [d["text"] for d in val2]
def test_format_preserves_original(self):
"""Test that formatting doesn't modify original samples."""
original = {
"messages": [
{"role": "user", "content": "Original question"},
{"role": "assistant", "content": "Original answer"}
],
"category": "test"
}
formatted = self.formatter.format_sample(original)
# Original should be unchanged
assert "category" in original
assert "messages" in original
assert len(original["messages"]) == 2
def test_qwen_format_system_first(self):
"""Test that Qwen format places system message first."""
sample = {
"messages": [
{"role": "user", "content": "User message"},
{"role": "system", "content": "System message"},
{"role": "assistant", "content": "Assistant message"}
]
}
formatted = self.formatter.format_sample(sample)
text = formatted["text"]
# System should appear before user in the formatted text
system_pos = text.find("system")
user_pos = text.find("user")
assert system_pos < user_pos
def test_format_with_special_tokens(self):
"""Test formatting with special music tokens."""
sample = {
"messages": [
{"role": "system", "content": "You are a [GUITAR] assistant."},
{"role": "user", "content": "How do I play a [CHORD]?"},
{"role": "assistant", "content": "Use [TAB] notation."}
]
}
formatted = self.formatter.format_sample(sample)
text = formatted["text"]
# Special tokens should be preserved
assert "[GUITAR]" in text
assert "[CHORD]" in text
assert "[TAB]" in text
def test_empty_content_handling(self):
"""Test handling of empty message content."""
sample = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": "Valid question"},
{"role": "assistant", "content": "Valid answer"}
]
}
is_valid, error = validate_sample(sample)
# Empty system content might be allowed or not depending on policy
# Here we just check it's handled
assert is_valid in [True, False]
def test_large_dataset_processing(self):
"""Test processing a larger dataset."""
dataset = [
{
"messages": [
{"role": "system", "content": f"System {i}"},
{"role": "user", "content": f"Question {i}"},
{"role": "assistant", "content": f"Answer {i}"}
]
}
for i in range(500)
]
formatted = self.formatter.format_dataset(dataset)
assert len(formatted) == 500
for item in formatted:
assert "text" in item
assert len(item["text"]) > 0
def test_format_consistency(self):
"""Test that same input produces same output."""
sample = {
"messages": [
{"role": "system", "content": "Test"},
{"role": "user", "content": "Question"},
{"role": "assistant", "content": "Answer"}
]
}
formatted1 = self.formatter.format_sample(sample)
formatted2 = self.formatter.format_sample(sample)
assert formatted1["text"] == formatted2["text"]
def test_unicode_handling(self):
"""Test handling of unicode characters."""
sample = {
"messages": [
{"role": "system", "content": "You are a music assistant. 🎵"},
{"role": "user", "content": "Café au lait? 🎸"},
{"role": "assistant", "content": "That's a great question! 🎹"}
]
}
formatted = self.formatter.format_sample(sample)
assert "🎵" in formatted["text"]
assert "🎸" in formatted["text"]
assert "🎹" in formatted["text"]
assert "Café" in formatted["text"]
if __name__ == "__main__":
pytest.main([__file__, "-v"])