""" Input validation and sanitization for tools and API endpoints. Validates and sanitizes all inputs to ensure data quality and security. """ from typing import Any, Dict, Optional from pydantic import BaseModel, Field, validator, ValidationError import logging logger = logging.getLogger(__name__) class SearchArgs(BaseModel): """Validated search arguments""" query: str = Field(..., min_length=1, max_length=500, description="Search query") search_type: str = Field(default="all", description="Search type: all, products, documentation") top_k: int = Field(default=5, ge=1, le=50, description="Number of results") @validator('search_type') def validate_search_type(cls, v): if v not in ("all", "products", "documentation"): raise ValueError(f"Invalid search_type: {v}") return v class QueryArgs(BaseModel): """Validated query arguments""" question: str = Field(..., min_length=1, max_length=1000, description="Question") top_k: Optional[int] = Field(default=None, ge=1, le=50, description="Number of sources") class ProductAnalysisArgs(BaseModel): """Validated product analysis arguments""" product_name: str = Field(..., min_length=1, max_length=200, description="Product name") category: Optional[str] = Field(default="general", max_length=100, description="Product category") description: Optional[str] = Field(default="", max_length=2000, description="Product description") current_price: Optional[float] = Field(default=None, ge=0, description="Current price") class ReviewAnalysisArgs(BaseModel): """Validated review analysis arguments""" reviews: list = Field(..., min_items=1, max_items=100, description="List of reviews") product_name: Optional[str] = Field(default="Product", max_length=200, description="Product name") @validator('reviews') def validate_reviews(cls, v): # Ensure all reviews are strings validated = [] for review in v: if not isinstance(review, str): raise ValueError(f"Review must be string, got {type(review)}") if len(review) > 5000: raise ValueError("Review exceeds 5000 characters") validated.append(review) return validated class ListingGenerationArgs(BaseModel): """Validated listing generation arguments""" product_name: str = Field(..., min_length=1, max_length=200, description="Product name") features: list = Field(..., min_items=1, max_items=20, description="Product features") target_audience: Optional[str] = Field(default="general consumers", max_length=200) style: Optional[str] = Field(default="professional", description="Tone style") @validator('features') def validate_features(cls, v): validated = [] for feature in v: if not isinstance(feature, str): raise ValueError(f"Feature must be string, got {type(feature)}") if len(feature) > 200: raise ValueError("Feature exceeds 200 characters") validated.append(feature) return validated @validator('style') def validate_style(cls, v): if v not in ("luxury", "budget", "professional", "casual"): raise ValueError(f"Invalid style: {v}") return v class PricingArgs(BaseModel): """Validated pricing recommendation arguments""" product_name: str = Field(..., min_length=1, max_length=200) cost: float = Field(..., ge=0.01, description="Product cost") category: Optional[str] = Field(default="general", max_length=100) target_margin: Optional[float] = Field(default=50, ge=0, le=500, description="Target profit margin %") class CompetitorAnalysisArgs(BaseModel): """Validated competitor analysis arguments""" product_name: str = Field(..., min_length=1, max_length=200) category: Optional[str] = Field(default="general", max_length=100) key_competitors: Optional[list] = Field(default=None, max_items=10, description="Competitor names") @validator('key_competitors') def validate_competitors(cls, v): if v is None: return v validated = [] for competitor in v: if not isinstance(competitor, str): raise ValueError(f"Competitor must be string, got {type(competitor)}") if len(competitor) > 200: raise ValueError("Competitor name exceeds 200 characters") validated.append(competitor) return validated def validate_tool_args(tool_name: str, arguments: Dict[str, Any]) -> tuple[bool, Any, Optional[str]]: """ Validate tool arguments Args: tool_name: Name of the tool arguments: Tool arguments Returns: Tuple of (is_valid, validated_args, error_message) """ try: if tool_name == "knowledge_search": args = SearchArgs(**arguments) elif tool_name == "product_query": args = QueryArgs(**arguments) elif tool_name == "analyze_product": args = ProductAnalysisArgs(**arguments) elif tool_name == "analyze_reviews": args = ReviewAnalysisArgs(**arguments) elif tool_name == "generate_listing": args = ListingGenerationArgs(**arguments) elif tool_name == "price_recommendation": args = PricingArgs(**arguments) elif tool_name == "competitor_analysis": args = CompetitorAnalysisArgs(**arguments) else: return False, None, f"Unknown tool: {tool_name}" return True, args.dict(), None except ValidationError as e: error_msg = f"Validation error: {e.errors()}" logger.warning(f"{tool_name} validation failed: {error_msg}") return False, None, error_msg except Exception as e: error_msg = f"Unexpected validation error: {str(e)}" logger.error(f"{tool_name} validation error: {error_msg}") return False, None, error_msg def sanitize_string(s: str, max_length: int = 5000) -> str: """ Sanitize string input Args: s: Input string max_length: Maximum allowed length Returns: Sanitized string """ if not isinstance(s, str): return "" # Truncate if too long if len(s) > max_length: s = s[:max_length] # Remove potentially harmful characters s = s.strip() return s