File size: 4,088 Bytes
b93a6a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Tool schema validator using Pydantic models.

This module validates tool call parameters against defined schemas
before execution to prevent errors and security issues.
"""

import logging
from typing import Dict, Any
from pydantic import BaseModel, Field, ValidationError

logger = logging.getLogger(__name__)


# Pydantic models for tool parameters

class AddTaskParams(BaseModel):
    """Parameters for add_task tool."""
    user_id: int = Field(..., gt=0, description="User ID")
    title: str = Field(..., min_length=1, max_length=255, description="Task title")
    description: str = Field(default="", max_length=1000, description="Task description")


class ListTasksParams(BaseModel):
    """Parameters for list_tasks tool."""
    user_id: int = Field(..., gt=0, description="User ID")
    filter: str = Field(default="all", pattern="^(all|pending|completed)$", description="Task filter")
    limit: int = Field(default=50, ge=1, le=100, description="Maximum tasks to return")


class CompleteTaskParams(BaseModel):
    """Parameters for complete_task tool."""
    user_id: int = Field(..., gt=0, description="User ID")
    task_id: int = Field(default=None, gt=0, description="Task ID")
    task_title: str = Field(default=None, min_length=1, description="Task title for matching")


class DeleteTaskParams(BaseModel):
    """Parameters for delete_task tool."""
    user_id: int = Field(..., gt=0, description="User ID")
    task_id: int = Field(default=None, gt=0, description="Task ID")
    task_title: str = Field(default=None, min_length=1, description="Task title for matching")


class UpdateTaskParams(BaseModel):
    """Parameters for update_task tool."""
    user_id: int = Field(..., gt=0, description="User ID")
    task_id: int = Field(default=None, gt=0, description="Task ID")
    task_title: str = Field(default=None, min_length=1, description="Current task title for matching")
    new_title: str = Field(default=None, min_length=1, max_length=255, description="New task title")
    new_description: str = Field(default=None, max_length=1000, description="New task description")


class ToolValidator:
    """Validates tool call parameters against schemas."""

    # Map tool names to their parameter models
    TOOL_SCHEMAS = {
        "add_task": AddTaskParams,
        "list_tasks": ListTasksParams,
        "complete_task": CompleteTaskParams,
        "delete_task": DeleteTaskParams,
        "update_task": UpdateTaskParams
    }

    @classmethod
    def validate(cls, tool_name: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
        """
        Validate tool parameters against schema.

        Args:
            tool_name: Name of the tool
            parameters: Parameters to validate

        Returns:
            Validated parameters dictionary

        Raises:
            ValueError: If tool name is unknown
            ValidationError: If parameters are invalid
        """
        if tool_name not in cls.TOOL_SCHEMAS:
            raise ValueError(f"Unknown tool: {tool_name}")

        schema = cls.TOOL_SCHEMAS[tool_name]

        try:
            validated = schema(**parameters)
            logger.info(f"Tool parameters validated: {tool_name}")
            return validated.model_dump()
        except ValidationError as e:
            logger.error(f"Tool parameter validation failed: {tool_name} - {str(e)}")
            raise

    @classmethod
    def validate_tool_call(cls, tool_call: Dict[str, Any]) -> bool:
        """
        Validate that a tool call has the required structure.

        Args:
            tool_call: Tool call dictionary

        Returns:
            True if valid structure, False otherwise
        """
        if not isinstance(tool_call, dict):
            return False

        if "name" not in tool_call or "parameters" not in tool_call:
            return False

        if not isinstance(tool_call["name"], str):
            return False

        if not isinstance(tool_call["parameters"], dict):
            return False

        return True


# Singleton instance
tool_validator = ToolValidator()