Spaces:
Sleeping
Sleeping
| from typing import Union | |
| import numpy as np | |
| import torch | |
| class DiscreteSupport(object): | |
| def __init__(self, min: int, max: int, delta: float = 1.) -> None: | |
| assert min < max | |
| self.min = min | |
| self.max = max | |
| self.range = np.arange(min, max + 1, delta) | |
| self.size = len(self.range) | |
| self.set_size = len(self.range) | |
| self.delta = delta | |
| def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Transform the original value to the scaled value, i.e. the h(.) function | |
| in paper https://arxiv.org/pdf/1805.11593.pdf. | |
| Reference: | |
| - MuZero: Appendix F: Network Architecture | |
| - https://arxiv.org/pdf/1805.11593.pdf (Page-11) Appendix A : Proposition A.2 | |
| """ | |
| # h(.) function | |
| if delta == 1: # for speed up | |
| output = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + epsilon * x | |
| else: | |
| # delta != 1 | |
| output = torch.sign(x) * (torch.sqrt(torch.abs(x / delta) + 1) - 1) + epsilon * x / delta | |
| return output | |
| def inverse_scalar_transform( | |
| logits: torch.Tensor, | |
| support_size: int, | |
| epsilon: float = 0.001, | |
| categorical_distribution: bool = True | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| transform the scaled value or its categorical representation to the original value, | |
| i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. | |
| Reference: | |
| - MuZero Appendix F: Network Architecture. | |
| - https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2 | |
| """ | |
| if categorical_distribution: | |
| scalar_support = DiscreteSupport(-support_size, support_size, delta=1) | |
| value_probs = torch.softmax(logits, dim=1) | |
| value_support = torch.from_numpy(scalar_support.range).unsqueeze(0) | |
| value_support = value_support.to(device=value_probs.device) | |
| value = (value_support * value_probs).sum(1, keepdim=True) | |
| else: | |
| value = logits | |
| # h^(-1)(.) function | |
| output = torch.sign(value) * ( | |
| ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1 | |
| ) | |
| # TODO(pu): comment this line due to saving time | |
| # output[torch.abs(output) < epsilon] = 0. | |
| return output | |
| class InverseScalarTransform: | |
| """ | |
| Overview: | |
| transform the the scaled value or its categorical representation to the original value, | |
| i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. | |
| Reference: | |
| - MuZero Appendix F: Network Architecture. | |
| - https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2 | |
| """ | |
| def __init__( | |
| self, | |
| support_size: int, | |
| device: Union[str, torch.device] = 'cpu', | |
| categorical_distribution: bool = True | |
| ) -> None: | |
| scalar_support = DiscreteSupport(-support_size, support_size, delta=1) | |
| self.value_support = torch.from_numpy(scalar_support.range).unsqueeze(0) | |
| self.value_support = self.value_support.to(device) | |
| self.categorical_distribution = categorical_distribution | |
| def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor: | |
| if self.categorical_distribution: | |
| value_probs = torch.softmax(logits, dim=1) | |
| value = value_probs.mul_(self.value_support).sum(1, keepdim=True) | |
| else: | |
| value = logits | |
| tmp = ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) | |
| # t * t is faster than t ** 2 | |
| output = torch.sign(value) * (tmp * tmp - 1) | |
| return output | |
| def visit_count_temperature( | |
| manual_temperature_decay: bool, fixed_temperature_value: float, | |
| threshold_training_steps_for_final_lr_temperature: int, trained_steps: int | |
| ) -> float: | |
| if manual_temperature_decay: | |
| if trained_steps < 0.5 * threshold_training_steps_for_final_lr_temperature: | |
| return 1.0 | |
| elif trained_steps < 0.75 * threshold_training_steps_for_final_lr_temperature: | |
| return 0.5 | |
| else: | |
| return 0.25 | |
| else: | |
| return fixed_temperature_value | |
| def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations. | |
| After this transformation, each scalar is represented as the linear combination of its two adjacent supports. | |
| Reference: | |
| - MuZero paper Appendix F: Network Architecture. | |
| """ | |
| min = discrete_support.min | |
| max = discrete_support.max | |
| set_size = discrete_support.set_size | |
| delta = discrete_support.delta | |
| x.clamp_(min, max) | |
| x_low = x.floor() | |
| x_high = x.ceil() | |
| p_high = x - x_low | |
| p_low = 1 - p_high | |
| target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device) | |
| x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta | |
| target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) | |
| target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) | |
| return target | |
| def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| return -(torch.log_softmax(prediction, dim=1) * target).sum(1) | |