CatkinChen commited on
Commit
52156bb
·
verified ·
1 Parent(s): 0500b69

Add training data

Browse files
Files changed (1) hide show
  1. training_data.json +9 -31
training_data.json CHANGED
@@ -1,14 +1,14 @@
1
  {
2
  "train_losses": [
3
- 3866.4689331054688
4
  ],
5
  "test_losses": [
6
- 2702.8141845703126
7
  ],
8
  "config": {
9
  "hmm_only": false,
10
  "vae_only_with_hmm": false,
11
- "em_rounds": 3,
12
  "m_epochs_per_round": 1,
13
  "hmm_params": {
14
  "alpha": 5.0,
@@ -26,14 +26,10 @@
26
  "set_Psi0_with_global_cov": false
27
  },
28
  "hmm_paths": [
29
- "checkpoints_hmm/hmm_round1.pt",
30
- "checkpoints_hmm/hmm_round2.pt",
31
- "checkpoints_hmm/hmm_round3.pt"
32
  ],
33
  "vae_hmm_paths": [
34
- "checkpoints_hmm/vae_with_hmm_round1.pt",
35
- "checkpoints_hmm/vae_with_hmm_round2.pt",
36
- "checkpoints_hmm/vae_with_hmm_round3.pt"
37
  ],
38
  "hf_repos": {
39
  "hmm": "CatkinChen/nethack-hmm",
@@ -48,30 +44,12 @@
48
  "skill_raster": "hmm_analysis/round_01/round01_skill_raster.png",
49
  "dwell_pmfs": "hmm_analysis/round_01/round01_dwell_pmfs.png",
50
  "diags_json": "hmm_analysis/round_01/round01_diags.json"
51
- },
52
- {
53
- "dir": "hmm_analysis/round_02",
54
- "pi_bar": "hmm_analysis/round_02/round02_pi_bar.png",
55
- "A_heatmap": "hmm_analysis/round_02/round02_A_heatmap.png",
56
- "mu_pca": "hmm_analysis/round_02/round02_mu_t-sne.png",
57
- "skill_raster": "hmm_analysis/round_02/round02_skill_raster.png",
58
- "dwell_pmfs": "hmm_analysis/round_02/round02_dwell_pmfs.png",
59
- "diags_json": "hmm_analysis/round_02/round02_diags.json"
60
- },
61
- {
62
- "dir": "hmm_analysis/round_03",
63
- "pi_bar": "hmm_analysis/round_03/round03_pi_bar.png",
64
- "A_heatmap": "hmm_analysis/round_03/round03_A_heatmap.png",
65
- "mu_pca": "hmm_analysis/round_03/round03_mu_t-sne.png",
66
- "skill_raster": "hmm_analysis/round_03/round03_skill_raster.png",
67
- "dwell_pmfs": "hmm_analysis/round_03/round03_dwell_pmfs.png",
68
- "diags_json": "hmm_analysis/round_03/round03_diags.json"
69
  }
70
  ]
71
  },
72
- "final_train_loss": 3866.4689331054688,
73
- "final_test_loss": 2702.8141845703126,
74
  "total_epochs": 1,
75
- "best_train_loss": 3866.4689331054688,
76
- "best_test_loss": 2702.8141845703126
77
  }
 
1
  {
2
  "train_losses": [
3
+ 3848.3857275390624
4
  ],
5
  "test_losses": [
6
+ 2673.374951171875
7
  ],
8
  "config": {
9
  "hmm_only": false,
10
  "vae_only_with_hmm": false,
11
+ "em_rounds": 4,
12
  "m_epochs_per_round": 1,
13
  "hmm_params": {
14
  "alpha": 5.0,
 
26
  "set_Psi0_with_global_cov": false
27
  },
28
  "hmm_paths": [
29
+ "checkpoints_hmm/hmm_round1.pt"
 
 
30
  ],
31
  "vae_hmm_paths": [
32
+ "checkpoints_hmm/vae_with_hmm_round1.pt"
 
 
33
  ],
34
  "hf_repos": {
35
  "hmm": "CatkinChen/nethack-hmm",
 
44
  "skill_raster": "hmm_analysis/round_01/round01_skill_raster.png",
45
  "dwell_pmfs": "hmm_analysis/round_01/round01_dwell_pmfs.png",
46
  "diags_json": "hmm_analysis/round_01/round01_diags.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  }
48
  ]
49
  },
50
+ "final_train_loss": 3848.3857275390624,
51
+ "final_test_loss": 2673.374951171875,
52
  "total_epochs": 1,
53
+ "best_train_loss": 3848.3857275390624,
54
+ "best_test_loss": 2673.374951171875
55
  }