#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Safe serialization module for remote executor communication. Provides JSON-based serialization with optional pickle fallback for types that cannot be safely serialized. **Security Note:** Pickle deserialization can execute arbitrary code. This module defaults to safe JSON-only serialization. Only enable pickle fallback (allow_insecure_serializer=True) if you fully trust the execution environment. """ import base64 import json import pickle from io import BytesIO from typing import Any __all__ = ["SerializationError", "SafeSerializer"] class SerializationError(Exception): """Raised when a type cannot be safely serialized.""" pass class SafeSerializer: """JSON-based serializer with type markers for safe serialization. Supports: - Basic: str, int, float, bool, None, list, dict - Extended: tuple, set, frozenset, bytes, complex, datetime/date/time/timedelta - Optional: numpy.ndarray, PIL.Image, dataclasses, Decimal, Path The serializer uses a prefix system to distinguish between formats: - "safe:" prefix for JSON-serialized data - "pickle:" prefix for pickle-serialized data (when allowed) """ SAFE_PREFIX = "safe:" # Cache for optional type classes (avoids repeated import attempts) _optional_types_cache: dict = {} @classmethod def _get_optional_type(cls, module: str, attr: str): """Get optional type class with caching to avoid repeated imports.""" key = f"{module}.{attr}" if key not in cls._optional_types_cache: try: mod = __import__(module, fromlist=[attr]) cls._optional_types_cache[key] = getattr(mod, attr) except (ImportError, AttributeError): cls._optional_types_cache[key] = None return cls._optional_types_cache[key] @staticmethod def to_json_safe(obj: Any) -> Any: """Convert Python objects to JSON-serializable format with type markers. Args: obj: Object to convert. Returns: JSON-serializable representation. Raises: SerializationError: If the object cannot be safely serialized. """ # Fast path: use exact type check for primitives (most common case) obj_type = type(obj) if obj_type is str or obj_type is int or obj_type is float or obj_type is bool or obj is None: return obj # Fast path: list (very common for return values) if obj_type is list: return [SafeSerializer.to_json_safe(item) for item in obj] # Fast path: tuple (common for multiple return values) if obj_type is tuple: return {"__type__": "tuple", "data": [SafeSerializer.to_json_safe(item) for item in obj]} # Fast path: dict (common, check string keys) if obj_type is dict: if all(type(k) is str for k in obj): return {k: SafeSerializer.to_json_safe(v) for k, v in obj.items()} return { "__type__": "dict_with_complex_keys", "data": [[SafeSerializer.to_json_safe(k), SafeSerializer.to_json_safe(v)] for k, v in obj.items()], } # Other builtin types - exact type checks if obj_type is set: return {"__type__": "set", "data": [SafeSerializer.to_json_safe(item) for item in obj]} if obj_type is frozenset: return {"__type__": "frozenset", "data": [SafeSerializer.to_json_safe(item) for item in obj]} if obj_type is bytes: return {"__type__": "bytes", "data": base64.b64encode(obj).decode()} if obj_type is complex: return {"__type__": "complex", "real": obj.real, "imag": obj.imag} # Use type module/name for lazy-loaded types (avoids import until needed) type_module = getattr(obj_type, "__module__", "") type_name = obj_type.__name__ # datetime module types (check module first to skip unrelated types quickly) if type_module == "datetime": if type_name == "datetime": return {"__type__": "datetime", "data": obj.isoformat()} if type_name == "date": return {"__type__": "date", "data": obj.isoformat()} if type_name == "time": return {"__type__": "time", "data": obj.isoformat()} if type_name == "timedelta": return {"__type__": "timedelta", "total_seconds": obj.total_seconds()} # decimal.Decimal if type_module == "decimal" and type_name == "Decimal": return {"__type__": "Decimal", "data": str(obj)} # pathlib.Path (and subclasses like PosixPath, WindowsPath) if type_module.startswith("pathlib") and "Path" in type_name: return {"__type__": "Path", "data": str(obj)} # PIL.Image - use cached import pil_image_cls = SafeSerializer._get_optional_type("PIL.Image", "Image") if pil_image_cls is not None and isinstance(obj, pil_image_cls): buffer = BytesIO() obj.save(buffer, format="PNG") return {"__type__": "PIL.Image", "data": base64.b64encode(buffer.getvalue()).decode()} # numpy types - use cached import if type_module == "numpy" or type_module.startswith("numpy."): np_ndarray = SafeSerializer._get_optional_type("numpy", "ndarray") if np_ndarray is not None and obj_type is np_ndarray: return {"__type__": "ndarray", "data": obj.tolist(), "dtype": str(obj.dtype)} np_integer = SafeSerializer._get_optional_type("numpy", "integer") np_floating = SafeSerializer._get_optional_type("numpy", "floating") if (np_integer and isinstance(obj, np_integer)) or (np_floating and isinstance(obj, np_floating)): return obj.item() # dataclass - check last as is_dataclass() has overhead import dataclasses if dataclasses.is_dataclass(obj) and not isinstance(obj, type): return { "__type__": "dataclass", "class_name": type_name, "module": type_module, "data": {f.name: SafeSerializer.to_json_safe(getattr(obj, f.name)) for f in dataclasses.fields(obj)}, } raise SerializationError(f"Cannot safely serialize object of type {type_name}") @staticmethod def from_json_safe(obj: Any) -> Any: """ Convert JSON-safe format back to Python objects. Args: obj: JSON-safe representation Returns: Original Python object """ if isinstance(obj, dict): if "__type__" in obj: obj_type = obj["__type__"] if obj_type == "bytes": return base64.b64decode(obj["data"]) elif obj_type == "PIL.Image": try: import PIL.Image img_bytes = base64.b64decode(obj["data"]) return PIL.Image.open(BytesIO(img_bytes)) except ImportError: return {"__type__": "PIL.Image", "data": obj["data"]} elif obj_type == "set": return set(SafeSerializer.from_json_safe(item) for item in obj["data"]) elif obj_type == "tuple": return tuple(SafeSerializer.from_json_safe(item) for item in obj["data"]) elif obj_type == "complex": return complex(obj["real"], obj["imag"]) elif obj_type == "frozenset": return frozenset(SafeSerializer.from_json_safe(item) for item in obj["data"]) elif obj_type == "dict_with_complex_keys": return {SafeSerializer.from_json_safe(k): SafeSerializer.from_json_safe(v) for k, v in obj["data"]} elif obj_type == "datetime": from datetime import datetime return datetime.fromisoformat(obj["data"]) elif obj_type == "date": from datetime import date return date.fromisoformat(obj["data"]) elif obj_type == "time": from datetime import time return time.fromisoformat(obj["data"]) elif obj_type == "timedelta": from datetime import timedelta return timedelta(seconds=obj["total_seconds"]) elif obj_type == "Decimal": from decimal import Decimal return Decimal(obj["data"]) elif obj_type == "Path": from pathlib import Path return Path(obj["data"]) elif obj_type == "ndarray": try: import numpy as np return np.array(obj["data"], dtype=obj["dtype"]) except ImportError: return obj["data"] # Return as list if numpy not available elif obj_type == "dataclass": # For dataclasses, we return a dict representation # since we can't reconstruct the actual class without access to it return { "__dataclass__": obj["class_name"], "__module__": obj["module"], **{k: SafeSerializer.from_json_safe(v) for k, v in obj["data"].items()}, } return {k: SafeSerializer.from_json_safe(v) for k, v in obj.items()} elif isinstance(obj, list): return [SafeSerializer.from_json_safe(item) for item in obj] return obj @staticmethod def dumps(obj: Any, allow_pickle: bool = False) -> str: """ Serialize object to string. Args: obj: Object to serialize allow_pickle: If False (default), use ONLY safe JSON serialization (error if fails). If True, try safe first, fallback to pickle with warning. Returns: str: Serialized string ("safe:..." for JSON, "pickle:..." for pickle) Raises: SerializationError: If allow_pickle=False and object cannot be safely serialized """ if not allow_pickle: # Safe ONLY mode - no pickle fallback json_safe = SafeSerializer.to_json_safe(obj) # Raises SerializationError if fails return SafeSerializer.SAFE_PREFIX + json.dumps(json_safe) else: # Try safe first, fallback to pickle try: json_safe = SafeSerializer.to_json_safe(obj) return SafeSerializer.SAFE_PREFIX + json.dumps(json_safe) except SerializationError: # Warn about insecure pickle usage import warnings warnings.warn( "Falling back to insecure pickle serialization. " "This is a security risk and will be removed in a future version. " "Consider using only safe serializable types (primitives, lists, dicts, " "numpy arrays, PIL images, datetime objects, dataclasses).", FutureWarning, stacklevel=2, ) # Fallback to pickle (with prefix) try: return "pickle:" + base64.b64encode(pickle.dumps(obj)).decode() except (pickle.PicklingError, TypeError, AttributeError) as e: raise SerializationError(f"Cannot serialize object: {e}") from e @staticmethod def loads(data: str, allow_pickle: bool = False) -> Any: """ Deserialize string with format detection. Args: data: Serialized string (with "safe:" or "pickle:" prefix) allow_pickle: If False (default), reject pickle data (strict safe mode). If True, accept both safe and pickle formats. Returns: Deserialized object Raises: SerializationError: If pickle data received but allow_pickle=False """ if data.startswith(SafeSerializer.SAFE_PREFIX): json_data = json.loads(data[len(SafeSerializer.SAFE_PREFIX) :]) return SafeSerializer.from_json_safe(json_data) elif data.startswith("pickle:"): # Explicit pickle prefix if not allow_pickle: raise SerializationError( "Pickle data rejected: allow_pickle=False requires safe-only data. " "This data is pickle-serialized. To deserialize it, set " "allow_pickle=True (not recommended for untrusted data)." ) # Warn about insecure pickle deserialization import warnings warnings.warn( "Deserializing pickle data. This is a security risk if the data is untrusted.", FutureWarning, stacklevel=2, ) return pickle.loads(base64.b64decode(data[7:])) else: # No prefix - legacy format, assume pickle if not allow_pickle: raise SerializationError( "Pickle data rejected: allow_pickle=False requires safe-only data. " "This data appears to be pickle-serialized (legacy format). To deserialize it, set " "allow_pickle=True (not recommended for untrusted data)." ) # Warn about insecure pickle deserialization import warnings warnings.warn( "Deserializing pickle data. This is a security risk if the data is untrusted.", FutureWarning, stacklevel=2, ) return pickle.loads(base64.b64decode(data)) @staticmethod def _extract_method_body(method) -> str: """Extract method body without the def line and dedent it.""" import inspect import textwrap source = inspect.getsource(method) lines = source.split("\n") # Skip the def line and docstring body_start = 0 for i, line in enumerate(lines): if '"""' in line and i > 0: # Find end of docstring if line.count('"""') == 2: body_start = i + 1 break for j in range(i + 1, len(lines)): if '"""' in lines[j]: body_start = j + 1 break break elif line.strip() and not line.strip().startswith("def ") and not line.strip().startswith("@"): body_start = i break body = "\n".join(lines[body_start:]) return textwrap.dedent(body) @staticmethod def get_safe_serializer_code() -> str: """ Returns the SafeSerializer class definition as string for injection into sandbox. This generates a standalone version from the actual implementation to avoid duplication. """ import inspect # Generate to_json_safe from actual implementation to_json_safe_source = inspect.getsource(SafeSerializer.to_json_safe) # Make it standalone (remove @staticmethod, change self references) to_json_safe_source = to_json_safe_source.replace("@staticmethod\n ", "") to_json_safe_source = to_json_safe_source.replace("SafeSerializer.to_json_safe", "to_json_safe") # Generate from_json_safe from actual implementation from_json_safe_source = inspect.getsource(SafeSerializer.from_json_safe) from_json_safe_source = from_json_safe_source.replace("@staticmethod\n ", "") from_json_safe_source = from_json_safe_source.replace("SafeSerializer.from_json_safe", "from_json_safe") return f''' class SerializationError(Exception): """Raised when a type cannot be safely serialized.""" pass class SafeSerializer: """Safe JSON-based serializer for sandbox use.""" SAFE_PREFIX = "safe:" {to_json_safe_source} {from_json_safe_source} @staticmethod def dumps(obj, allow_pickle=False): import json import base64 import pickle if not allow_pickle: # Safe ONLY - no pickle fallback json_safe = to_json_safe(obj) # Raises SerializationError if fails return SafeSerializer.SAFE_PREFIX + json.dumps(json_safe) else: # Try safe first, fallback to pickle if allowed try: json_safe = to_json_safe(obj) return SafeSerializer.SAFE_PREFIX + json.dumps(json_safe) except SerializationError: try: return "pickle:" + base64.b64encode(pickle.dumps(obj)).decode() except (pickle.PicklingError, TypeError, AttributeError) as e: raise SerializationError(f"Cannot serialize object: {{e}}") from e @staticmethod def loads(data, allow_pickle=False): import json import base64 import pickle if data.startswith(SafeSerializer.SAFE_PREFIX): json_data = json.loads(data[len(SafeSerializer.SAFE_PREFIX):]) return from_json_safe(json_data) elif data.startswith("pickle:"): if not allow_pickle: raise SerializationError("Pickle data rejected: allow_pickle=False") return pickle.loads(base64.b64decode(data[7:])) else: # Legacy format (no prefix) - assume pickle if not allow_pickle: raise SerializationError("Pickle data rejected: allow_pickle=False") return pickle.loads(base64.b64decode(data)) ''' @staticmethod def get_deserializer_code(allow_pickle: bool) -> str: """ Generate deserializer function for remote execution with setting baked in. This generates code from the actual implementation to avoid duplication. Args: allow_pickle: Whether to allow pickle deserialization Returns: Python code string with _deserialize function """ import inspect import textwrap # Build a standalone _from_json_safe function from the source of from_json_safe. from_json_safe_source = inspect.getsource(SafeSerializer.from_json_safe) from_json_safe_source = textwrap.dedent(from_json_safe_source) if from_json_safe_source.startswith("@staticmethod\n"): from_json_safe_source = from_json_safe_source[len("@staticmethod\n") :] from_json_safe_source = from_json_safe_source.replace("def from_json_safe(", "def _from_json_safe(") from_json_safe_source = from_json_safe_source.replace("SafeSerializer.from_json_safe", "_from_json_safe") if allow_pickle: prefixed_pickle_branch = [ " import pickle", " return pickle.loads(base64.b64decode(data[7:]))", ] legacy_pickle_branch = [ " import pickle", " return pickle.loads(base64.b64decode(data))", ] else: prefixed_pickle_branch = [ ' raise SerializationError("Pickle data rejected: allow_pickle=False")', ] legacy_pickle_branch = [ ' raise SerializationError("Pickle data rejected: allow_pickle=False")', ] lines = [ "import base64", "from io import BytesIO", "from typing import Any", "", "class SerializationError(Exception):", " pass", "", from_json_safe_source.rstrip(), "", "def _deserialize(data):", " import json", ' if isinstance(data, str) and data.startswith("safe:"):', " json_data = json.loads(data[5:])", " return _from_json_safe(json_data)", ' elif isinstance(data, str) and data.startswith("pickle:"):', *prefixed_pickle_branch, " else:", " # No safe prefix - legacy format, assume pickle", *legacy_pickle_branch, "", ] return "\n".join(lines)