Spaces:
Sleeping
Sleeping
| """ | |
| Input validation and sanitization for tools and API endpoints. | |
| Validates and sanitizes all inputs to ensure data quality and security. | |
| """ | |
| from typing import Any, Dict, Optional | |
| from pydantic import BaseModel, Field, validator, ValidationError | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SearchArgs(BaseModel): | |
| """Validated search arguments""" | |
| query: str = Field(..., min_length=1, max_length=500, description="Search query") | |
| search_type: str = Field(default="all", description="Search type: all, products, documentation") | |
| top_k: int = Field(default=5, ge=1, le=50, description="Number of results") | |
| def validate_search_type(cls, v): | |
| if v not in ("all", "products", "documentation"): | |
| raise ValueError(f"Invalid search_type: {v}") | |
| return v | |
| class QueryArgs(BaseModel): | |
| """Validated query arguments""" | |
| question: str = Field(..., min_length=1, max_length=1000, description="Question") | |
| top_k: Optional[int] = Field(default=None, ge=1, le=50, description="Number of sources") | |
| class ProductAnalysisArgs(BaseModel): | |
| """Validated product analysis arguments""" | |
| product_name: str = Field(..., min_length=1, max_length=200, description="Product name") | |
| category: Optional[str] = Field(default="general", max_length=100, description="Product category") | |
| description: Optional[str] = Field(default="", max_length=2000, description="Product description") | |
| current_price: Optional[float] = Field(default=None, ge=0, description="Current price") | |
| class ReviewAnalysisArgs(BaseModel): | |
| """Validated review analysis arguments""" | |
| reviews: list = Field(..., min_items=1, max_items=100, description="List of reviews") | |
| product_name: Optional[str] = Field(default="Product", max_length=200, description="Product name") | |
| def validate_reviews(cls, v): | |
| # Ensure all reviews are strings | |
| validated = [] | |
| for review in v: | |
| if not isinstance(review, str): | |
| raise ValueError(f"Review must be string, got {type(review)}") | |
| if len(review) > 5000: | |
| raise ValueError("Review exceeds 5000 characters") | |
| validated.append(review) | |
| return validated | |
| class ListingGenerationArgs(BaseModel): | |
| """Validated listing generation arguments""" | |
| product_name: str = Field(..., min_length=1, max_length=200, description="Product name") | |
| features: list = Field(..., min_items=1, max_items=20, description="Product features") | |
| target_audience: Optional[str] = Field(default="general consumers", max_length=200) | |
| style: Optional[str] = Field(default="professional", description="Tone style") | |
| def validate_features(cls, v): | |
| validated = [] | |
| for feature in v: | |
| if not isinstance(feature, str): | |
| raise ValueError(f"Feature must be string, got {type(feature)}") | |
| if len(feature) > 200: | |
| raise ValueError("Feature exceeds 200 characters") | |
| validated.append(feature) | |
| return validated | |
| def validate_style(cls, v): | |
| if v not in ("luxury", "budget", "professional", "casual"): | |
| raise ValueError(f"Invalid style: {v}") | |
| return v | |
| class PricingArgs(BaseModel): | |
| """Validated pricing recommendation arguments""" | |
| product_name: str = Field(..., min_length=1, max_length=200) | |
| cost: float = Field(..., ge=0.01, description="Product cost") | |
| category: Optional[str] = Field(default="general", max_length=100) | |
| target_margin: Optional[float] = Field(default=50, ge=0, le=500, description="Target profit margin %") | |
| class CompetitorAnalysisArgs(BaseModel): | |
| """Validated competitor analysis arguments""" | |
| product_name: str = Field(..., min_length=1, max_length=200) | |
| category: Optional[str] = Field(default="general", max_length=100) | |
| key_competitors: Optional[list] = Field(default=None, max_items=10, description="Competitor names") | |
| def validate_competitors(cls, v): | |
| if v is None: | |
| return v | |
| validated = [] | |
| for competitor in v: | |
| if not isinstance(competitor, str): | |
| raise ValueError(f"Competitor must be string, got {type(competitor)}") | |
| if len(competitor) > 200: | |
| raise ValueError("Competitor name exceeds 200 characters") | |
| validated.append(competitor) | |
| return validated | |
| def validate_tool_args(tool_name: str, arguments: Dict[str, Any]) -> tuple[bool, Any, Optional[str]]: | |
| """ | |
| Validate tool arguments | |
| Args: | |
| tool_name: Name of the tool | |
| arguments: Tool arguments | |
| Returns: | |
| Tuple of (is_valid, validated_args, error_message) | |
| """ | |
| try: | |
| if tool_name == "knowledge_search": | |
| args = SearchArgs(**arguments) | |
| elif tool_name == "product_query": | |
| args = QueryArgs(**arguments) | |
| elif tool_name == "analyze_product": | |
| args = ProductAnalysisArgs(**arguments) | |
| elif tool_name == "analyze_reviews": | |
| args = ReviewAnalysisArgs(**arguments) | |
| elif tool_name == "generate_listing": | |
| args = ListingGenerationArgs(**arguments) | |
| elif tool_name == "price_recommendation": | |
| args = PricingArgs(**arguments) | |
| elif tool_name == "competitor_analysis": | |
| args = CompetitorAnalysisArgs(**arguments) | |
| else: | |
| return False, None, f"Unknown tool: {tool_name}" | |
| return True, args.dict(), None | |
| except ValidationError as e: | |
| error_msg = f"Validation error: {e.errors()}" | |
| logger.warning(f"{tool_name} validation failed: {error_msg}") | |
| return False, None, error_msg | |
| except Exception as e: | |
| error_msg = f"Unexpected validation error: {str(e)}" | |
| logger.error(f"{tool_name} validation error: {error_msg}") | |
| return False, None, error_msg | |
| def sanitize_string(s: str, max_length: int = 5000) -> str: | |
| """ | |
| Sanitize string input | |
| Args: | |
| s: Input string | |
| max_length: Maximum allowed length | |
| Returns: | |
| Sanitized string | |
| """ | |
| if not isinstance(s, str): | |
| return "" | |
| # Truncate if too long | |
| if len(s) > max_length: | |
| s = s[:max_length] | |
| # Remove potentially harmful characters | |
| s = s.strip() | |
| return s | |