File size: 12,154 Bytes
4f0238f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | """
Tests for Chat Formatter.
"""
import pytest
import json
from pathlib import Path
from TouchGrass.data.chat_formatter import ChatFormatter, format_chat_qwen, validate_sample
class TestChatFormatter:
"""Test suite for ChatFormatter."""
def setup_method(self):
"""Set up test fixtures."""
self.formatter = ChatFormatter()
def test_formatter_initialization(self):
"""Test that formatter initializes correctly."""
assert hasattr(self.formatter, "format_sample")
assert hasattr(self.formatter, "format_dataset")
assert hasattr(self.formatter, "save_dataset")
assert hasattr(self.formatter, "create_splits")
def test_format_single_sample(self):
"""Test formatting a single valid sample."""
sample = {
"messages": [
{"role": "system", "content": "You are a music assistant."},
{"role": "user", "content": "How do I play a C chord?"},
{"role": "assistant", "content": "Place your fingers on the 1st, 2nd, and 3rd strings at the 1st fret."}
]
}
formatted = self.formatter.format_sample(sample)
assert "text" in formatted
assert isinstance(formatted["text"], str)
# Should contain system, user, assistant markers
text = formatted["text"]
assert "system" in text
assert "user" in text
assert "assistant" in text
def test_format_sample_without_system(self):
"""Test formatting a sample without system message."""
sample = {
"messages": [
{"role": "user", "content": "What is a scale?"},
{"role": "assistant", "content": "A scale is a sequence of notes in ascending or descending order."}
]
}
formatted = self.formatter.format_sample(sample)
assert "text" in formatted
# Should still work without system
assert "user" in formatted["text"]
assert "assistant" in formatted["text"]
def test_format_sample_multiple_turns(self):
"""Test formatting a sample with multiple conversation turns."""
sample = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Question 1"},
{"role": "assistant", "content": "Answer 1"},
{"role": "user", "content": "Follow-up question"},
{"role": "assistant", "content": "Follow-up answer"}
]
}
formatted = self.formatter.format_sample(sample)
text = formatted["text"]
# Should have multiple user/assistant pairs
assert text.count("user") >= 2
assert text.count("assistant") >= 2
def test_validate_sample_valid(self):
"""Test sample validation with valid sample."""
sample = {
"messages": [
{"role": "system", "content": "Test system"},
{"role": "user", "content": "Test user"},
{"role": "assistant", "content": "Test assistant"}
]
}
is_valid, error = validate_sample(sample)
assert is_valid is True
assert error is None
def test_validate_sample_missing_role(self):
"""Test sample validation with missing role."""
sample = {
"messages": [
{"content": "Missing role field"},
]
}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "role" in error.lower()
def test_validate_sample_missing_content(self):
"""Test sample validation with missing content."""
sample = {
"messages": [
{"role": "user"},
]
}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "content" in error.lower()
def test_validate_sample_invalid_role(self):
"""Test sample validation with invalid role."""
sample = {
"messages": [
{"role": "invalid", "content": "Test"}
]
}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "role" in error.lower()
def test_validate_sample_empty_messages(self):
"""Test sample validation with empty messages list."""
sample = {"messages": []}
is_valid, error = validate_sample(sample)
assert is_valid is False
assert "empty" in error.lower() or "message" in error.lower()
def test_format_dataset(self):
"""Test formatting a full dataset."""
dataset = [
{
"messages": [
{"role": "system", "content": "System 1"},
{"role": "user", "content": "User 1"},
{"role": "assistant", "content": "Assistant 1"}
]
},
{
"messages": [
{"role": "system", "content": "System 2"},
{"role": "user", "content": "User 2"},
{"role": "assistant", "content": "Assistant 2"}
]
}
]
formatted = self.formatter.format_dataset(dataset)
assert len(formatted) == 2
for item in formatted:
assert "text" in item
assert isinstance(item["text"], str)
def test_save_dataset_jsonl(self, tmp_path):
"""Test saving formatted dataset as JSONL."""
formatted = [
{"text": "Sample 1"},
{"text": "Sample 2"},
{"text": "Sample 3"}
]
output_path = tmp_path / "test_output.jsonl"
self.formatter.save_dataset(formatted, str(output_path), format="jsonl")
assert output_path.exists()
# Verify content
with open(output_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
assert len(lines) == 3
for line in lines:
data = json.loads(line)
assert "text" in data
def test_save_dataset_json(self, tmp_path):
"""Test saving formatted dataset as JSON."""
formatted = [
{"text": "Sample 1"},
{"text": "Sample 2"}
]
output_path = tmp_path / "test_output.json"
self.formatter.save_dataset(formatted, str(output_path), format="json")
assert output_path.exists()
with open(output_path, 'r', encoding='utf-8') as f:
data = json.load(f)
assert isinstance(data, list)
assert len(data) == 2
def test_create_splits(self):
"""Test train/val split creation."""
dataset = [{"text": f"Sample {i}"} for i in range(100)]
train, val = self.formatter.create_splits(dataset, val_size=0.2)
assert len(train) == 80
assert len(val) == 20
# Check no overlap
train_ids = [id(d) for d in train]
val_ids = [id(d) for d in val]
assert len(set(train_ids) & set(val_ids)) == 0
def test_create_splits_with_seed(self):
"""Test that splits are reproducible with seed."""
dataset = [{"text": f"Sample {i}"} for i in range(100)]
train1, val1 = self.formatter.create_splits(dataset, val_size=0.2, seed=42)
train2, val2 = self.formatter.create_splits(dataset, val_size=0.2, seed=42)
# Should be identical
assert [d["text"] for d in train1] == [d["text"] for d in train2]
assert [d["text"] for d in val1] == [d["text"] for d in val2]
def test_format_preserves_original(self):
"""Test that formatting doesn't modify original samples."""
original = {
"messages": [
{"role": "user", "content": "Original question"},
{"role": "assistant", "content": "Original answer"}
],
"category": "test"
}
formatted = self.formatter.format_sample(original)
# Original should be unchanged
assert "category" in original
assert "messages" in original
assert len(original["messages"]) == 2
def test_qwen_format_system_first(self):
"""Test that Qwen format places system message first."""
sample = {
"messages": [
{"role": "user", "content": "User message"},
{"role": "system", "content": "System message"},
{"role": "assistant", "content": "Assistant message"}
]
}
formatted = self.formatter.format_sample(sample)
text = formatted["text"]
# System should appear before user in the formatted text
system_pos = text.find("system")
user_pos = text.find("user")
assert system_pos < user_pos
def test_format_with_special_tokens(self):
"""Test formatting with special music tokens."""
sample = {
"messages": [
{"role": "system", "content": "You are a [GUITAR] assistant."},
{"role": "user", "content": "How do I play a [CHORD]?"},
{"role": "assistant", "content": "Use [TAB] notation."}
]
}
formatted = self.formatter.format_sample(sample)
text = formatted["text"]
# Special tokens should be preserved
assert "[GUITAR]" in text
assert "[CHORD]" in text
assert "[TAB]" in text
def test_empty_content_handling(self):
"""Test handling of empty message content."""
sample = {
"messages": [
{"role": "system", "content": ""},
{"role": "user", "content": "Valid question"},
{"role": "assistant", "content": "Valid answer"}
]
}
is_valid, error = validate_sample(sample)
# Empty system content might be allowed or not depending on policy
# Here we just check it's handled
assert is_valid in [True, False]
def test_large_dataset_processing(self):
"""Test processing a larger dataset."""
dataset = [
{
"messages": [
{"role": "system", "content": f"System {i}"},
{"role": "user", "content": f"Question {i}"},
{"role": "assistant", "content": f"Answer {i}"}
]
}
for i in range(500)
]
formatted = self.formatter.format_dataset(dataset)
assert len(formatted) == 500
for item in formatted:
assert "text" in item
assert len(item["text"]) > 0
def test_format_consistency(self):
"""Test that same input produces same output."""
sample = {
"messages": [
{"role": "system", "content": "Test"},
{"role": "user", "content": "Question"},
{"role": "assistant", "content": "Answer"}
]
}
formatted1 = self.formatter.format_sample(sample)
formatted2 = self.formatter.format_sample(sample)
assert formatted1["text"] == formatted2["text"]
def test_unicode_handling(self):
"""Test handling of unicode characters."""
sample = {
"messages": [
{"role": "system", "content": "You are a music assistant. 🎵"},
{"role": "user", "content": "Café au lait? 🎸"},
{"role": "assistant", "content": "That's a great question! 🎹"}
]
}
formatted = self.formatter.format_sample(sample)
assert "🎵" in formatted["text"]
assert "🎸" in formatted["text"]
assert "🎹" in formatted["text"]
assert "Café" in formatted["text"]
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|