Spaces:
Paused
Paused
| import pytest | |
| import torch | |
| from models.ctm import ContinuousThoughtMachine | |
| from models.ctm_qamnist import ContinuousThoughtMachineQAMNIST | |
| from utils.samplers import QAMNISTSampler | |
| from tasks.qamnist.utils import get_dataset | |
| from tests.test_data import * | |
| from utils.housekeeping import set_seed | |
| from tasks.rl.train import Agent | |
| from types import SimpleNamespace | |
| # --- Housekeeping --- | |
| def device(): | |
| return torch.device("cpu") | |
| def seed(): | |
| return 42 | |
| def auto_set_seed(seed): | |
| set_seed(seed) | |
| # --- Golden Test Fixtures --- | |
| # ------ Parity ------ | |
| def golden_test_params_parity(parity_params): | |
| parity_length = 4 | |
| parity_small_params = parity_params.copy() | |
| parity_small_params["out_dims"] = 2 * parity_length | |
| parity_small_params["prediction_reshaper"] = [parity_length, 2] | |
| return parity_small_params | |
| def golden_test_model_parity(golden_test_params_parity, device): | |
| return ContinuousThoughtMachine(**golden_test_params_parity).to(device) | |
| def golden_test_input_parity(device): | |
| return GOLDEN_TEST_INPUT_PARITY.to(device) | |
| def golden_test_expected_predictions_parity(device): | |
| return GOLDEN_TEST_EXPECTED_PREDICTIONS_PARITY.to(device) | |
| def golden_test_expected_certainties_parity(device): | |
| return GOLDEN_TEST_EXPECTED_CERTAINTIES_PARITY.to(device) | |
| def golden_test_expected_synchronization_out_tracking_parity(): | |
| return GOLDEN_TEST_EXPECTED_SYNCH_OUT_TRACKING_PARITY | |
| def golden_test_expected_synchronization_action_tracking_parity(): | |
| return GOLDEN_TEST_EXPECTED_SYNCH_ACTION_TRACKING_PARITY | |
| def golden_test_expected_pre_activations_tracking_parity(): | |
| return GOLDEN_TEST_EXPECTED_PRE_ACTIVATIONS_TRACKING_PARITY | |
| def golden_test_expected_post_activations_tracking_parity(): | |
| return GOLDEN_TEST_EXPECTED_POST_ACTIVATIONS_TRACKING_PARITY | |
| def golden_test_expected_attentions_tracking_parity(): | |
| return GOLDEN_TEST_EXPECTED_ATTENTIONS_TRACKING_PARITY | |
| # ------ QAMNIST ------ | |
| def golden_test_params_qamnist(base_params): | |
| params = base_params.copy() | |
| params.pop("backbone_type") | |
| params.pop("positional_embedding_type") | |
| params["iterations_per_digit"] = 1 | |
| params["iterations_per_question_part"] = 1 | |
| params["iterations_for_answering"] = 1 | |
| return params | |
| def golden_test_model_qamnist(golden_test_params_qamnist, device): | |
| return ContinuousThoughtMachineQAMNIST(**golden_test_params_qamnist).to(device) | |
| def golden_test_input_qamnist(device): | |
| q_num_images = 2 | |
| q_num_images_delta = 0 | |
| q_num_repeats_per_input = 1 | |
| q_num_operations = 2 | |
| q_num_operations_delta = 0 | |
| batch_size = 1 | |
| train_data, _, _, _, _ = get_dataset( | |
| q_num_images=q_num_images, | |
| q_num_images_delta=q_num_images_delta, | |
| q_num_repeats_per_input=q_num_repeats_per_input, | |
| q_num_operations=q_num_operations, | |
| q_num_operations_delta=q_num_operations_delta, | |
| ) | |
| sampler = QAMNISTSampler(train_data, batch_size=batch_size) | |
| loader = torch.utils.data.DataLoader(train_data, batch_sampler=sampler, num_workers=0) | |
| inputs, z, _, _ = next(iter(loader)) | |
| inputs = inputs.to(device) | |
| z = torch.stack(z, 1).to(device) | |
| return inputs, z | |
| def golden_test_expected_predictions_qamnist(device): | |
| return GOLDEN_TEST_EXPECTED_PREDICTIONS_QAMNIST.to(device) | |
| def golden_test_expected_certainties_qamnist(device): | |
| return GOLDEN_TEST_EXPECTED_CERTAINTIES_QAMNIST.to(device) | |
| def golden_test_expected_synchronization_out_tracking_qamnist(device): | |
| return GOLDEN_TEST_EXPECTED_SYNCH_OUT_TRACKING_QAMNIST.to(device) | |
| def golden_test_expected_synchronization_action_tracking_qamnist(): | |
| return GOLDEN_TEST_EXPECTED_SYNCH_ACTION_TRACKING_QAMNIST | |
| def golden_test_expected_pre_activations_tracking_qamnist(): | |
| return GOLDEN_TEST_EXPECTED_PRE_ACTIVATIONS_TRACKING_QAMNIST | |
| def golden_test_expected_post_activations_tracking_qamnist(): | |
| return GOLDEN_TEST_EXPECTED_POST_ACTIVATIONS_TRACKING_QAMNIST | |
| def golden_test_expected_attentions_tracking_qamnist(): | |
| return GOLDEN_TEST_EXPECTED_ATTENTIONS_TRACKING_QAMNIST | |
| def golden_test_expected_embeddings_tracking_qamnist(): | |
| return GOLDEN_TEST_EXPECTED_EMBEDDINGS_TRACKING_QAMNIST | |
| # ------ RL (CartPole) ------ | |
| def golden_test_params_rl(base_params): | |
| params = base_params.copy() | |
| params.pop("heads") | |
| params.pop("n_synch_action") | |
| params.pop("out_dims") | |
| params.pop("n_random_pairing_self") | |
| params.pop("positional_embedding_type") | |
| return params | |
| def golden_test_model_rl(golden_test_params_rl, device): | |
| args = SimpleNamespace( | |
| env_id="CartPole-v1", | |
| model_type="ctm", | |
| continuous_state_trace=True, | |
| iterations=golden_test_params_rl['iterations'], | |
| d_model=golden_test_params_rl['d_model'], | |
| d_input=golden_test_params_rl['d_input'], | |
| n_synch_out=golden_test_params_rl['n_synch_out'], | |
| synapse_depth=golden_test_params_rl['synapse_depth'], | |
| memory_length=golden_test_params_rl['memory_length'], | |
| deep_memory=golden_test_params_rl['deep_nlms'], | |
| memory_hidden_dims=golden_test_params_rl['memory_hidden_dims'], | |
| do_normalisation=golden_test_params_rl['do_layernorm_nlm'], | |
| dropout=golden_test_params_rl.get('dropout', 0), | |
| neuron_select_type=golden_test_params_rl.get('neuron_select_type', 'first-last'), | |
| ) | |
| size_action_space = 2 | |
| model = Agent(size_action_space, args, device).to(device) | |
| return model | |
| def environment_to_test(): | |
| return "cartpole" | |
| def golden_test_inputs_rl(environment_to_test): | |
| if environment_to_test != "cartpole": | |
| raise NotImplementedError("RL test only tests cartpole.") | |
| observations = torch.tensor([[ 0.01, 0.02, 0.03, 0.04],], dtype=torch.float32) | |
| return observations | |
| def golden_test_expected_initial_state_trace_rl(device): | |
| return GOLDEN_TEST_EXPECTED_INITIAL_STATE_TRACE_RL.to(device) | |
| def golden_test_expected_initial_activated_state_trace_rl(device): | |
| return GOLDEN_TEST_EXPECTED_INITIAL_ACTIVATED_STATE_TRACE_RL.to(device) | |
| def golden_test_expected_action_rl(device): | |
| return GOLDEN_TEST_EXPECTED_ACTION_RL.to(device) | |
| def golden_test_expected_action_log_prob_rl(device): | |
| return GOLDEN_TEST_EXPECTED_ACTION_LOG_PROB_RL.to(device) | |
| def golden_test_expected_action_entropy_rl(device): | |
| return GOLDEN_TEST_EXPECTED_ENTROPY_RL.to(device) | |
| def golden_test_expected_value_rl(device): | |
| return GOLDEN_TEST_EXPECTED_VALUE_RL.to(device) | |
| def golden_test_expected_state_trace_rl(device): | |
| return GOLDEN_TEST_EXPECTED_STATE_TRACE_RL.to(device) | |
| def golden_test_expected_activated_state_trace_rl(device): | |
| return GOLDEN_TEST_EXPECTED_ACTIVATED_STATE_TRACE_RL.to(device) | |
| def golden_test_expected_action_logits_rl(device): | |
| return GOLDEN_TEST_EXPECTED_ACTION_LOGITS_RL.to(device) | |
| def golden_test_expected_action_probs_rl(device): | |
| return GOLDEN_TEST_EXPECTED_ACTION_PROBS_RL.to(device) | |
| def golden_test_expected_pre_activations_tracking_rl(): | |
| return GOLDEN_TEST_EXPECTED_PRE_ACTIVATIONS_TRACKING_RL | |
| def golden_test_expected_post_activations_tracking_rl(): | |
| return GOLDEN_TEST_EXPECTED_POST_ACTIVATIONS_TRACKING_RL | |
| def golden_test_expected_synch_out_tracking_rl(): | |
| return GOLDEN_TEST_SYNCH_OUT_TRACKING_RL | |
| # --- Parity Fixtures --- | |
| def base_params(): | |
| return dict( | |
| iterations=10, | |
| d_model=32, | |
| d_input=4, | |
| heads=2, | |
| n_synch_out=3, | |
| n_synch_action=3, | |
| synapse_depth=1, | |
| memory_length=5, | |
| deep_nlms=True, | |
| memory_hidden_dims=2, | |
| do_layernorm_nlm=False, | |
| backbone_type="none", | |
| positional_embedding_type="none", | |
| out_dims=10, | |
| prediction_reshaper=[-1], | |
| dropout=0.0, | |
| neuron_select_type="first-last", | |
| n_random_pairing_self=0, | |
| ) | |
| def parity_input(device): | |
| batch_size = 4 | |
| parity_length = 64 | |
| return torch.randint(0, 2, (batch_size, parity_length), dtype=torch.float32, device=device) * 2 - 1 | |
| def ctm_factory(): | |
| def _create_model(base_config, **overrides): | |
| config = base_config.copy() | |
| config.update(overrides) | |
| return ContinuousThoughtMachine(**config) | |
| return _create_model | |
| def parity_params(base_params): | |
| parity_length = 64 | |
| parity_params = base_params.copy() | |
| parity_params["backbone_type"] = "parity_backbone" | |
| parity_params["positional_embedding_type"] = "custom-rotational-1d" | |
| parity_params["out_dims"] = 2 * parity_length | |
| parity_params["prediction_reshaper"] = [parity_length, 2] | |
| return parity_params | |
| def parity_ctm_model(parity_params, device): | |
| return ContinuousThoughtMachine(**parity_params).to(device) | |
| # --- QAMNIST Fixtures --- | |
| def qamnist_params(): | |
| return dict( | |
| iterations=1, | |
| d_model=1024, | |
| d_input=64, | |
| heads=4, | |
| n_synch_out=32, | |
| n_synch_action=32, | |
| synapse_depth=1, | |
| memory_length=10, | |
| deep_nlms=True, | |
| memory_hidden_dims=16, | |
| do_layernorm_nlm=True, | |
| out_dims=10, | |
| prediction_reshaper=[-1], | |
| dropout=0.0, | |
| neuron_select_type="first-last", | |
| iterations_per_digit=10, | |
| iterations_per_question_part=10, | |
| iterations_for_answering=10, | |
| n_random_pairing_self=0, | |
| ) | |
| def qamnist_input(device): | |
| q_num_images = 3 | |
| q_num_images_delta = 2 | |
| q_num_repeats_per_input = 10 | |
| q_num_operations = 3 | |
| q_num_operations_delta = 2 | |
| batch_size = 4 | |
| train_data, _, _, _, _ = get_dataset( | |
| q_num_images=q_num_images, | |
| q_num_images_delta=q_num_images_delta, | |
| q_num_repeats_per_input=q_num_repeats_per_input, | |
| q_num_operations=q_num_operations, | |
| q_num_operations_delta=q_num_operations_delta, | |
| ) | |
| sampler = QAMNISTSampler(train_data, batch_size=batch_size) | |
| loader = torch.utils.data.DataLoader(train_data, batch_sampler=sampler, num_workers=0) | |
| inputs, z, _, _ = next(iter(loader)) | |
| inputs = inputs.to(device) | |
| z = torch.stack(z, 1).to(device) | |
| return inputs, z | |
| def qamnist_model_factory(qamnist_params): | |
| def _create_model(neuron_select_type): | |
| return ContinuousThoughtMachineQAMNIST( | |
| **{**qamnist_params, "neuron_select_type": neuron_select_type} | |
| ) | |
| return _create_model | |