Spaces:
Sleeping
Sleeping
| """Unit tests for validation utilities.""" | |
| import pytest | |
| from writing_studio.core.exceptions import ValidationError | |
| from writing_studio.utils.validation import ( | |
| sanitize_text, | |
| validate_text_input, | |
| validate_model_name, | |
| validate_generation_params, | |
| ) | |
| class TestSanitizeText: | |
| """Tests for text sanitization.""" | |
| def test_sanitize_removes_null_bytes(self): | |
| """Test that null bytes are removed.""" | |
| text = "Hello\x00World" | |
| result = sanitize_text(text) | |
| assert "\x00" not in result | |
| def test_sanitize_normalizes_whitespace(self): | |
| """Test that whitespace is normalized.""" | |
| text = "Hello World\n\nTest" | |
| result = sanitize_text(text) | |
| assert " " not in result | |
| def test_sanitize_empty_string(self): | |
| """Test sanitization of empty string.""" | |
| assert sanitize_text("") == "" | |
| class TestValidateTextInput: | |
| """Tests for text input validation.""" | |
| def test_valid_text(self): | |
| """Test validation of valid text.""" | |
| text = "This is a valid text input." | |
| result = validate_text_input(text) | |
| assert result == text.strip() | |
| def test_text_too_short(self): | |
| """Test validation fails for text below minimum length.""" | |
| with pytest.raises(ValidationError) as exc: | |
| validate_text_input("", min_length=1) | |
| assert "at least" in exc.value.message | |
| def test_text_too_long(self): | |
| """Test validation fails for text exceeding maximum length.""" | |
| long_text = "a" * 10001 | |
| with pytest.raises(ValidationError) as exc: | |
| validate_text_input(long_text, max_length=10000) | |
| assert "exceeds maximum" in exc.value.message | |
| def test_non_string_input(self): | |
| """Test validation fails for non-string input.""" | |
| with pytest.raises(ValidationError) as exc: | |
| validate_text_input(123) | |
| assert "must be a string" in exc.value.message | |
| class TestValidateModelName: | |
| """Tests for model name validation.""" | |
| def test_valid_model_name(self): | |
| """Test validation of valid model name.""" | |
| assert validate_model_name("distilgpt2") == "distilgpt2" | |
| assert validate_model_name("gpt2-medium") == "gpt2-medium" | |
| assert validate_model_name("organization/model-name") == "organization/model-name" | |
| def test_empty_model_name(self): | |
| """Test validation fails for empty model name.""" | |
| with pytest.raises(ValidationError): | |
| validate_model_name("") | |
| def test_path_traversal_attempt(self): | |
| """Test validation fails for path traversal attempts.""" | |
| with pytest.raises(ValidationError): | |
| validate_model_name("../etc/passwd") | |
| class TestValidateGenerationParams: | |
| """Tests for generation parameter validation.""" | |
| def test_valid_params(self): | |
| """Test validation of valid parameters.""" | |
| result = validate_generation_params(100, 1, 1.0) | |
| assert result["max_length"] == 100 | |
| assert result["num_sequences"] == 1 | |
| assert result["temperature"] == 1.0 | |
| def test_invalid_max_length(self): | |
| """Test validation fails for invalid max_length.""" | |
| with pytest.raises(ValidationError): | |
| validate_generation_params(0, 1, 1.0) | |
| def test_invalid_num_sequences(self): | |
| """Test validation fails for too many sequences.""" | |
| with pytest.raises(ValidationError): | |
| validate_generation_params(100, 10, 1.0) | |