""" Amazon S3 data source. This module provides data loading from Amazon S3 buckets, supporting various authentication methods. """ import ipaddress import json import logging import socket from typing import Any, Dict, Iterator, List, Optional from urllib.parse import urlparse from potato.data_sources.base import DataSource, SourceConfig logger = logging.getLogger(__name__) class S3Source(DataSource): """ Data source for Amazon S3 buckets. Supports loading data from S3 with multiple authentication options: - AWS credentials file (~/.aws/credentials) - Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) - Explicit credentials in config - S3-compatible storage (MinIO, etc.) Configuration: type: s3 bucket: "my-annotation-data" # Required key: "datasets/items.jsonl" # Required region: "us-east-1" # Optional, default us-east-1 # Optional: explicit credentials (prefer env vars) access_key_id: "${AWS_ACCESS_KEY_ID}" secret_access_key: "${AWS_SECRET_ACCESS_KEY}" # Optional: for S3-compatible storage endpoint_url: "https://minio.example.com" Supported formats: JSON, JSONL, CSV, TSV """ # Check for optional dependencies _HAS_BOTO3 = None @classmethod def _check_dependencies(cls) -> bool: """Check if boto3 is available.""" if cls._HAS_BOTO3 is None: try: import boto3 cls._HAS_BOTO3 = True except ImportError: cls._HAS_BOTO3 = False return cls._HAS_BOTO3 def __init__(self, config: SourceConfig): """Initialize the S3 source.""" super().__init__(config) self._bucket = config.config.get("bucket", "") self._key = config.config.get("key", "") self._region = config.config.get("region", "us-east-1") self._access_key_id = config.config.get("access_key_id") self._secret_access_key = config.config.get("secret_access_key") self._endpoint_url = config.config.get("endpoint_url") self._cached_data: Optional[List[Dict]] = None self._client = None def get_source_id(self) -> str: """Get unique identifier.""" return self._source_id def validate_config(self) -> List[str]: """Validate source configuration.""" errors = [] if not self._bucket: errors.append("'bucket' is required for S3 source") if not self._key: errors.append("'key' is required for S3 source") # Check that both access key and secret are provided together if self._access_key_id and not self._secret_access_key: errors.append( "'secret_access_key' is required when 'access_key_id' is provided" ) if self._secret_access_key and not self._access_key_id: errors.append( "'access_key_id' is required when 'secret_access_key' is provided" ) # SSRF protection: validate endpoint_url does not point to private IPs if self._endpoint_url: try: parsed = urlparse(self._endpoint_url) if parsed.scheme not in ('http', 'https'): errors.append( f"Invalid endpoint_url scheme '{parsed.scheme}'. " f"Only http/https allowed." ) hostname = parsed.hostname if hostname: try: addr_info = socket.getaddrinfo(hostname, None) for info in addr_info: ip_str = info[4][0] try: ip = ipaddress.ip_address(ip_str) if ip.is_loopback or ip.is_link_local: errors.append( f"endpoint_url host '{hostname}' resolves " f"to blocked IP {ip_str}. Loopback and " f"link-local addresses are not allowed." ) except ValueError: pass except socket.gaierror: # Can't resolve at validation time — will fail at connect pass except Exception as e: errors.append(f"Invalid endpoint_url: {e}") return errors def is_available(self) -> bool: """Check if the source is available.""" if not self._check_dependencies(): logger.warning( "boto3 not installed. Install with: pip install boto3" ) return False return True def _get_client(self): """Get or create the S3 client.""" if self._client: return self._client import boto3 # Build client configuration client_kwargs = { 'region_name': self._region, } if self._endpoint_url: client_kwargs['endpoint_url'] = self._endpoint_url if self._access_key_id and self._secret_access_key: client_kwargs['aws_access_key_id'] = self._access_key_id client_kwargs['aws_secret_access_key'] = self._secret_access_key self._client = boto3.client('s3', **client_kwargs) return self._client def _fetch_data(self) -> List[Dict[str, Any]]: """Fetch and parse data from S3.""" client = self._get_client() try: response = client.get_object(Bucket=self._bucket, Key=self._key) content = response['Body'].read() content_type = response.get('ContentType', '') logger.debug( f"Downloaded s3://{self._bucket}/{self._key} " f"({len(content)} bytes, {content_type})" ) # Decode and parse text = content.decode('utf-8') return self._parse_content(text, content_type) except client.exceptions.NoSuchKey: raise ValueError( f"Object not found: s3://{self._bucket}/{self._key}" ) except client.exceptions.NoSuchBucket: raise ValueError(f"Bucket not found: {self._bucket}") except Exception as e: raise RuntimeError(f"S3 error: {e}") def _parse_content( self, text: str, content_type: str = "" ) -> List[Dict[str, Any]]: """Parse file content based on content type or key extension.""" key_lower = self._key.lower() # Determine format is_json = 'json' in content_type or key_lower.endswith('.json') is_jsonl = 'ndjson' in content_type or key_lower.endswith('.jsonl') is_csv = 'csv' in content_type or key_lower.endswith('.csv') is_tsv = 'tab' in content_type or key_lower.endswith('.tsv') # Try JSON array first if is_json or is_jsonl: try: data = json.loads(text) if isinstance(data, list): return data elif isinstance(data, dict): return [data] except json.JSONDecodeError: pass # Try JSONL if is_jsonl or is_json: items = [] for line in text.strip().split('\n'): line = line.strip() if not line: continue try: item = json.loads(line) if isinstance(item, list): items.extend(item) else: items.append(item) except json.JSONDecodeError: pass if items: return items # Try CSV/TSV if is_csv or is_tsv: import csv from io import StringIO delimiter = '\t' if is_tsv else ',' reader = csv.DictReader(StringIO(text), delimiter=delimiter) return [dict(row) for row in reader] # Auto-detect: try JSON, then JSONL, then CSV try: data = json.loads(text) if isinstance(data, list): return data elif isinstance(data, dict): return [data] except json.JSONDecodeError: pass # Try JSONL items = [] for line in text.strip().split('\n'): line = line.strip() if not line: continue try: item = json.loads(line) if isinstance(item, list): items.extend(item) else: items.append(item) except json.JSONDecodeError: pass if items: return items # Try CSV as last resort import csv from io import StringIO try: reader = csv.DictReader(StringIO(text)) items = [dict(row) for row in reader] if items: return items except Exception: pass raise ValueError( f"Could not parse content from s3://{self._bucket}/{self._key}" ) def read_items( self, start: int = 0, count: Optional[int] = None ) -> Iterator[Dict[str, Any]]: """Read items from S3.""" if self._cached_data is None: self._cached_data = self._fetch_data() items = self._cached_data[start:] if count is not None: items = items[:count] yield from items def get_total_count(self) -> Optional[int]: """Get total number of items.""" if self._cached_data is None: try: self._cached_data = self._fetch_data() except Exception as e: logger.error(f"Error fetching data: {e}") return None return len(self._cached_data) def supports_partial_reading(self) -> bool: """Partial reading is supported after initial fetch.""" return True def refresh(self) -> bool: """Refresh by clearing cached data.""" self._cached_data = None return True def get_status(self) -> Dict[str, Any]: """Get source status.""" status = super().get_status() status["bucket"] = self._bucket status["key"] = self._key status["region"] = self._region status["endpoint_url"] = self._endpoint_url status["cached"] = self._cached_data is not None return status def close(self) -> None: """Close the source.""" self._client = None self._cached_data = None