File size: 1,306 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Tests for action decoder and policy components."""
import pytest
import numpy as np
from env.action_space import ActionDecoder, MetaAction, DelegationMode


def test_action_decoder_stop():
    decoder = ActionDecoder(["a", "b", "c"], max_specialists=3)
    action = np.zeros(decoder.get_action_dim(), dtype=np.float32)
    action[0] = 1.0  # STOP
    factored = decoder.decode(action)
    assert factored.meta_action == MetaAction.STOP
    assert factored.is_terminal()


def test_action_decoder_call_specialist():
    ids = ["frontend_react", "backend_api", "database_architect"]
    decoder = ActionDecoder(ids, max_specialists=3)
    action = np.zeros(decoder.get_action_dim(), dtype=np.float32)
    action[0] = 0.0   # CALL_SPECIALIST
    action[1] = 1.0   # Select frontend_react
    factored = decoder.decode(action)
    assert factored.meta_action == MetaAction.CALL_SPECIALIST
    assert "frontend_react" in factored.specialist_ids


def test_specialist_mask():
    ids = ["a", "b", "c"]
    decoder = ActionDecoder(ids, max_specialists=3)
    mask = decoder.build_specialist_mask(["b"])
    assert mask[0] == 0.0
    assert mask[1] == 1.0
    assert mask[2] == 0.0


def test_action_dim():
    decoder = ActionDecoder(["a", "b"], max_specialists=2)
    assert decoder.get_action_dim() == 2 + 6