File size: 6,515 Bytes
108d8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9eebeb3
108d8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Input validation and sanitization for tools and API endpoints.

Validates and sanitizes all inputs to ensure data quality and security.
"""

from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, validator, ValidationError
import logging

logger = logging.getLogger(__name__)


class SearchArgs(BaseModel):
    """Validated search arguments"""
    query: str = Field(..., min_length=1, max_length=500, description="Search query")
    search_type: str = Field(default="all", description="Search type: all, products, documentation")
    top_k: int = Field(default=5, ge=1, le=50, description="Number of results")
    
    @validator('search_type')
    def validate_search_type(cls, v):
        if v not in ("all", "products", "documentation"):
            raise ValueError(f"Invalid search_type: {v}")
        return v


class QueryArgs(BaseModel):
    """Validated query arguments"""
    question: str = Field(..., min_length=1, max_length=1000, description="Question")
    top_k: Optional[int] = Field(default=None, ge=1, le=50, description="Number of sources")


class ProductAnalysisArgs(BaseModel):
    """Validated product analysis arguments"""
    product_name: str = Field(..., min_length=1, max_length=200, description="Product name")
    category: Optional[str] = Field(default="general", max_length=100, description="Product category")
    description: Optional[str] = Field(default="", max_length=2000, description="Product description")
    current_price: Optional[float] = Field(default=None, ge=0, description="Current price")


class ReviewAnalysisArgs(BaseModel):
    """Validated review analysis arguments"""
    reviews: list = Field(..., min_items=1, max_items=100, description="List of reviews")
    product_name: Optional[str] = Field(default="Product", max_length=200, description="Product name")
    
    @validator('reviews')
    def validate_reviews(cls, v):
        # Ensure all reviews are strings
        validated = []
        for review in v:
            if not isinstance(review, str):
                raise ValueError(f"Review must be string, got {type(review)}")
            if len(review) > 5000:
                raise ValueError("Review exceeds 5000 characters")
            validated.append(review)
        return validated


class ListingGenerationArgs(BaseModel):
    """Validated listing generation arguments"""
    product_name: str = Field(..., min_length=1, max_length=200, description="Product name")
    features: list = Field(..., min_items=1, max_items=20, description="Product features")
    target_audience: Optional[str] = Field(default="general consumers", max_length=200)
    style: Optional[str] = Field(default="professional", description="Tone style")
    
    @validator('features')
    def validate_features(cls, v):
        validated = []
        for feature in v:
            if not isinstance(feature, str):
                raise ValueError(f"Feature must be string, got {type(feature)}")
            if len(feature) > 200:
                raise ValueError("Feature exceeds 200 characters")
            validated.append(feature)
        return validated
    
    @validator('style')
    def validate_style(cls, v):
        if v not in ("luxury", "budget", "professional", "casual"):
            raise ValueError(f"Invalid style: {v}")
        return v


class PricingArgs(BaseModel):
    """Validated pricing recommendation arguments"""
    product_name: str = Field(..., min_length=1, max_length=200)
    cost: float = Field(..., ge=0.01, description="Product cost")
    category: Optional[str] = Field(default="general", max_length=100)
    target_margin: Optional[float] = Field(default=50, ge=0, le=500, description="Target profit margin %")


class CompetitorAnalysisArgs(BaseModel):
    """Validated competitor analysis arguments"""
    product_name: str = Field(..., min_length=1, max_length=200)
    category: Optional[str] = Field(default="general", max_length=100)
    key_competitors: Optional[list] = Field(default=None, max_items=10, description="Competitor names")
    
    @validator('key_competitors')
    def validate_competitors(cls, v):
        if v is None:
            return v
        validated = []
        for competitor in v:
            if not isinstance(competitor, str):
                raise ValueError(f"Competitor must be string, got {type(competitor)}")
            if len(competitor) > 200:
                raise ValueError("Competitor name exceeds 200 characters")
            validated.append(competitor)
        return validated


def validate_tool_args(tool_name: str, arguments: Dict[str, Any]) -> tuple[bool, Any, Optional[str]]:
    """
    Validate tool arguments
    
    Args:
        tool_name: Name of the tool
        arguments: Tool arguments
        
    Returns:
        Tuple of (is_valid, validated_args, error_message)
    """
    try:
        if tool_name == "knowledge_search":
            args = SearchArgs(**arguments)
        elif tool_name == "product_query":
            args = QueryArgs(**arguments)
        elif tool_name == "analyze_product":
            args = ProductAnalysisArgs(**arguments)
        elif tool_name == "analyze_reviews":
            args = ReviewAnalysisArgs(**arguments)
        elif tool_name == "generate_listing":
            args = ListingGenerationArgs(**arguments)
        elif tool_name == "price_recommendation":
            args = PricingArgs(**arguments)
        elif tool_name == "competitor_analysis":
            args = CompetitorAnalysisArgs(**arguments)
        else:
            return False, None, f"Unknown tool: {tool_name}"
        
        return True, args.dict(), None
        
    except ValidationError as e:
        error_msg = f"Validation error: {e.errors()}"
        logger.warning(f"{tool_name} validation failed: {error_msg}")
        return False, None, error_msg
    
    except Exception as e:
        error_msg = f"Unexpected validation error: {str(e)}"
        logger.error(f"{tool_name} validation error: {error_msg}")
        return False, None, error_msg


def sanitize_string(s: str, max_length: int = 5000) -> str:
    """
    Sanitize string input
    
    Args:
        s: Input string
        max_length: Maximum allowed length
        
    Returns:
        Sanitized string
    """
    if not isinstance(s, str):
        return ""
    
    # Truncate if too long
    if len(s) > max_length:
        s = s[:max_length]
    
    # Remove potentially harmful characters
    s = s.strip()
    
    return s