""" URL data source. This module provides data loading from HTTP/HTTPS URLs with security protections against SSRF attacks. """ import ipaddress import json import logging import os import socket import tempfile from typing import Any, Dict, Iterator, List, Optional from urllib.parse import urlparse from potato.data_sources.base import DataSource, SourceConfig logger = logging.getLogger(__name__) # Default limits DEFAULT_MAX_SIZE_MB = 100 DEFAULT_TIMEOUT_SECONDS = 30 # Private IP ranges to block (SSRF protection) PRIVATE_IP_RANGES = [ ipaddress.ip_network("10.0.0.0/8"), ipaddress.ip_network("172.16.0.0/12"), ipaddress.ip_network("192.168.0.0/16"), ipaddress.ip_network("127.0.0.0/8"), ipaddress.ip_network("169.254.0.0/16"), ipaddress.ip_network("::1/128"), ipaddress.ip_network("fc00::/7"), ipaddress.ip_network("fe80::/10"), ] def is_private_ip(ip_str: str) -> bool: """Check if an IP address is in a private range.""" try: ip = ipaddress.ip_address(ip_str) for network in PRIVATE_IP_RANGES: if ip in network: return True return False except ValueError: return False def resolve_and_validate_url(url: str, block_private_ips: bool = True) -> tuple: """ Validate a URL and resolve it, checking for SSRF vulnerabilities. Returns the validated URL and a list of validated (non-private) IP addresses that can be used for IP-pinned connections, preventing DNS rebinding attacks. Args: url: The URL to validate block_private_ips: Whether to block private/internal IPs Returns: Tuple of (validated_url, list_of_validated_ips) Raises: ValueError: If the URL is invalid or points to a blocked IP """ parsed = urlparse(url) # Only allow http/https if parsed.scheme not in ('http', 'https'): raise ValueError(f"Invalid URL scheme '{parsed.scheme}'. Only http/https allowed.") if not parsed.netloc: raise ValueError("Invalid URL: missing host") # Extract hostname (without port) hostname = parsed.hostname if not hostname: raise ValueError("Invalid URL: missing hostname") validated_ips = [] if block_private_ips: # Resolve hostname to IP try: # Get all IP addresses for the hostname addr_info = socket.getaddrinfo(hostname, None) for info in addr_info: ip = info[4][0] if is_private_ip(ip): raise ValueError( f"URL host '{hostname}' resolves to private IP {ip}. " f"Access to private networks is not allowed." ) validated_ips.append(ip) except socket.gaierror as e: raise ValueError(f"Could not resolve hostname '{hostname}': {e}") return url, validated_ips class URLSource(DataSource): """ Data source for HTTP/HTTPS URLs. Supports fetching data from remote URLs with: - SSRF protection (blocks private IPs) - Custom headers for authentication - Size limits and timeouts - Content-type validation - Caching integration Configuration: type: url url: "https://example.com/data.jsonl" # Required headers: # Optional custom headers Authorization: "Bearer ${API_TOKEN}" max_size_mb: 100 # Optional size limit timeout_seconds: 30 # Optional request timeout block_private_ips: true # Optional SSRF protection Supported content types: - application/json, application/x-ndjson - text/csv, text/tab-separated-values - application/x-jsonlines """ def __init__(self, config: SourceConfig): """Initialize the URL source.""" super().__init__(config) self._url = config.config.get("url", "") self._headers = config.config.get("headers", {}) self._max_size_bytes = config.config.get( "max_size_mb", DEFAULT_MAX_SIZE_MB ) * 1024 * 1024 self._timeout = config.config.get("timeout_seconds", DEFAULT_TIMEOUT_SECONDS) self._block_private_ips = config.config.get("block_private_ips", True) self._allowed_domains = config.config.get("allowed_domains") # Cached data self._cached_data: Optional[List[Dict]] = None self._content_type: Optional[str] = None def get_source_id(self) -> str: """Get unique identifier.""" return self._source_id def validate_config(self) -> List[str]: """Validate source configuration.""" errors = [] if not self._url: errors.append("'url' is required for URL source") return errors try: parsed = urlparse(self._url) if parsed.scheme not in ('http', 'https'): errors.append( f"Invalid URL scheme '{parsed.scheme}'. Only http/https allowed." ) if not parsed.netloc: errors.append("Invalid URL: missing host") # Check domain allowlist if configured if self._allowed_domains: hostname = parsed.hostname if hostname and hostname not in self._allowed_domains: errors.append( f"Domain '{hostname}' is not in allowed domains list" ) except Exception as e: errors.append(f"Invalid URL: {e}") return errors def is_available(self) -> bool: """Check if the URL is accessible.""" try: resolve_and_validate_url(self._url, self._block_private_ips) return True except ValueError as e: logger.warning(f"URL not available: {e}") return False except Exception as e: logger.warning(f"Error checking URL availability: {e}") return False def _fetch_data(self) -> List[Dict[str, Any]]: """Fetch and parse data from the URL.""" import urllib.request import urllib.error # Resolve and validate URL; returns validated IPs for post-connection check. # We validate immediately before the fetch to minimize TOCTOU window. _, validated_ips = resolve_and_validate_url( self._url, self._block_private_ips ) # Build request with headers request = urllib.request.Request(self._url) for key, value in self._headers.items(): request.add_header(key, value) # Add User-Agent if not specified if 'User-Agent' not in self._headers: request.add_header('User-Agent', 'Potato-Annotation-Tool/1.0') try: with urllib.request.urlopen(request, timeout=self._timeout) as response: # Post-connection SSRF check: verify the connected IP is not private. # This guards against DNS rebinding between validation and connect. if self._block_private_ips: try: sock = None # Navigate to the underlying socket fp = getattr(response, 'fp', None) raw = getattr(fp, 'raw', None) if fp else None sock = getattr(raw, '_sock', None) if raw else None if sock is not None: peer = sock.getpeername() if peer and is_private_ip(peer[0]): raise ValueError( f"Connection resolved to private IP {peer[0]}. " f"Possible DNS rebinding attack." ) except (AttributeError, OSError): # If we can't inspect the socket, the pre-connect # validation still provides the primary protection pass # Check content length content_length = response.headers.get('Content-Length') if content_length: size = int(content_length) if size > self._max_size_bytes: raise ValueError( f"Response size {size} exceeds limit {self._max_size_bytes}" ) # Read with size limit data = b"" chunk_size = 8192 while True: chunk = response.read(chunk_size) if not chunk: break data += chunk if len(data) > self._max_size_bytes: raise ValueError( f"Response exceeded size limit of " f"{self._max_size_bytes / (1024*1024):.1f}MB" ) self._content_type = response.headers.get('Content-Type', '') # Parse based on content type return self._parse_content(data, self._content_type) except urllib.error.HTTPError as e: raise RuntimeError(f"HTTP error {e.code}: {e.reason}") except urllib.error.URLError as e: raise RuntimeError(f"URL error: {e.reason}") def _parse_content( self, data: bytes, content_type: str ) -> List[Dict[str, Any]]: """Parse content based on content type or URL extension.""" text = data.decode('utf-8') # Determine format from content type or URL url_path = urlparse(self._url).path.lower() if any(ct in content_type for ct in ['json', 'ndjson', 'jsonlines']): return self._parse_json(text) elif url_path.endswith('.json') or url_path.endswith('.jsonl'): return self._parse_json(text) elif 'csv' in content_type or url_path.endswith('.csv'): return self._parse_csv(text, ',') elif 'tab-separated' in content_type or url_path.endswith('.tsv'): return self._parse_csv(text, '\t') else: # Try JSON first, fall back to JSONL try: return self._parse_json(text) except json.JSONDecodeError: raise ValueError( f"Could not parse content. " f"Content-Type: {content_type}, URL: {self._url}" ) def _parse_json(self, text: str) -> List[Dict[str, Any]]: """Parse JSON or JSONL content.""" # Try as JSON array first try: data = json.loads(text) if isinstance(data, list): return data elif isinstance(data, dict): return [data] else: raise ValueError(f"Unexpected JSON type: {type(data)}") except json.JSONDecodeError: pass # Parse as JSONL items = [] for line_no, line in enumerate(text.split('\n'), 1): line = line.strip() if not line: continue try: item = json.loads(line) if isinstance(item, list): items.extend(item) else: items.append(item) except json.JSONDecodeError as e: logger.warning(f"Invalid JSON at line {line_no}: {e}") return items def _parse_csv(self, text: str, delimiter: str) -> List[Dict[str, Any]]: """Parse CSV/TSV content.""" import csv from io import StringIO reader = csv.DictReader(StringIO(text), delimiter=delimiter) return [dict(row) for row in reader] def read_items( self, start: int = 0, count: Optional[int] = None ) -> Iterator[Dict[str, Any]]: """Read items from the URL.""" # Fetch data if not cached if self._cached_data is None: self._cached_data = self._fetch_data() # Apply start/count items = self._cached_data[start:] if count is not None: items = items[:count] yield from items def get_total_count(self) -> Optional[int]: """Get total number of items.""" if self._cached_data is None: try: self._cached_data = self._fetch_data() except Exception as e: logger.error(f"Error fetching data for count: {e}") return None return len(self._cached_data) def supports_partial_reading(self) -> bool: """URL source supports partial reading after fetch.""" return True def refresh(self) -> bool: """Refresh by clearing cached data.""" self._cached_data = None return True def get_status(self) -> Dict[str, Any]: """Get source status.""" status = super().get_status() status["url"] = self._url status["cached"] = self._cached_data is not None status["content_type"] = self._content_type return status