WritingStudio / tests /unit /test_validation.py
jmisak's picture
Upload 41 files
aeb3f7c verified
"""Unit tests for validation utilities."""
import pytest
from writing_studio.core.exceptions import ValidationError
from writing_studio.utils.validation import (
sanitize_text,
validate_text_input,
validate_model_name,
validate_generation_params,
)
class TestSanitizeText:
"""Tests for text sanitization."""
def test_sanitize_removes_null_bytes(self):
"""Test that null bytes are removed."""
text = "Hello\x00World"
result = sanitize_text(text)
assert "\x00" not in result
def test_sanitize_normalizes_whitespace(self):
"""Test that whitespace is normalized."""
text = "Hello World\n\nTest"
result = sanitize_text(text)
assert " " not in result
def test_sanitize_empty_string(self):
"""Test sanitization of empty string."""
assert sanitize_text("") == ""
class TestValidateTextInput:
"""Tests for text input validation."""
def test_valid_text(self):
"""Test validation of valid text."""
text = "This is a valid text input."
result = validate_text_input(text)
assert result == text.strip()
def test_text_too_short(self):
"""Test validation fails for text below minimum length."""
with pytest.raises(ValidationError) as exc:
validate_text_input("", min_length=1)
assert "at least" in exc.value.message
def test_text_too_long(self):
"""Test validation fails for text exceeding maximum length."""
long_text = "a" * 10001
with pytest.raises(ValidationError) as exc:
validate_text_input(long_text, max_length=10000)
assert "exceeds maximum" in exc.value.message
def test_non_string_input(self):
"""Test validation fails for non-string input."""
with pytest.raises(ValidationError) as exc:
validate_text_input(123)
assert "must be a string" in exc.value.message
class TestValidateModelName:
"""Tests for model name validation."""
def test_valid_model_name(self):
"""Test validation of valid model name."""
assert validate_model_name("distilgpt2") == "distilgpt2"
assert validate_model_name("gpt2-medium") == "gpt2-medium"
assert validate_model_name("organization/model-name") == "organization/model-name"
def test_empty_model_name(self):
"""Test validation fails for empty model name."""
with pytest.raises(ValidationError):
validate_model_name("")
def test_path_traversal_attempt(self):
"""Test validation fails for path traversal attempts."""
with pytest.raises(ValidationError):
validate_model_name("../etc/passwd")
class TestValidateGenerationParams:
"""Tests for generation parameter validation."""
def test_valid_params(self):
"""Test validation of valid parameters."""
result = validate_generation_params(100, 1, 1.0)
assert result["max_length"] == 100
assert result["num_sequences"] == 1
assert result["temperature"] == 1.0
def test_invalid_max_length(self):
"""Test validation fails for invalid max_length."""
with pytest.raises(ValidationError):
validate_generation_params(0, 1, 1.0)
def test_invalid_num_sequences(self):
"""Test validation fails for too many sequences."""
with pytest.raises(ValidationError):
validate_generation_params(100, 10, 1.0)