File size: 922 Bytes
2d8433b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Base transform implementations for composing environment-specific transforms."""

from .interfaces import Transform
from .types import Observation


class CompositeTransform(Transform):
    """Combines multiple transforms into a single transform."""

    def __init__(self, transforms: list[Transform]):
        self.transforms = transforms

    def __call__(self, observation: Observation) -> Observation:
        for transform in self.transforms:
            observation = transform(observation)
        return observation


class NullTransform(Transform):
    """Default transform that passes through unchanged."""

    def __call__(self, observation: Observation) -> Observation:
        return observation