Medium-MCP / src /validation.py
Nikhil Pravin Pise
feat: implement comprehensive improvement plan (Phases 1-5)
e98cc10
"""
Input Validation Module
Comprehensive validation for URLs, inputs, and parameters.
Protects against malicious input and ensures data integrity.
"""
from __future__ import annotations
import re
import logging
from typing import Optional, Any
from urllib.parse import urlparse, parse_qs, unquote
from dataclasses import dataclass
logger = logging.getLogger(__name__)
# =============================================================================
# CONSTANTS
# =============================================================================
# Maximum lengths
MAX_URL_LENGTH = 2048
MAX_QUERY_LENGTH = 500
MAX_TAG_LENGTH = 100
MAX_BATCH_SIZE = 20
# Valid URL schemes
VALID_SCHEMES = frozenset({"http", "https"})
# Known Medium domains
MEDIUM_DOMAINS = frozenset({
"medium.com",
"towardsdatascience.com",
"betterprogramming.pub",
"levelup.gitconnected.com",
"javascript.plainenglish.io",
"python.plainenglish.io",
"blog.devgenius.io",
"uxdesign.cc",
"itnext.io",
"hackernoon.com",
"freecodecamp.org",
})
# Dangerous URL patterns
DANGEROUS_PATTERNS = (
re.compile(r"javascript:", re.IGNORECASE),
re.compile(r"data:", re.IGNORECASE),
re.compile(r"vbscript:", re.IGNORECASE),
re.compile(r"file://", re.IGNORECASE),
)
# =============================================================================
# VALIDATION RESULTS
# =============================================================================
@dataclass
class ValidationResult:
"""Result of a validation check."""
is_valid: bool
value: Optional[str] = None
error: Optional[str] = None
sanitized: bool = False
# =============================================================================
# URL VALIDATION
# =============================================================================
def validate_url(url: str) -> ValidationResult:
"""
Validate and sanitize a URL.
Checks:
- Not empty
- Within length limits
- Valid scheme (http/https)
- No dangerous patterns
- Valid URL structure
Args:
url: URL to validate
Returns:
ValidationResult with sanitized URL or error
"""
if not url:
return ValidationResult(
is_valid=False,
error="URL is required"
)
# Strip whitespace
url = url.strip()
# Check length
if len(url) > MAX_URL_LENGTH:
return ValidationResult(
is_valid=False,
error=f"URL exceeds maximum length of {MAX_URL_LENGTH}"
)
# Check for dangerous patterns
for pattern in DANGEROUS_PATTERNS:
if pattern.search(url):
logger.warning(f"Dangerous URL pattern detected: {url[:50]}")
return ValidationResult(
is_valid=False,
error="URL contains dangerous content"
)
# Parse URL
try:
parsed = urlparse(url)
except Exception as e:
return ValidationResult(
is_valid=False,
error=f"Invalid URL format: {e}"
)
# Check scheme
if parsed.scheme.lower() not in VALID_SCHEMES:
return ValidationResult(
is_valid=False,
error="URL must use http or https"
)
# Check netloc (domain)
if not parsed.netloc:
return ValidationResult(
is_valid=False,
error="URL must include a domain"
)
return ValidationResult(
is_valid=True,
value=url,
sanitized=True
)
def validate_medium_url(url: str) -> ValidationResult:
"""
Validate that URL is a Medium article URL.
Args:
url: URL to validate
Returns:
ValidationResult
"""
# First do basic URL validation
result = validate_url(url)
if not result.is_valid:
return result
parsed = urlparse(url)
domain = parsed.netloc.lower().replace("www.", "")
# Check if it's a known Medium domain
is_medium = any(
domain == md or domain.endswith(f".{md}")
for md in MEDIUM_DOMAINS
)
if not is_medium:
return ValidationResult(
is_valid=False,
error="URL is not a recognized Medium domain"
)
return ValidationResult(
is_valid=True,
value=url,
sanitized=True
)
def validate_batch_urls(urls: list[str]) -> tuple[list[str], list[dict[str, str]]]:
"""
Validate a batch of URLs.
Args:
urls: List of URLs to validate
Returns:
Tuple of (valid_urls, errors)
"""
if not urls:
return [], [{"url": "", "error": "No URLs provided"}]
if len(urls) > MAX_BATCH_SIZE:
return [], [{"url": "", "error": f"Batch size exceeds maximum of {MAX_BATCH_SIZE}"}]
valid_urls = []
errors = []
for url in urls:
result = validate_url(url)
if result.is_valid:
valid_urls.append(result.value or url)
else:
errors.append({"url": url, "error": result.error or "Invalid URL"})
return valid_urls, errors
# =============================================================================
# QUERY VALIDATION
# =============================================================================
def validate_search_query(query: str) -> ValidationResult:
"""
Validate a search query.
Args:
query: Search query to validate
Returns:
ValidationResult
"""
if not query:
return ValidationResult(
is_valid=False,
error="Search query is required"
)
query = query.strip()
if len(query) > MAX_QUERY_LENGTH:
return ValidationResult(
is_valid=False,
error=f"Query exceeds maximum length of {MAX_QUERY_LENGTH}"
)
if len(query) < 2:
return ValidationResult(
is_valid=False,
error="Query must be at least 2 characters"
)
# Remove potential injection characters
sanitized = re.sub(r'[<>"\']', '', query)
return ValidationResult(
is_valid=True,
value=sanitized,
sanitized=sanitized != query
)
def validate_tag(tag: str) -> ValidationResult:
"""
Validate a Medium tag.
Args:
tag: Tag to validate
Returns:
ValidationResult
"""
if not tag:
return ValidationResult(
is_valid=False,
error="Tag is required"
)
tag = tag.strip().lower()
if len(tag) > MAX_TAG_LENGTH:
return ValidationResult(
is_valid=False,
error=f"Tag exceeds maximum length of {MAX_TAG_LENGTH}"
)
# Tags should be alphanumeric with hyphens
if not re.match(r'^[a-z0-9-]+$', tag):
# Try to sanitize
sanitized = re.sub(r'[^a-z0-9-]', '-', tag)
sanitized = re.sub(r'-+', '-', sanitized).strip('-')
if sanitized:
return ValidationResult(
is_valid=True,
value=sanitized,
sanitized=True
)
return ValidationResult(
is_valid=False,
error="Tag must contain only letters, numbers, and hyphens"
)
return ValidationResult(
is_valid=True,
value=tag
)
# =============================================================================
# NUMERIC VALIDATION
# =============================================================================
def validate_positive_int(
value: Any,
name: str = "value",
min_val: int = 1,
max_val: int = 100,
) -> ValidationResult:
"""
Validate a positive integer within bounds.
Args:
value: Value to validate
name: Parameter name for error messages
min_val: Minimum allowed value
max_val: Maximum allowed value
Returns:
ValidationResult
"""
try:
int_value = int(value)
except (TypeError, ValueError):
return ValidationResult(
is_valid=False,
error=f"{name} must be an integer"
)
if int_value < min_val:
return ValidationResult(
is_valid=False,
error=f"{name} must be at least {min_val}"
)
if int_value > max_val:
return ValidationResult(
is_valid=False,
error=f"{name} must be at most {max_val}"
)
return ValidationResult(
is_valid=True,
value=str(int_value)
)
# =============================================================================
# POST ID VALIDATION
# =============================================================================
def validate_post_id(post_id: str) -> ValidationResult:
"""
Validate a Medium post ID.
Post IDs are 8-12 character alphanumeric strings.
Args:
post_id: Post ID to validate
Returns:
ValidationResult
"""
if not post_id:
return ValidationResult(
is_valid=False,
error="Post ID is required"
)
post_id = post_id.strip()
if len(post_id) < 8 or len(post_id) > 16:
return ValidationResult(
is_valid=False,
error="Post ID must be 8-16 characters"
)
if not re.match(r'^[a-f0-9]+$', post_id.lower()):
return ValidationResult(
is_valid=False,
error="Post ID must be hexadecimal"
)
return ValidationResult(
is_valid=True,
value=post_id.lower()
)