"""Unit tests for validation utilities.""" import pytest from writing_studio.core.exceptions import ValidationError from writing_studio.utils.validation import ( sanitize_text, validate_text_input, validate_model_name, validate_generation_params, ) class TestSanitizeText: """Tests for text sanitization.""" def test_sanitize_removes_null_bytes(self): """Test that null bytes are removed.""" text = "Hello\x00World" result = sanitize_text(text) assert "\x00" not in result def test_sanitize_normalizes_whitespace(self): """Test that whitespace is normalized.""" text = "Hello World\n\nTest" result = sanitize_text(text) assert " " not in result def test_sanitize_empty_string(self): """Test sanitization of empty string.""" assert sanitize_text("") == "" class TestValidateTextInput: """Tests for text input validation.""" def test_valid_text(self): """Test validation of valid text.""" text = "This is a valid text input." result = validate_text_input(text) assert result == text.strip() def test_text_too_short(self): """Test validation fails for text below minimum length.""" with pytest.raises(ValidationError) as exc: validate_text_input("", min_length=1) assert "at least" in exc.value.message def test_text_too_long(self): """Test validation fails for text exceeding maximum length.""" long_text = "a" * 10001 with pytest.raises(ValidationError) as exc: validate_text_input(long_text, max_length=10000) assert "exceeds maximum" in exc.value.message def test_non_string_input(self): """Test validation fails for non-string input.""" with pytest.raises(ValidationError) as exc: validate_text_input(123) assert "must be a string" in exc.value.message class TestValidateModelName: """Tests for model name validation.""" def test_valid_model_name(self): """Test validation of valid model name.""" assert validate_model_name("distilgpt2") == "distilgpt2" assert validate_model_name("gpt2-medium") == "gpt2-medium" assert validate_model_name("organization/model-name") == "organization/model-name" def test_empty_model_name(self): """Test validation fails for empty model name.""" with pytest.raises(ValidationError): validate_model_name("") def test_path_traversal_attempt(self): """Test validation fails for path traversal attempts.""" with pytest.raises(ValidationError): validate_model_name("../etc/passwd") class TestValidateGenerationParams: """Tests for generation parameter validation.""" def test_valid_params(self): """Test validation of valid parameters.""" result = validate_generation_params(100, 1, 1.0) assert result["max_length"] == 100 assert result["num_sequences"] == 1 assert result["temperature"] == 1.0 def test_invalid_max_length(self): """Test validation fails for invalid max_length.""" with pytest.raises(ValidationError): validate_generation_params(0, 1, 1.0) def test_invalid_num_sequences(self): """Test validation fails for too many sequences.""" with pytest.raises(ValidationError): validate_generation_params(100, 10, 1.0)