File size: 4,154 Bytes
aeb3f7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Input validation utilities."""

import re
from typing import Optional

from writing_studio.core.config import settings
from writing_studio.core.exceptions import ValidationError
from writing_studio.utils.logging import logger


def sanitize_text(text: str) -> str:
    """
    Sanitize input text by removing potentially harmful content.

    Args:
        text: Input text to sanitize

    Returns:
        Sanitized text
    """
    if not text:
        return ""

    # Remove null bytes
    text = text.replace("\x00", "")

    # Normalize whitespace
    text = re.sub(r"\s+", " ", text)

    # Strip leading/trailing whitespace
    text = text.strip()

    return text


def validate_text_input(
    text: str, max_length: Optional[int] = None, min_length: int = 1
) -> str:
    """
    Validate and sanitize text input.

    Args:
        text: Input text to validate
        max_length: Maximum allowed length (default: from settings)
        min_length: Minimum allowed length

    Returns:
        Validated and sanitized text

    Raises:
        ValidationError: If validation fails
    """
    if not isinstance(text, str):
        raise ValidationError("Input must be a string", {"type": type(text).__name__})

    # Sanitize
    text = sanitize_text(text)

    # Check minimum length
    if len(text) < min_length:
        raise ValidationError(
            f"Text must be at least {min_length} characters",
            {"length": len(text), "min_length": min_length},
        )

    # Check maximum length
    max_len = max_length or settings.max_text_length
    if len(text) > max_len:
        logger.warning(f"Text exceeds maximum length: {len(text)} > {max_len}")
        raise ValidationError(
            f"Text exceeds maximum length of {max_len} characters",
            {"length": len(text), "max_length": max_len},
        )

    return text


def validate_model_name(model_name: str) -> str:
    """
    Validate HuggingFace model name.

    Args:
        model_name: Model identifier

    Returns:
        Validated model name

    Raises:
        ValidationError: If validation fails
    """
    if not isinstance(model_name, str):
        raise ValidationError("Model name must be a string", {"type": type(model_name).__name__})

    model_name = model_name.strip()

    if not model_name:
        raise ValidationError("Model name cannot be empty")

    # Basic validation for HuggingFace model names
    # Format: organization/model-name or just model-name
    if not re.match(r"^[a-zA-Z0-9][\w\-./]*$", model_name):
        raise ValidationError(
            "Invalid model name format", {"model_name": model_name}
        )

    # Check for path traversal attempts
    if ".." in model_name or model_name.startswith("/"):
        raise ValidationError(
            "Model name contains invalid characters", {"model_name": model_name}
        )

    return model_name


def validate_generation_params(
    max_length: int, num_sequences: int, temperature: float = 1.0
) -> dict:
    """
    Validate text generation parameters.

    Args:
        max_length: Maximum generation length
        num_sequences: Number of sequences to generate
        temperature: Sampling temperature

    Returns:
        Validated parameters

    Raises:
        ValidationError: If validation fails
    """
    errors = {}

    if not isinstance(max_length, int) or max_length < 1:
        errors["max_length"] = "Must be a positive integer"

    if max_length > settings.max_model_length:
        errors["max_length"] = f"Exceeds maximum of {settings.max_model_length}"

    if not isinstance(num_sequences, int) or num_sequences < 1:
        errors["num_sequences"] = "Must be a positive integer"

    if num_sequences > 5:
        errors["num_sequences"] = "Cannot exceed 5 sequences"

    if not isinstance(temperature, (int, float)) or temperature <= 0:
        errors["temperature"] = "Must be a positive number"

    if errors:
        raise ValidationError("Invalid generation parameters", errors)

    return {
        "max_length": max_length,
        "num_sequences": num_sequences,
        "temperature": temperature,
    }