oracle / inference_utils.py
zirobtc's picture
Upload folder using huggingface_hub
16f4534
import torch
def transform_targets(targets):
"""
Applies the log-transform used during training:
y_trans = sign(y) * log(1 + |y|)
Args:
targets: torch.Tensor or float, raw returns (e.g. 1.5 for 150%)
Returns:
transformed targets in the same shape/type
"""
if isinstance(targets, torch.Tensor):
return torch.sign(targets) * torch.log1p(torch.abs(targets))
else:
# Handle float/numpy
import numpy as np
return np.sign(targets) * np.log1p(np.abs(targets))
def inverse_transform_targets(transformed_targets):
"""
Inverts the log-transform to get back raw returns:
y = sign(y_trans) * (exp(|y_trans|) - 1)
Args:
transformed_targets: torch.Tensor, model outputs (logits)
Returns:
raw returns (e.g. 1.5 for 150%)
"""
if isinstance(transformed_targets, torch.Tensor):
return torch.sign(transformed_targets) * (torch.exp(torch.abs(transformed_targets)) - 1)
else:
import numpy as np
return np.sign(transformed_targets) * (np.exp(np.abs(transformed_targets)) - 1)