| { | |
| "name": "alien_iris_world_model", | |
| "env": "AlienNoFrameskip-v4", | |
| "model_type": "iris", | |
| "metadata": { | |
| "latent_dim": [1, 16], | |
| "num_tokens": 340 | |
| }, | |
| "util_folders":{ | |
| "models": "../../src/models" | |
| }, | |
| "requirements":{ | |
| "-r": "requirements.txt" | |
| }, | |
| "models": [ | |
| { | |
| "name": "world_model", | |
| "framework": null, | |
| "format": "state_dict", | |
| "source": { | |
| "weights_path": "world_model.pt", | |
| "class_path": "../../src/world_model.py", | |
| "class_name": "WorldModel", | |
| "class_args": [ | |
| { | |
| "vocab_size": 512, | |
| "act_vocab_size": 18, | |
| "tokens_per_block": 17, | |
| "max_blocks": 20, | |
| "attention": "causal", | |
| "num_layers": 10, | |
| "num_heads": 4, | |
| "embed_dim": 256, | |
| "embed_pdrop": 0.1, | |
| "resid_pdrop": 0.1, | |
| "attn_pdrop": 0.1 | |
| }] | |
| }, | |
| "signature": { | |
| "inputs": ["tokens", "past_keys_values"], | |
| "call_mode": "positional" | |
| }, | |
| "sub_models": | |
| [ | |
| { | |
| "name": "transformer", | |
| "sub_model_name": "transformer", | |
| "signature": | |
| { | |
| "inputs": ["sequences", "past_keys_values"], | |
| "call_mode": "positional" | |
| } | |
| } | |
| ], | |
| "methods": | |
| [ | |
| { | |
| "name": "generate_empty_keys_values", | |
| "method_name": "generate_empty_keys_values", | |
| "signature": | |
| { | |
| "inputs": ["n"] | |
| } | |
| } | |
| ] | |
| }, | |
| { | |
| "name": "tokenizer", | |
| "framework": null, | |
| "format": "state_dict", | |
| "source": { | |
| "weights_path": "tokenizer.pt", | |
| "class_path": "../../src/tokenizer.py", | |
| "class_name": "Tokenizer", | |
| "class_args": [{ | |
| "vocab_size": 512, | |
| "embed_dim": 512, | |
| "encoder": { | |
| "resolution": 64, | |
| "in_channels": 3, | |
| "z_channels": 512, | |
| "ch": 64, | |
| "ch_mult": [1, 1, 1, 1, 1], | |
| "num_res_blocks": 2, | |
| "attn_resolutions": [8, 16], | |
| "out_ch": 3, | |
| "dropout": 0.0 | |
| }, | |
| "decoder": { | |
| "resolution": 64, | |
| "in_channels": 3, | |
| "z_channels": 512, | |
| "ch": 64, | |
| "ch_mult": [1, 1, 1, 1, 1], | |
| "num_res_blocks": 2, | |
| "attn_resolutions": [8, 16], | |
| "out_ch": 3, | |
| "dropout": 0.0 | |
| } | |
| }] | |
| }, | |
| "signature": { | |
| "inputs": ["x", "should_preprocess", "should_postprocess"], | |
| "call_mode": "positional" | |
| }, | |
| "sub_models": | |
| [ | |
| { | |
| "name": "embedding", | |
| "sub_model_name": "embedding", | |
| "signature": | |
| { | |
| "call_mode": "auto" | |
| } | |
| } | |
| ], | |
| "methods": | |
| [ | |
| { | |
| "name": "decode", | |
| "method_name": "decode", | |
| "signature": | |
| { | |
| "inputs": ["z", "should_postprocess"] | |
| } | |
| }, | |
| { | |
| "name": "decode_obs_tokens", | |
| "method_name": "decode_obs_tokens", | |
| "signature": | |
| { | |
| "inputs": ["obs_tokens", "num_observations_tokens"] | |
| } | |
| }, | |
| { | |
| "name": "encode", | |
| "method_name": "encode", | |
| "signature": | |
| { | |
| "inputs": ["observations", "should_preprocess"] | |
| } | |
| } | |
| ] | |
| } | |
| ] | |
| } |