File size: 1,407 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
State encoder for the policy network.
MLP-based (replaces GNN from v3 design — too complex for hackathon timeline).
Document: GNN would be used in production for the delegation graph component.
"""

from __future__ import annotations
import torch
import torch.nn as nn


class StateEncoder(nn.Module):
    """
    Encodes the flat state vector into a compressed representation.
    The SB3 policy will use this as its feature extractor.

    Architecture:
      - Input: flat state vector (~1376 + N*768 dims)
      - Hidden: 512 → 256 → 128
      - Output: 128-dim feature vector

    Note: The MLP operates on the full flat vector including:
      - Task embedding (384)
      - Roster + called specialist embeddings (padded)
      - Graph adjacency vector (100)
      - Scratchpad summary (384)
      - Scalar features (8)
    This is the "MLP adjacency" approach that replaces the GNN.
    """

    def __init__(self, input_dim: int, output_dim: int = 128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
            nn.ReLU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)