File size: 3,807 Bytes
50dca14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Input validation and sanitization utilities.
All user-supplied strings pass through here before entering business logic.
"""
from __future__ import annotations

import re
from typing import Any, Optional
from urllib.parse import urlparse

import bleach


# Allowed HTML tags for any rich-text fields (none for our use case)
ALLOWED_TAGS: list[str] = []
ALLOWED_ATTRS: dict = {}

# Valid CSS selector pattern (permissive but blocks script injection)
_CSS_UNSAFE_PATTERN = re.compile(r"[<>\"']|javascript:", re.IGNORECASE)
_XPATH_UNSAFE_PATTERN = re.compile(r"[<>]|javascript:", re.IGNORECASE)


def sanitize_string(value: Any, max_length: int = 500) -> str:
    """Strip HTML tags and trim whitespace."""
    if not value:
        return ""
    cleaned = bleach.clean(str(value), tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS, strip=True)
    return cleaned[:max_length].strip()


def validate_url(url: str) -> tuple[bool, str]:
    """Validate URL format. Returns (is_valid, error_message)."""
    if not url:
        return False, "URL is required."
    url = url.strip()
    try:
        parsed = urlparse(url)
        if parsed.scheme not in ("http", "https"):
            return False, "URL must use http or https."
        if not parsed.netloc:
            return False, "URL must include a valid domain."
        # Block localhost/private IPs (SSRF prevention)
        host = parsed.hostname or ""
        blocked = ["localhost", "127.0.0.1", "0.0.0.0", "::1"]
        if host in blocked or host.startswith("192.168.") or host.startswith("10."):
            return False, "Requests to internal addresses are not allowed."
    except Exception:
        return False, "Invalid URL format."
    return True, ""


def validate_css_selector(selector: Optional[str]) -> tuple[bool, str]:
    if not selector:
        return True, ""
    if _CSS_UNSAFE_PATTERN.search(selector):
        return False, "CSS selector contains unsafe characters."
    return True, ""


def validate_xpath(xpath: Optional[str]) -> tuple[bool, str]:
    if not xpath:
        return True, ""
    if _XPATH_UNSAFE_PATTERN.search(xpath):
        return False, "XPath contains unsafe characters."
    return True, ""


def validate_job_data(data: dict) -> tuple[bool, dict]:
    """
    Validate all fields for a scrape job.
    Returns (is_valid, error_dict).
    """
    errors: dict[str, str] = {}

    url = data.get("url", "")
    valid, msg = validate_url(url)
    if not valid:
        errors["url"] = msg

    css = data.get("css_selector")
    valid, msg = validate_css_selector(css)
    if not valid:
        errors["css_selector"] = msg

    xpath = data.get("xpath_selector")
    valid, msg = validate_xpath(xpath)
    if not valid:
        errors["xpath_selector"] = msg

    extraction_type = data.get("extraction_type", "text")
    valid_extractions = ("text", "images", "links", "attributes", "table", "json_ld", "full_html")
    if extraction_type not in valid_extractions:
        errors["extraction_type"] = f"Must be one of: {', '.join(valid_extractions)}"

    scrape_type = data.get("scrape_type", "static")
    if scrape_type not in ("static", "dynamic"):
        errors["scrape_type"] = "Must be 'static' or 'dynamic'."

    try:
        max_pages = int(data.get("max_pages", 1))
        if not 1 <= max_pages <= 200:
            errors["max_pages"] = "max_pages must be between 1 and 200."
    except (TypeError, ValueError):
        errors["max_pages"] = "max_pages must be an integer."

    try:
        delay = float(data.get("delay_seconds", 1.0))
        if not 0 <= delay <= 60:
            errors["delay_seconds"] = "delay_seconds must be between 0 and 60."
    except (TypeError, ValueError):
        errors["delay_seconds"] = "delay_seconds must be a number."

    return len(errors) == 0, errors