Spaces:
Sleeping
Sleeping
| """ | |
| CNN policy for MiniGrid partial-obs RGB observations (56x56x3). | |
| One strided conv (stride=2) instead of two, keeping 28x28 feature maps | |
| before the AdaptiveAvgPool — better suited for the smaller input. | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.distributions import Categorical | |
| def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | |
| nn.init.orthogonal_(layer.weight, std) | |
| nn.init.constant_(layer.bias, bias_const) | |
| return layer | |
| class CNNPolicy(nn.Module): | |
| def __init__(self, obs_shape, n_actions: int, hidden_dim: int = 256): | |
| super().__init__() | |
| H, W, C = obs_shape | |
| self.cnn = nn.Sequential( | |
| layer_init(nn.Conv2d(C, 32, kernel_size=3, stride=2, padding=1)), | |
| nn.ReLU(), | |
| layer_init(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)), | |
| nn.ReLU(), | |
| layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)), | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool2d((8, 8)), | |
| nn.Flatten(), | |
| ) | |
| with torch.no_grad(): | |
| dummy = torch.zeros(1, C, H, W) | |
| cnn_out = self.cnn(dummy).shape[1] | |
| self.fc = nn.Sequential( | |
| layer_init(nn.Linear(cnn_out, hidden_dim)), | |
| nn.ReLU(), | |
| ) | |
| self.policy_head = layer_init(nn.Linear(hidden_dim, n_actions), std=0.01) | |
| self.value_head = layer_init(nn.Linear(hidden_dim, 1), std=1.0) | |
| def _preprocess(self, obs): | |
| if obs.dtype == torch.uint8: | |
| obs = obs.float() / 255.0 | |
| return obs.permute(0, 3, 1, 2) | |
| def forward(self, obs): | |
| x = self._preprocess(obs) | |
| x = self.cnn(x) | |
| x = self.fc(x) | |
| return self.policy_head(x), self.value_head(x).squeeze(-1) | |
| def get_action_and_value(self, obs, action=None): | |
| logits, value = self.forward(obs) | |
| dist = Categorical(logits=logits) | |
| if action is None: | |
| action = dist.sample() | |
| return action, dist.log_prob(action), dist.entropy(), value | |
| def get_value(self, obs): | |
| _, value = self.forward(obs) | |
| return value | |