ecomcp / src /core /validators.py
vinhnx90's picture
feat: Establish core infrastructure modules, add comprehensive documentation, and refactor UI components.
9eebeb3
"""
Input validation and sanitization for tools and API endpoints.
Validates and sanitizes all inputs to ensure data quality and security.
"""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, validator, ValidationError
import logging
logger = logging.getLogger(__name__)
class SearchArgs(BaseModel):
"""Validated search arguments"""
query: str = Field(..., min_length=1, max_length=500, description="Search query")
search_type: str = Field(default="all", description="Search type: all, products, documentation")
top_k: int = Field(default=5, ge=1, le=50, description="Number of results")
@validator('search_type')
def validate_search_type(cls, v):
if v not in ("all", "products", "documentation"):
raise ValueError(f"Invalid search_type: {v}")
return v
class QueryArgs(BaseModel):
"""Validated query arguments"""
question: str = Field(..., min_length=1, max_length=1000, description="Question")
top_k: Optional[int] = Field(default=None, ge=1, le=50, description="Number of sources")
class ProductAnalysisArgs(BaseModel):
"""Validated product analysis arguments"""
product_name: str = Field(..., min_length=1, max_length=200, description="Product name")
category: Optional[str] = Field(default="general", max_length=100, description="Product category")
description: Optional[str] = Field(default="", max_length=2000, description="Product description")
current_price: Optional[float] = Field(default=None, ge=0, description="Current price")
class ReviewAnalysisArgs(BaseModel):
"""Validated review analysis arguments"""
reviews: list = Field(..., min_items=1, max_items=100, description="List of reviews")
product_name: Optional[str] = Field(default="Product", max_length=200, description="Product name")
@validator('reviews')
def validate_reviews(cls, v):
# Ensure all reviews are strings
validated = []
for review in v:
if not isinstance(review, str):
raise ValueError(f"Review must be string, got {type(review)}")
if len(review) > 5000:
raise ValueError("Review exceeds 5000 characters")
validated.append(review)
return validated
class ListingGenerationArgs(BaseModel):
"""Validated listing generation arguments"""
product_name: str = Field(..., min_length=1, max_length=200, description="Product name")
features: list = Field(..., min_items=1, max_items=20, description="Product features")
target_audience: Optional[str] = Field(default="general consumers", max_length=200)
style: Optional[str] = Field(default="professional", description="Tone style")
@validator('features')
def validate_features(cls, v):
validated = []
for feature in v:
if not isinstance(feature, str):
raise ValueError(f"Feature must be string, got {type(feature)}")
if len(feature) > 200:
raise ValueError("Feature exceeds 200 characters")
validated.append(feature)
return validated
@validator('style')
def validate_style(cls, v):
if v not in ("luxury", "budget", "professional", "casual"):
raise ValueError(f"Invalid style: {v}")
return v
class PricingArgs(BaseModel):
"""Validated pricing recommendation arguments"""
product_name: str = Field(..., min_length=1, max_length=200)
cost: float = Field(..., ge=0.01, description="Product cost")
category: Optional[str] = Field(default="general", max_length=100)
target_margin: Optional[float] = Field(default=50, ge=0, le=500, description="Target profit margin %")
class CompetitorAnalysisArgs(BaseModel):
"""Validated competitor analysis arguments"""
product_name: str = Field(..., min_length=1, max_length=200)
category: Optional[str] = Field(default="general", max_length=100)
key_competitors: Optional[list] = Field(default=None, max_items=10, description="Competitor names")
@validator('key_competitors')
def validate_competitors(cls, v):
if v is None:
return v
validated = []
for competitor in v:
if not isinstance(competitor, str):
raise ValueError(f"Competitor must be string, got {type(competitor)}")
if len(competitor) > 200:
raise ValueError("Competitor name exceeds 200 characters")
validated.append(competitor)
return validated
def validate_tool_args(tool_name: str, arguments: Dict[str, Any]) -> tuple[bool, Any, Optional[str]]:
"""
Validate tool arguments
Args:
tool_name: Name of the tool
arguments: Tool arguments
Returns:
Tuple of (is_valid, validated_args, error_message)
"""
try:
if tool_name == "knowledge_search":
args = SearchArgs(**arguments)
elif tool_name == "product_query":
args = QueryArgs(**arguments)
elif tool_name == "analyze_product":
args = ProductAnalysisArgs(**arguments)
elif tool_name == "analyze_reviews":
args = ReviewAnalysisArgs(**arguments)
elif tool_name == "generate_listing":
args = ListingGenerationArgs(**arguments)
elif tool_name == "price_recommendation":
args = PricingArgs(**arguments)
elif tool_name == "competitor_analysis":
args = CompetitorAnalysisArgs(**arguments)
else:
return False, None, f"Unknown tool: {tool_name}"
return True, args.dict(), None
except ValidationError as e:
error_msg = f"Validation error: {e.errors()}"
logger.warning(f"{tool_name} validation failed: {error_msg}")
return False, None, error_msg
except Exception as e:
error_msg = f"Unexpected validation error: {str(e)}"
logger.error(f"{tool_name} validation error: {error_msg}")
return False, None, error_msg
def sanitize_string(s: str, max_length: int = 5000) -> str:
"""
Sanitize string input
Args:
s: Input string
max_length: Maximum allowed length
Returns:
Sanitized string
"""
if not isinstance(s, str):
return ""
# Truncate if too long
if len(s) > max_length:
s = s[:max_length]
# Remove potentially harmful characters
s = s.strip()
return s