File size: 1,118 Bytes
16f4534 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import torch
def transform_targets(targets):
"""
Applies the log-transform used during training:
y_trans = sign(y) * log(1 + |y|)
Args:
targets: torch.Tensor or float, raw returns (e.g. 1.5 for 150%)
Returns:
transformed targets in the same shape/type
"""
if isinstance(targets, torch.Tensor):
return torch.sign(targets) * torch.log1p(torch.abs(targets))
else:
# Handle float/numpy
import numpy as np
return np.sign(targets) * np.log1p(np.abs(targets))
def inverse_transform_targets(transformed_targets):
"""
Inverts the log-transform to get back raw returns:
y = sign(y_trans) * (exp(|y_trans|) - 1)
Args:
transformed_targets: torch.Tensor, model outputs (logits)
Returns:
raw returns (e.g. 1.5 for 150%)
"""
if isinstance(transformed_targets, torch.Tensor):
return torch.sign(transformed_targets) * (torch.exp(torch.abs(transformed_targets)) - 1)
else:
import numpy as np
return np.sign(transformed_targets) * (np.exp(np.abs(transformed_targets)) - 1)
|