File size: 4,516 Bytes
41a9651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""

Observation encoder for the car racing PPO agent.



Input

-----

img     : (B, 3, 64, 64)  float32, pixels normalised to 0..1

scalars : (B, 7)           float32, [angular_velocity, speed, rayΓ—5]



Output

------

(B, 288)  flat feature vector β†’ feed directly into actor / critic heads.



Architecture

------------

ImpalaCNN  (Espeholt et al., IMPALA 2018)

  3 blocks Γ— (Conv β†’ MaxPool β†’ ResBlock β†’ ResBlock)

  channels : 16 β†’ 32 β†’ 32

  64Γ—64 input shrinks to 8Γ—8 after 3 stride-2 MaxPools  β†’  32Γ—8Γ—8 = 2048 β†’ FC(256)



  Key difference vs Nature CNN: each block adds two residual (skip) connections.

  Gradients flow straight back through the shortcuts, so early conv filters keep

  updating throughout training.  Empirically 3-5Γ— more sample-efficient on

  visual RL tasks at identical inference cost.



Scalar MLP

  7 β†’ 32 β†’ 32  (angular_velocity, speed, rayΓ—5)



Combined

  cat([img_features, scalar_features])  β†’  288-d vector

"""

import torch
import torch.nn as nn


# ── Building blocks ───────────────────────────────────────────────────────────

class _ResBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.net(x)          # skip connection


class _ImpalaBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.pool = nn.MaxPool2d(3, stride=2, padding=1)
        self.res1 = _ResBlock(out_ch)
        self.res2 = _ResBlock(out_ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(self.conv(x))
        x = self.res1(x)
        x = self.res2(x)
        return x


# ── Encoders ──────────────────────────────────────────────────────────────────

class ImpalaCNN(nn.Module):
    """

    Encodes a (B, 3, 64, 64) image to a (B, 256) feature vector.



    Block channels [16, 32, 32]:

        input  64Γ—64

        block1 32Γ—32  (16 ch)

        block2 16Γ—16  (32 ch)

        block3  8Γ—8   (32 ch)  β†’  flatten 2048  β†’  FC 256

    """

    CHANNELS = [16, 32, 32]

    def __init__(self, in_channels: int = 3, out_features: int = 256):
        super().__init__()
        blocks, ch = [], in_channels
        for out_ch in self.CHANNELS:
            blocks.append(_ImpalaBlock(ch, out_ch))
            ch = out_ch
        self.cnn = nn.Sequential(*blocks, nn.ReLU())
        self.fc  = nn.Sequential(
            nn.Flatten(),
            nn.Linear(ch * 8 * 8, out_features),
            nn.ReLU(),
        )
        self.out_features = out_features

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        return self.fc(self.cnn(img))


class RaceEncoder(nn.Module):
    """

    Full encoder: ImpalaCNN for image + small MLP for scalars, outputs

    concatenated feature vector for actor / critic heads.



    out_features = img_features (256) + scalar_features (32) = 288

    """

    def __init__(self, img_features: int = 256, scalar_features: int = 32):
        super().__init__()
        self.cnn = ImpalaCNN(out_features=img_features)
        self.scalar_mlp = nn.Sequential(
            # 9 scalars: angular_velocity, speed, rayΓ—5, wp_sin, wp_cos
            nn.Linear(9, scalar_features),
            nn.ReLU(),
            nn.Linear(scalar_features, scalar_features),
            nn.ReLU(),
        )
        self.out_features = img_features + scalar_features

    def forward(self, img: torch.Tensor, scalars: torch.Tensor) -> torch.Tensor:
        """

        img     : (B, 3, 64, 64)  float32  pixels / 255

        scalars : (B, 7)          float32  obs.scalars

                  [angular_velocity, speed,

                   ray_left, ray_front_left, ray_front, ray_front_right, ray_right]

        returns : (B, out_features)

        """
        return torch.cat([self.cnn(img), self.scalar_mlp(scalars)], dim=-1)