""" Tests for Chat Formatter. """ import pytest import json from pathlib import Path from TouchGrass.data.chat_formatter import ChatFormatter, format_chat_qwen, validate_sample class TestChatFormatter: """Test suite for ChatFormatter.""" def setup_method(self): """Set up test fixtures.""" self.formatter = ChatFormatter() def test_formatter_initialization(self): """Test that formatter initializes correctly.""" assert hasattr(self.formatter, "format_sample") assert hasattr(self.formatter, "format_dataset") assert hasattr(self.formatter, "save_dataset") assert hasattr(self.formatter, "create_splits") def test_format_single_sample(self): """Test formatting a single valid sample.""" sample = { "messages": [ {"role": "system", "content": "You are a music assistant."}, {"role": "user", "content": "How do I play a C chord?"}, {"role": "assistant", "content": "Place your fingers on the 1st, 2nd, and 3rd strings at the 1st fret."} ] } formatted = self.formatter.format_sample(sample) assert "text" in formatted assert isinstance(formatted["text"], str) # Should contain system, user, assistant markers text = formatted["text"] assert "system" in text assert "user" in text assert "assistant" in text def test_format_sample_without_system(self): """Test formatting a sample without system message.""" sample = { "messages": [ {"role": "user", "content": "What is a scale?"}, {"role": "assistant", "content": "A scale is a sequence of notes in ascending or descending order."} ] } formatted = self.formatter.format_sample(sample) assert "text" in formatted # Should still work without system assert "user" in formatted["text"] assert "assistant" in formatted["text"] def test_format_sample_multiple_turns(self): """Test formatting a sample with multiple conversation turns.""" sample = { "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Question 1"}, {"role": "assistant", "content": "Answer 1"}, {"role": "user", "content": "Follow-up question"}, {"role": "assistant", "content": "Follow-up answer"} ] } formatted = self.formatter.format_sample(sample) text = formatted["text"] # Should have multiple user/assistant pairs assert text.count("user") >= 2 assert text.count("assistant") >= 2 def test_validate_sample_valid(self): """Test sample validation with valid sample.""" sample = { "messages": [ {"role": "system", "content": "Test system"}, {"role": "user", "content": "Test user"}, {"role": "assistant", "content": "Test assistant"} ] } is_valid, error = validate_sample(sample) assert is_valid is True assert error is None def test_validate_sample_missing_role(self): """Test sample validation with missing role.""" sample = { "messages": [ {"content": "Missing role field"}, ] } is_valid, error = validate_sample(sample) assert is_valid is False assert "role" in error.lower() def test_validate_sample_missing_content(self): """Test sample validation with missing content.""" sample = { "messages": [ {"role": "user"}, ] } is_valid, error = validate_sample(sample) assert is_valid is False assert "content" in error.lower() def test_validate_sample_invalid_role(self): """Test sample validation with invalid role.""" sample = { "messages": [ {"role": "invalid", "content": "Test"} ] } is_valid, error = validate_sample(sample) assert is_valid is False assert "role" in error.lower() def test_validate_sample_empty_messages(self): """Test sample validation with empty messages list.""" sample = {"messages": []} is_valid, error = validate_sample(sample) assert is_valid is False assert "empty" in error.lower() or "message" in error.lower() def test_format_dataset(self): """Test formatting a full dataset.""" dataset = [ { "messages": [ {"role": "system", "content": "System 1"}, {"role": "user", "content": "User 1"}, {"role": "assistant", "content": "Assistant 1"} ] }, { "messages": [ {"role": "system", "content": "System 2"}, {"role": "user", "content": "User 2"}, {"role": "assistant", "content": "Assistant 2"} ] } ] formatted = self.formatter.format_dataset(dataset) assert len(formatted) == 2 for item in formatted: assert "text" in item assert isinstance(item["text"], str) def test_save_dataset_jsonl(self, tmp_path): """Test saving formatted dataset as JSONL.""" formatted = [ {"text": "Sample 1"}, {"text": "Sample 2"}, {"text": "Sample 3"} ] output_path = tmp_path / "test_output.jsonl" self.formatter.save_dataset(formatted, str(output_path), format="jsonl") assert output_path.exists() # Verify content with open(output_path, 'r', encoding='utf-8') as f: lines = f.readlines() assert len(lines) == 3 for line in lines: data = json.loads(line) assert "text" in data def test_save_dataset_json(self, tmp_path): """Test saving formatted dataset as JSON.""" formatted = [ {"text": "Sample 1"}, {"text": "Sample 2"} ] output_path = tmp_path / "test_output.json" self.formatter.save_dataset(formatted, str(output_path), format="json") assert output_path.exists() with open(output_path, 'r', encoding='utf-8') as f: data = json.load(f) assert isinstance(data, list) assert len(data) == 2 def test_create_splits(self): """Test train/val split creation.""" dataset = [{"text": f"Sample {i}"} for i in range(100)] train, val = self.formatter.create_splits(dataset, val_size=0.2) assert len(train) == 80 assert len(val) == 20 # Check no overlap train_ids = [id(d) for d in train] val_ids = [id(d) for d in val] assert len(set(train_ids) & set(val_ids)) == 0 def test_create_splits_with_seed(self): """Test that splits are reproducible with seed.""" dataset = [{"text": f"Sample {i}"} for i in range(100)] train1, val1 = self.formatter.create_splits(dataset, val_size=0.2, seed=42) train2, val2 = self.formatter.create_splits(dataset, val_size=0.2, seed=42) # Should be identical assert [d["text"] for d in train1] == [d["text"] for d in train2] assert [d["text"] for d in val1] == [d["text"] for d in val2] def test_format_preserves_original(self): """Test that formatting doesn't modify original samples.""" original = { "messages": [ {"role": "user", "content": "Original question"}, {"role": "assistant", "content": "Original answer"} ], "category": "test" } formatted = self.formatter.format_sample(original) # Original should be unchanged assert "category" in original assert "messages" in original assert len(original["messages"]) == 2 def test_qwen_format_system_first(self): """Test that Qwen format places system message first.""" sample = { "messages": [ {"role": "user", "content": "User message"}, {"role": "system", "content": "System message"}, {"role": "assistant", "content": "Assistant message"} ] } formatted = self.formatter.format_sample(sample) text = formatted["text"] # System should appear before user in the formatted text system_pos = text.find("system") user_pos = text.find("user") assert system_pos < user_pos def test_format_with_special_tokens(self): """Test formatting with special music tokens.""" sample = { "messages": [ {"role": "system", "content": "You are a [GUITAR] assistant."}, {"role": "user", "content": "How do I play a [CHORD]?"}, {"role": "assistant", "content": "Use [TAB] notation."} ] } formatted = self.formatter.format_sample(sample) text = formatted["text"] # Special tokens should be preserved assert "[GUITAR]" in text assert "[CHORD]" in text assert "[TAB]" in text def test_empty_content_handling(self): """Test handling of empty message content.""" sample = { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": "Valid question"}, {"role": "assistant", "content": "Valid answer"} ] } is_valid, error = validate_sample(sample) # Empty system content might be allowed or not depending on policy # Here we just check it's handled assert is_valid in [True, False] def test_large_dataset_processing(self): """Test processing a larger dataset.""" dataset = [ { "messages": [ {"role": "system", "content": f"System {i}"}, {"role": "user", "content": f"Question {i}"}, {"role": "assistant", "content": f"Answer {i}"} ] } for i in range(500) ] formatted = self.formatter.format_dataset(dataset) assert len(formatted) == 500 for item in formatted: assert "text" in item assert len(item["text"]) > 0 def test_format_consistency(self): """Test that same input produces same output.""" sample = { "messages": [ {"role": "system", "content": "Test"}, {"role": "user", "content": "Question"}, {"role": "assistant", "content": "Answer"} ] } formatted1 = self.formatter.format_sample(sample) formatted2 = self.formatter.format_sample(sample) assert formatted1["text"] == formatted2["text"] def test_unicode_handling(self): """Test handling of unicode characters.""" sample = { "messages": [ {"role": "system", "content": "You are a music assistant. 🎵"}, {"role": "user", "content": "Café au lait? 🎸"}, {"role": "assistant", "content": "That's a great question! 🎹"} ] } formatted = self.formatter.format_sample(sample) assert "🎵" in formatted["text"] assert "🎸" in formatted["text"] assert "🎹" in formatted["text"] assert "Café" in formatted["text"] if __name__ == "__main__": pytest.main([__file__, "-v"])