burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
8c48bb1 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Shared serialization and deserialization utilities for OpenEnv HTTP servers.
This module provides common utilities for converting between JSON dictionaries
and Pydantic models (Action/Observation) to eliminate code duplication across
HTTP server and web interface implementations.
"""
from typing import Any, Dict, Type
from .types import Action, Observation
def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
"""
Convert JSON dict to Action instance using Pydantic validation.
This is a basic deserialization that works for most environments.
For special cases (e.g., tensor fields, custom type conversions),
use deserialize_action_with_preprocessing().
Args:
action_data: Dictionary containing action data
action_cls: The Action subclass to instantiate
Returns:
Action instance
Raises:
ValidationError: If action_data is invalid for the action class
Note:
This uses Pydantic's model_validate() for automatic validation.
"""
return action_cls.model_validate(action_data)
def deserialize_action_with_preprocessing(
action_data: Dict[str, Any], action_cls: Type[Action]
) -> Action:
"""
Convert JSON dict to Action instance with preprocessing for special types.
This version handles common type conversions needed for web interfaces:
- Converting lists/strings to tensors for 'tokens' field
- Converting string action_id to int
- Other custom preprocessing as needed
Args:
action_data: Dictionary containing action data
action_cls: The Action subclass to instantiate
Returns:
Action instance
Raises:
ValidationError: If action_data is invalid for the action class
"""
processed_data = {}
for key, value in action_data.items():
if key == "tokens" and isinstance(value, (list, str)):
# Convert list or string to tensor
if isinstance(value, str):
# If it's a string, try to parse it as a list of numbers
try:
import json
value = json.loads(value)
except Exception:
# If parsing fails, treat as empty list
value = []
if isinstance(value, list):
try:
import torch # type: ignore
processed_data[key] = torch.tensor(value, dtype=torch.long)
except ImportError:
# If torch not available, keep as list
processed_data[key] = value
else:
processed_data[key] = value
elif key == "action_id" and isinstance(value, str):
# Convert action_id from string to int
try:
processed_data[key] = int(value)
except ValueError:
# If conversion fails, keep original value
processed_data[key] = value
else:
processed_data[key] = value
return action_cls.model_validate(processed_data)
def serialize_observation(observation: Observation) -> Dict[str, Any]:
"""
Convert Observation instance to JSON-compatible dict using Pydantic.
Args:
observation: Observation instance
Returns:
Dictionary compatible with EnvClient._parse_result()
The format matches what EnvClient expects:
{
"observation": {...}, # Observation fields
"reward": float | None,
"done": bool,
}
"""
# Use Pydantic's model_dump() for serialization
obs_dict = observation.model_dump(
exclude={
"reward",
"done",
"metadata",
} # Exclude these from observation dict
)
# Extract reward and done directly from the observation
reward = observation.reward
done = observation.done
# Return in EnvClient expected format
return {
"observation": obs_dict,
"reward": reward,
"done": done,
}