Spaces:
Runtime error
Runtime error
File size: 1,407 Bytes
02ff91f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | """
State encoder for the policy network.
MLP-based (replaces GNN from v3 design — too complex for hackathon timeline).
Document: GNN would be used in production for the delegation graph component.
"""
from __future__ import annotations
import torch
import torch.nn as nn
class StateEncoder(nn.Module):
"""
Encodes the flat state vector into a compressed representation.
The SB3 policy will use this as its feature extractor.
Architecture:
- Input: flat state vector (~1376 + N*768 dims)
- Hidden: 512 → 256 → 128
- Output: 128-dim feature vector
Note: The MLP operates on the full flat vector including:
- Task embedding (384)
- Roster + called specialist embeddings (padded)
- Graph adjacency vector (100)
- Scratchpad summary (384)
- Scalar features (8)
This is the "MLP adjacency" approach that replaces the GNN.
"""
def __init__(self, input_dim: int, output_dim: int = 128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, output_dim),
nn.ReLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
|