File size: 11,337 Bytes
2a623ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b746fa
 
2a623ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""
Input validation utilities for HF EDA MCP Server.

This module provides centralized validation functions for all tool inputs,
ensuring consistent error messages and validation logic across the application.
"""

import re
from typing import Optional, List
from hf_eda_mcp.config import get_config


class ValidationError(ValueError):
    """Custom exception for validation errors with helpful messages."""

    def __init__(self, message: str, suggestions: Optional[List[str]] = None):
        super().__init__(message)
        self.suggestions = suggestions or []


def validate_dataset_id(dataset_id: str) -> str:
    """
    Validate and normalize a HuggingFace dataset identifier.

    Args:
        dataset_id: Dataset identifier to validate

    Returns:
        Normalized dataset_id (stripped of whitespace)

    Raises:
        ValidationError: If dataset_id is invalid with helpful error message
    """
    if not dataset_id:
        raise ValidationError(
            "dataset_id is required and cannot be empty",
            suggestions=[
                "Provide a valid HuggingFace dataset identifier",
                "Examples: 'imdb', 'squad', 'glue', 'username/dataset-name'",
            ],
        )

    if not isinstance(dataset_id, str):
        raise ValidationError(
            f"dataset_id must be a string, got {type(dataset_id).__name__}",
            suggestions=["Ensure dataset_id is passed as a string value"],
        )

    dataset_id = dataset_id.strip()

    if not dataset_id:
        raise ValidationError(
            "dataset_id cannot be empty or contain only whitespace",
            suggestions=["Provide a non-empty dataset identifier"],
        )

    # Validate format: alphanumeric, hyphens, underscores, slashes, dots, @
    # Pattern: optional username/ followed by dataset name
    pattern = r"^[a-zA-Z0-9][\w\-\.@]*(/[\w\-\.]+)?$"
    if not re.match(pattern, dataset_id):
        raise ValidationError(
            f"Invalid dataset_id format: '{dataset_id}'",
            suggestions=[
                "Dataset IDs should contain only letters, numbers, hyphens, underscores, dots, and slashes",
                "Valid formats: 'dataset-name' or 'username/dataset-name'",
                "Examples: 'imdb', 'squad', 'huggingface/dataset-name'",
            ],
        )

    # Check for common mistakes
    if dataset_id.startswith("/") or dataset_id.endswith("/"):
        raise ValidationError(
            f"Invalid dataset_id: '{dataset_id}' - cannot start or end with '/'",
            suggestions=["Remove leading or trailing slashes from the dataset_id"],
        )

    if "//" in dataset_id:
        raise ValidationError(
            f"Invalid dataset_id: '{dataset_id}' - contains consecutive slashes",
            suggestions=["Use single slashes to separate username from dataset name"],
        )

    # Warn about very long dataset IDs (likely an error)
    if len(dataset_id) > 100:
        raise ValidationError(
            f"dataset_id is unusually long ({len(dataset_id)} characters)",
            suggestions=[
                "Check if the dataset_id is correct",
                "Dataset IDs are typically shorter than 100 characters",
            ],
        )

    return dataset_id


def validate_config_name(config_name: Optional[str]) -> Optional[str]:
    """
    Validate and normalize a dataset configuration name.

    Args:
        config_name: Configuration name to validate (can be None)

    Returns:
        Normalized config_name or None

    Raises:
        ValidationError: If config_name is invalid
    """
    if config_name is None:
        return None

    if not isinstance(config_name, str):
        raise ValidationError(
            f"config_name must be a string or None, got {type(config_name).__name__}",
            suggestions=["Pass config_name as a string or omit it for default configuration"],
        )

    config_name = config_name.strip()

    if len(config_name) == 0:
        return None

    # Validate format: alphanumeric, hyphens, underscores, dots
    pattern = r"^[a-zA-Z0-9][\w\-\.]*$"
    if not re.match(pattern, config_name):
        raise ValidationError(
            f"Invalid config_name format: '{config_name}'",
            suggestions=[
                "Configuration names should contain only letters, numbers, hyphens, underscores, and dots",
                "Examples: 'cola', 'sst2', 'plain_text'",
            ],
        )

    if len(config_name) > 50:
        raise ValidationError(
            f"config_name is unusually long ({len(config_name)} characters)",
            suggestions=[
                "Check if the config_name is correct",
                "Configuration names are typically shorter than 50 characters",
            ],
        )

    return config_name


def validate_split_name(split: str) -> str:
    """
    Validate and normalize a dataset split name.

    Args:
        split: Split name to validate

    Returns:
        Normalized split name (lowercase, stripped)

    Raises:
        ValidationError: If split is invalid
    """
    if not split:
        raise ValidationError(
            "split is required and cannot be empty",
            suggestions=[
                "Provide a valid split name",
                "Common splits: 'train', 'validation', 'test'",
            ],
        )

    if not isinstance(split, str):
        raise ValidationError(
            f"split must be a string, got {type(split).__name__}",
            suggestions=["Ensure split is passed as a string value"],
        )

    split = split.strip().lower()

    if not split:
        raise ValidationError(
            "split cannot be empty or contain only whitespace",
            suggestions=["Provide a non-empty split name"],
        )

    # Validate format: alphanumeric, hyphens, underscores
    pattern = r"^[a-zA-Z0-9][\w\-]*$"
    if not re.match(pattern, split):
        raise ValidationError(
            f"Invalid split name format: '{split}'",
            suggestions=[
                "Split names should contain only letters, numbers, hyphens, and underscores",
                "Common splits: 'train', 'validation', 'test', 'dev'",
            ],
        )

    # Note: We don't enforce a specific set of split names as datasets can have custom splits
    # Common splits for reference
    common_splits = {"train", "validation", "test", "dev", "val"}

    if split not in common_splits and len(split) > 20:
        raise ValidationError(
            f"Unusual split name: '{split}' (length: {len(split)})",
            suggestions=[
                "Check if the split name is correct",
                f"Common splits are: {', '.join(sorted(common_splits))}",
                "Some datasets may have custom split names",
            ],
        )

    return split


def validate_sample_size(num_samples: int, parameter_name: str = "num_samples") -> int:
    """
    Validate sample size parameter.

    Args:
        num_samples: Number of samples to validate
        parameter_name: Name of the parameter (for error messages)

    Returns:
        Validated num_samples

    Raises:
        ValidationError: If num_samples is invalid
    """
    if not isinstance(num_samples, int):
        # Check if it's a float that's actually an integer
        if isinstance(num_samples, float) and num_samples.is_integer():
            num_samples = int(num_samples)
        else:
            raise ValidationError(
                f"{parameter_name} must be an integer, got {type(num_samples).__name__}",
                suggestions=[
                    f"Provide {parameter_name} as an integer value",
                    "Example: num_samples=100",
                ],
            )

    if num_samples <= 0:
        raise ValidationError(
            f"{parameter_name} must be positive, got {num_samples}",
            suggestions=[
                f"Provide a positive integer for {parameter_name}",
                "Example: num_samples=10 or num_samples=1000",
            ],
        )

    # Get max sample size from config
    config = get_config()
    max_sample_size = config.max_sample_size

    if num_samples > max_sample_size:
        raise ValidationError(
            f"{parameter_name} ({num_samples}) exceeds maximum allowed ({max_sample_size})",
            suggestions=[
                f"Reduce {parameter_name} to {max_sample_size} or less",
                f"Current maximum is configured as {max_sample_size}",
                "For larger samples, consider using streaming or batch processing",
            ],
        )

    # Warn about very small samples (might not be useful)
    if num_samples < 5:
        # This is just a soft warning, not an error
        pass

    return num_samples


def validate_indices(indices: List[int]) -> List[int]:
    """
    Validate a list of indices for sampling.

    Args:
        indices: List of indices to validate

    Returns:
        Validated indices list

    Raises:
        ValidationError: If indices are invalid
    """
    if not indices:
        raise ValidationError(
            "indices list is required and cannot be empty",
            suggestions=[
                "Provide a non-empty list of indices",
                "Example: indices=[0, 1, 2, 10, 20]",
            ],
        )

    if not isinstance(indices, list):
        raise ValidationError(
            f"indices must be a list, got {type(indices).__name__}",
            suggestions=[
                "Provide indices as a list of integers",
                "Example: indices=[0, 1, 2]",
            ],
        )

    # Validate each index
    for i, idx in enumerate(indices):
        if not isinstance(idx, int):
            raise ValidationError(
                f"All indices must be integers, got {type(idx).__name__} at position {i}",
                suggestions=[
                    "Ensure all indices are integer values",
                    "Example: indices=[0, 1, 2] (not [0.5, 1.2])",
                ],
            )

        if idx < 0:
            raise ValidationError(
                f"All indices must be non-negative, got {idx} at position {i}",
                suggestions=[
                    "Provide only non-negative indices (0 or greater)",
                    "Example: indices=[0, 1, 2, 10]",
                ],
            )

    # Check for reasonable list size
    config = get_config()
    max_sample_size = config.max_sample_size

    if len(indices) > max_sample_size:
        raise ValidationError(
            f"Too many indices requested ({len(indices)}), maximum is {max_sample_size}",
            suggestions=[
                f"Reduce the number of indices to {max_sample_size} or less",
                "Consider using regular sampling instead of specific indices",
            ],
        )

    return indices


def format_validation_error(error: ValidationError) -> str:
    """
    Format a validation error with suggestions into a user-friendly message.

    Args:
        error: ValidationError to format

    Returns:
        Formatted error message with suggestions
    """
    message = str(error)

    if error.suggestions:
        message += "\n\nSuggestions:"
        for suggestion in error.suggestions:
            message += f"\n  - {suggestion}"

    return message