File size: 3,478 Bytes
aeb3f7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Unit tests for validation utilities."""

import pytest

from writing_studio.core.exceptions import ValidationError
from writing_studio.utils.validation import (
    sanitize_text,
    validate_text_input,
    validate_model_name,
    validate_generation_params,
)


class TestSanitizeText:
    """Tests for text sanitization."""

    def test_sanitize_removes_null_bytes(self):
        """Test that null bytes are removed."""
        text = "Hello\x00World"
        result = sanitize_text(text)
        assert "\x00" not in result

    def test_sanitize_normalizes_whitespace(self):
        """Test that whitespace is normalized."""
        text = "Hello    World\n\nTest"
        result = sanitize_text(text)
        assert "  " not in result

    def test_sanitize_empty_string(self):
        """Test sanitization of empty string."""
        assert sanitize_text("") == ""


class TestValidateTextInput:
    """Tests for text input validation."""

    def test_valid_text(self):
        """Test validation of valid text."""
        text = "This is a valid text input."
        result = validate_text_input(text)
        assert result == text.strip()

    def test_text_too_short(self):
        """Test validation fails for text below minimum length."""
        with pytest.raises(ValidationError) as exc:
            validate_text_input("", min_length=1)
        assert "at least" in exc.value.message

    def test_text_too_long(self):
        """Test validation fails for text exceeding maximum length."""
        long_text = "a" * 10001
        with pytest.raises(ValidationError) as exc:
            validate_text_input(long_text, max_length=10000)
        assert "exceeds maximum" in exc.value.message

    def test_non_string_input(self):
        """Test validation fails for non-string input."""
        with pytest.raises(ValidationError) as exc:
            validate_text_input(123)
        assert "must be a string" in exc.value.message


class TestValidateModelName:
    """Tests for model name validation."""

    def test_valid_model_name(self):
        """Test validation of valid model name."""
        assert validate_model_name("distilgpt2") == "distilgpt2"
        assert validate_model_name("gpt2-medium") == "gpt2-medium"
        assert validate_model_name("organization/model-name") == "organization/model-name"

    def test_empty_model_name(self):
        """Test validation fails for empty model name."""
        with pytest.raises(ValidationError):
            validate_model_name("")

    def test_path_traversal_attempt(self):
        """Test validation fails for path traversal attempts."""
        with pytest.raises(ValidationError):
            validate_model_name("../etc/passwd")


class TestValidateGenerationParams:
    """Tests for generation parameter validation."""

    def test_valid_params(self):
        """Test validation of valid parameters."""
        result = validate_generation_params(100, 1, 1.0)
        assert result["max_length"] == 100
        assert result["num_sequences"] == 1
        assert result["temperature"] == 1.0

    def test_invalid_max_length(self):
        """Test validation fails for invalid max_length."""
        with pytest.raises(ValidationError):
            validate_generation_params(0, 1, 1.0)

    def test_invalid_num_sequences(self):
        """Test validation fails for too many sequences."""
        with pytest.raises(ValidationError):
            validate_generation_params(100, 10, 1.0)