Spaces:
Sleeping
Sleeping
File size: 6,515 Bytes
108d8af 9eebeb3 108d8af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """
Input validation and sanitization for tools and API endpoints.
Validates and sanitizes all inputs to ensure data quality and security.
"""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, validator, ValidationError
import logging
logger = logging.getLogger(__name__)
class SearchArgs(BaseModel):
"""Validated search arguments"""
query: str = Field(..., min_length=1, max_length=500, description="Search query")
search_type: str = Field(default="all", description="Search type: all, products, documentation")
top_k: int = Field(default=5, ge=1, le=50, description="Number of results")
@validator('search_type')
def validate_search_type(cls, v):
if v not in ("all", "products", "documentation"):
raise ValueError(f"Invalid search_type: {v}")
return v
class QueryArgs(BaseModel):
"""Validated query arguments"""
question: str = Field(..., min_length=1, max_length=1000, description="Question")
top_k: Optional[int] = Field(default=None, ge=1, le=50, description="Number of sources")
class ProductAnalysisArgs(BaseModel):
"""Validated product analysis arguments"""
product_name: str = Field(..., min_length=1, max_length=200, description="Product name")
category: Optional[str] = Field(default="general", max_length=100, description="Product category")
description: Optional[str] = Field(default="", max_length=2000, description="Product description")
current_price: Optional[float] = Field(default=None, ge=0, description="Current price")
class ReviewAnalysisArgs(BaseModel):
"""Validated review analysis arguments"""
reviews: list = Field(..., min_items=1, max_items=100, description="List of reviews")
product_name: Optional[str] = Field(default="Product", max_length=200, description="Product name")
@validator('reviews')
def validate_reviews(cls, v):
# Ensure all reviews are strings
validated = []
for review in v:
if not isinstance(review, str):
raise ValueError(f"Review must be string, got {type(review)}")
if len(review) > 5000:
raise ValueError("Review exceeds 5000 characters")
validated.append(review)
return validated
class ListingGenerationArgs(BaseModel):
"""Validated listing generation arguments"""
product_name: str = Field(..., min_length=1, max_length=200, description="Product name")
features: list = Field(..., min_items=1, max_items=20, description="Product features")
target_audience: Optional[str] = Field(default="general consumers", max_length=200)
style: Optional[str] = Field(default="professional", description="Tone style")
@validator('features')
def validate_features(cls, v):
validated = []
for feature in v:
if not isinstance(feature, str):
raise ValueError(f"Feature must be string, got {type(feature)}")
if len(feature) > 200:
raise ValueError("Feature exceeds 200 characters")
validated.append(feature)
return validated
@validator('style')
def validate_style(cls, v):
if v not in ("luxury", "budget", "professional", "casual"):
raise ValueError(f"Invalid style: {v}")
return v
class PricingArgs(BaseModel):
"""Validated pricing recommendation arguments"""
product_name: str = Field(..., min_length=1, max_length=200)
cost: float = Field(..., ge=0.01, description="Product cost")
category: Optional[str] = Field(default="general", max_length=100)
target_margin: Optional[float] = Field(default=50, ge=0, le=500, description="Target profit margin %")
class CompetitorAnalysisArgs(BaseModel):
"""Validated competitor analysis arguments"""
product_name: str = Field(..., min_length=1, max_length=200)
category: Optional[str] = Field(default="general", max_length=100)
key_competitors: Optional[list] = Field(default=None, max_items=10, description="Competitor names")
@validator('key_competitors')
def validate_competitors(cls, v):
if v is None:
return v
validated = []
for competitor in v:
if not isinstance(competitor, str):
raise ValueError(f"Competitor must be string, got {type(competitor)}")
if len(competitor) > 200:
raise ValueError("Competitor name exceeds 200 characters")
validated.append(competitor)
return validated
def validate_tool_args(tool_name: str, arguments: Dict[str, Any]) -> tuple[bool, Any, Optional[str]]:
"""
Validate tool arguments
Args:
tool_name: Name of the tool
arguments: Tool arguments
Returns:
Tuple of (is_valid, validated_args, error_message)
"""
try:
if tool_name == "knowledge_search":
args = SearchArgs(**arguments)
elif tool_name == "product_query":
args = QueryArgs(**arguments)
elif tool_name == "analyze_product":
args = ProductAnalysisArgs(**arguments)
elif tool_name == "analyze_reviews":
args = ReviewAnalysisArgs(**arguments)
elif tool_name == "generate_listing":
args = ListingGenerationArgs(**arguments)
elif tool_name == "price_recommendation":
args = PricingArgs(**arguments)
elif tool_name == "competitor_analysis":
args = CompetitorAnalysisArgs(**arguments)
else:
return False, None, f"Unknown tool: {tool_name}"
return True, args.dict(), None
except ValidationError as e:
error_msg = f"Validation error: {e.errors()}"
logger.warning(f"{tool_name} validation failed: {error_msg}")
return False, None, error_msg
except Exception as e:
error_msg = f"Unexpected validation error: {str(e)}"
logger.error(f"{tool_name} validation error: {error_msg}")
return False, None, error_msg
def sanitize_string(s: str, max_length: int = 5000) -> str:
"""
Sanitize string input
Args:
s: Input string
max_length: Maximum allowed length
Returns:
Sanitized string
"""
if not isinstance(s, str):
return ""
# Truncate if too long
if len(s) > max_length:
s = s[:max_length]
# Remove potentially harmful characters
s = s.strip()
return s
|