jaxaht-benchmark / agents /initialize_agents.py
lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
import jax
from agents.mlp_actor_critic_agent import MLPActorCriticPolicy, ActorWithDoubleCriticPolicy, \
ActorWithConditionalCriticPolicy, PseudoActorWithDoubleCriticPolicy, \
PseudoActorWithConditionalCriticPolicy
from agents.rnn_actor_critic_agent import RNNActorCriticPolicy
from agents.s5_actor_critic_agent import S5ActorCriticPolicy
from agents.liam_agent import LIAMPolicy, initialize_liam_encoder_decoder
from agents.meliba_agent import MeLIBAPolicy, initialize_meliba_encoder_decoder
def initialize_s5_agent(config, env, rng):
"""Initialize an S5 agent with the given config.
Args:
config: dict, config for the agent
env: gymnasium environment
rng: jax.random.PRNGKey, random key for initialization
Returns:
policy: S5ActorCriticPolicy, the policy object
params: dict, initial parameters for the agent
"""
# Create the S5 policy with direct parameters
policy = S5ActorCriticPolicy(
action_dim=env.action_space(env.agents[0]).n,
obs_dim=config.get("POLICY_INPUT_DIM", env.observation_space(env.agents[0]).shape[0]),
d_model=config.get("S5_D_MODEL", 128),
ssm_size=config.get("S5_SSM_SIZE", 128),
# d_model=config.get("S5_D_MODEL", 16),
# ssm_size=config.get("S5_SSM_SIZE", 16),
ssm_n_layers=config.get("S5_N_LAYERS", 2),
blocks=config.get("S5_BLOCKS", 1),
fc_hidden_dim=config.get("S5_ACTOR_CRITIC_HIDDEN_DIM", 1024),
fc_n_layers=config.get("FC_N_LAYERS", 3),
# fc_hidden_dim=config.get("S5_ACTOR_CRITIC_HIDDEN_DIM", 64),
# fc_n_layers=config.get("FC_N_LAYERS", 2),
s5_activation=config.get("S5_ACTIVATION", "full_glu"),
s5_do_norm=config.get("S5_DO_NORM", True),
s5_prenorm=config.get("S5_PRENORM", True),
s5_do_gtrxl_norm=config.get("S5_DO_GTRXL_NORM", True),
)
rng, init_rng = jax.random.split(rng)
init_params = policy.init_params(init_rng)
return policy, init_params
def initialize_rnn_agent(config, env, rng):
"""Initialize an RNN agent with the given config.
Args:
config: dict, config for the agent
env: gymnasium environment
rng: jax.random.PRNGKey, random key for initialization
Returns:
policy: RNNActorCriticPolicy, the policy object
params: dict, initial parameters for the agent
"""
# Create the RNN policy
policy = RNNActorCriticPolicy(
action_dim=env.action_space(env.agents[0]).n,
obs_dim=config.get("POLICY_INPUT_DIM", env.observation_space(env.agents[0]).shape[0]),
activation=config.get("ACTIVATION", "tanh"),
fc_hidden_dim=config.get("FC_HIDDEN_DIM", 64),
gru_hidden_dim=config.get("GRU_HIDDEN_DIM", 64),
)
rng, init_rng = jax.random.split(rng)
init_params = policy.init_params(init_rng)
return policy, init_params
def initialize_mlp_agent(config, env, rng):
"""
Initialize an MLP agent with the given config.
"""
policy = MLPActorCriticPolicy(
action_dim=env.action_space(env.agents[0]).n,
obs_dim=config.get("POLICY_INPUT_DIM", env.observation_space(env.agents[0]).shape[0]),
activation=config.get("ACTIVATION", "tanh"),
fc_hidden_dim=config.get("FC_HIDDEN_DIM", 64),
)
rng, init_rng = jax.random.split(rng)
init_params = policy.init_params(init_rng)
return policy, init_params
def initialize_actor_with_double_critic(config, env, rng):
"""Initialize an actor with double critic with the given config."""
policy = ActorWithDoubleCriticPolicy(
action_dim=env.action_space(env.agents[0]).n,
obs_dim=config.get("POLICY_INPUT_DIM", env.observation_space(env.agents[0]).shape[0]),
activation=config.get("ACTIVATION", "tanh"),
fc_hidden_dim=config.get("FC_HIDDEN_DIM", 64),
)
rng, init_rng = jax.random.split(rng)
init_params = policy.init_params(init_rng)
return policy, init_params
def initialize_pseudo_actor_with_double_critic(config, env, rng):
"""Initialize a pseudo actor with double critic with the given config."""
policy = PseudoActorWithDoubleCriticPolicy(
action_dim=env.action_space(env.agents[0]).n,
obs_dim=config.get("POLICY_INPUT_DIM", env.observation_space(env.agents[0]).shape[0]),
activation=config.get("ACTIVATION", "tanh"),
fc_hidden_dim=config.get("FC_HIDDEN_DIM", 64),
)
rng, init_rng = jax.random.split(rng)
init_params = policy.init_params(init_rng)
return policy, init_params
def initialize_actor_with_conditional_critic(config, env, rng):
"""Initialize an actor with conditional critic with the given config."""
policy = ActorWithConditionalCriticPolicy(
action_dim=env.action_space(env.agents[0]).n,
obs_dim=config.get("POLICY_INPUT_DIM", env.observation_space(env.agents[0]).shape[0]),
pop_size=config["POP_SIZE"],
activation=config.get("ACTIVATION", "tanh"),
fc_hidden_dim=config.get("FC_HIDDEN_DIM", 64),
)
rng, init_rng = jax.random.split(rng)
init_params = policy.init_params(init_rng)
return policy, init_params
def initialize_pseudo_actor_with_conditional_critic(config, env, rng):
"""Initialize a pseudo actor with conditional critic with the given config."""
policy = PseudoActorWithConditionalCriticPolicy(
action_dim=env.action_space(env.agents[0]).n,
obs_dim=config.get("POLICY_INPUT_DIM", env.observation_space(env.agents[0]).shape[0]),
pop_size=config["POP_SIZE"],
activation=config.get("ACTIVATION", "tanh"),
fc_hidden_dim=config.get("FC_HIDDEN_DIM", 64),
)
rng, init_rng = jax.random.split(rng)
init_params = policy.init_params(init_rng)
return policy, init_params
def initialize_liam_agent(config, env, rng):
"""Initialize the LIAM ego agent with the given config.
Args:
config: dict, config for the agent
env: gymnasium environment
rng: jax.random.PRNGKey, random key for initialization
Returns:
liam: LIAMPolicy, the policy object
params: tuple, initial parameters for the {encoder, decoder} and policy
"""
rng, init_encoder_decoder_rng, init_policy_rng = jax.random.split(rng, 3)
# Initialize the policy based on the specified type
if config["EGO_ACTOR_TYPE"] == "s5":
ego_policy, init_ego_params = initialize_s5_agent(config, env, init_policy_rng)
elif config["EGO_ACTOR_TYPE"] == "mlp":
ego_policy, init_ego_params = initialize_mlp_agent(config, env, init_policy_rng)
elif config["EGO_ACTOR_TYPE"] == "rnn":
ego_policy, init_ego_params = initialize_rnn_agent(config, env, init_policy_rng)
# Initialize the encoder and decoder for LIAM
encoder, decoder, init_encoder_decoder_params = initialize_liam_encoder_decoder(config, env, init_encoder_decoder_rng)
liam = LIAMPolicy(
policy=ego_policy,
encoder=encoder,
decoder=decoder
)
params = {'encoder': init_encoder_decoder_params['encoder'],
'decoder': init_encoder_decoder_params['decoder'],
'policy': init_ego_params}
return liam, params
def initialize_meliba_agent(config, env, rng):
"""Initialize the MeLIBA ego agent with the given config.
Args:
config: dict, config for the agent
env: gymnasium environment
rng: jax.random.PRNGKey, random key for initialization
Returns:
meliba: MeLIBAPolicy, the policy object
params: tuple, initial parameters for the {encoder, decoder} and policy
"""
rng, init_encoder_decoder_rng, init_policy_rng = jax.random.split(rng, 3)
# Initialize the policy based on the specified type
if config["EGO_ACTOR_TYPE"] == "s5":
ego_policy, init_ego_params = initialize_s5_agent(config, env, init_policy_rng)
elif config["EGO_ACTOR_TYPE"] == "mlp":
ego_policy, init_ego_params = initialize_mlp_agent(config, env, init_policy_rng)
elif config["EGO_ACTOR_TYPE"] == "rnn":
ego_policy, init_ego_params = initialize_rnn_agent(config, env, init_policy_rng)
# Initialize the encoder and decoder for LIAM
encoder, decoder, init_encoder_decoder_params = initialize_meliba_encoder_decoder(config, env, init_encoder_decoder_rng)
meliba = MeLIBAPolicy(
policy=ego_policy,
encoder=encoder,
decoder=decoder
)
params = {'encoder': init_encoder_decoder_params['encoder'],
'decoder': init_encoder_decoder_params['decoder'],
'policy': init_ego_params}
return meliba, params