CatkinChen commited on
Commit
7cd7c4d
·
verified ·
1 Parent(s): eda1065

Add model config

Browse files
Files changed (1) hide show
  1. config.json +28 -14
config.json CHANGED
@@ -7,7 +7,7 @@
7
  "lowrank_dim": 0,
8
  "bInclude_glyph_bag": true,
9
  "bInclude_hero": true,
10
- "dropout_rate": 0.2,
11
  "enable_dropout_on_latent": true,
12
  "enable_dropout_on_decoder": true,
13
  "architecture": "Multi-modal Variational Autoencoder for NetHack game states",
@@ -27,9 +27,11 @@
27
  ],
28
  "training_config": {
29
  "epochs": 15,
30
- "batch_size": 32,
31
- "learning_rate": 0.0005,
32
  "sequence_size": 32,
 
 
33
  "adaptive_weighting": {
34
  "initial_weight_emb": 1.5,
35
  "final_weight_emb": 0.0,
@@ -37,22 +39,34 @@
37
  "initial_weight_raw": 0.4,
38
  "final_weight_raw": 1.0,
39
  "weight_raw_shape": "linear",
40
- "initial_kl_beta": 0.0001,
41
- "final_kl_beta": 0.6,
42
- "kl_beta_shape": "cosine",
43
- "warmup_epoch_ratio": 0.4
 
 
 
 
 
 
44
  },
45
- "total_correlation_beta_multiplier": 10.0,
46
  "free_bits": 0.15,
47
  "focal_loss_alpha": 0.75,
48
  "focal_loss_gamma": 2.0,
49
- "dropout_rate": 0.2,
50
  "enable_dropout_on_latent": true,
51
- "enable_dropout_on_decoder": true
 
 
 
 
 
 
 
52
  },
53
- "final_train_loss": 549.4707604980468,
54
- "final_test_loss": 8138.669360351562,
55
- "best_train_loss": 522.560433959961,
56
- "best_test_loss": 3383.009033203125,
57
  "total_epochs": 15
58
  }
 
7
  "lowrank_dim": 0,
8
  "bInclude_glyph_bag": true,
9
  "bInclude_hero": true,
10
+ "dropout_rate": 0.1,
11
  "enable_dropout_on_latent": true,
12
  "enable_dropout_on_decoder": true,
13
  "architecture": "Multi-modal Variational Autoencoder for NetHack game states",
 
27
  ],
28
  "training_config": {
29
  "epochs": 15,
30
+ "batch_size": 1024,
31
+ "max_learning_rate": 0.001,
32
  "sequence_size": 32,
33
+ "shuffle_batches": true,
34
+ "shuffle_within_batch": true,
35
  "adaptive_weighting": {
36
  "initial_weight_emb": 1.5,
37
  "final_weight_emb": 0.0,
 
39
  "initial_weight_raw": 0.4,
40
  "final_weight_raw": 1.0,
41
  "weight_raw_shape": "linear",
42
+ "initial_mi_beta": 0.0,
43
+ "final_mi_beta": 0.0,
44
+ "mi_beta_shape": "constant",
45
+ "initial_tc_beta": 5.0,
46
+ "final_tc_beta": 5.0,
47
+ "tc_beta_shape": "constant",
48
+ "initial_dw_beta": 0.02,
49
+ "final_dw_beta": 0.3,
50
+ "dw_beta_shape": "custom",
51
+ "warmup_epoch_ratio": 0.2
52
  },
 
53
  "free_bits": 0.15,
54
  "focal_loss_alpha": 0.75,
55
  "focal_loss_gamma": 2.0,
56
+ "dropout_rate": 0.1,
57
  "enable_dropout_on_latent": true,
58
+ "enable_dropout_on_decoder": true,
59
+ "early_stopping": {
60
+ "enabled": true,
61
+ "patience": 3,
62
+ "min_delta": 0.01,
63
+ "triggered": true,
64
+ "best_epoch": 3
65
+ }
66
  },
67
+ "final_train_loss": 2710.9047668457033,
68
+ "final_test_loss": 2245.631726074219,
69
+ "best_train_loss": 2710.9047668457033,
70
+ "best_test_loss": 2245.631726074219,
71
  "total_epochs": 15
72
  }