Spaces:
Running
Running
| import pytest | |
| import torch | |
| from ding.torch_utils import is_differentiable | |
| from lzero.model.common import RepresentationNetwork | |
| class TestCommon: | |
| def output_check(self, model, outputs): | |
| if isinstance(outputs, torch.Tensor): | |
| loss = outputs.sum() | |
| elif isinstance(outputs, list): | |
| loss = sum([t.sum() for t in outputs]) | |
| elif isinstance(outputs, dict): | |
| loss = sum([v.sum() for v in outputs.values()]) | |
| is_differentiable(loss, model) | |
| def test_representation_network(self, batch_size): | |
| batch = batch_size | |
| obs = torch.rand(batch, 1, 3, 3) | |
| representation_network = RepresentationNetwork( | |
| observation_shape=[1, 3, 3], num_res_blocks=1, num_channels=16, downsample=False | |
| ) | |
| state = representation_network(obs) | |
| assert state.shape == torch.Size([10, 16, 3, 3]) | |