cesarali commited on
Commit
bdec968
·
verified ·
1 Parent(s): e52e111

best val_rmse 0.0169

Browse files
Files changed (2) hide show
  1. config.json +277 -0
  2. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_val_loss": 0.016884254291653633,
3
+ "comet_ai_key": null,
4
+ "context_observations": {
5
+ "add_rem": true,
6
+ "empirical_number_of_obs": 2,
7
+ "max_num_obs": 15,
8
+ "max_past": 5,
9
+ "min_past": 3,
10
+ "obs_dataset": "/home/ojedamarin/Projects/Pharma/generative_pk/data/preprocessed/lenuzza/Lenuzza2016.csv",
11
+ "past_time_ratio": 0.1,
12
+ "split_past_future": false,
13
+ "type": "pk_peak_half_life"
14
+ },
15
+ "debug_test": true,
16
+ "dosing": {
17
+ "logdose_mean_range": [
18
+ -2.0,
19
+ 2.0
20
+ ],
21
+ "logdose_std_range": [
22
+ 0.1,
23
+ 0.5
24
+ ],
25
+ "num_individuals": 10,
26
+ "route_options": [
27
+ "oral",
28
+ "iv"
29
+ ],
30
+ "route_weights": [
31
+ 0.8,
32
+ 0.2
33
+ ],
34
+ "same_route": true,
35
+ "time": 0.0
36
+ },
37
+ "experiment_dir": "/home/cesarali/Pharma/sim_priors_pk/results/comet/aistats/8d1028b3bd834d328600646d98ec4596",
38
+ "experiment_indentifier": null,
39
+ "experiment_name": "aistats",
40
+ "hf_model_card_path": [
41
+ "hf_model_cards",
42
+ "NODE-PK_Readme.md"
43
+ ],
44
+ "hf_model_name": "NodePK_cluster",
45
+ "hugging_face_token": null,
46
+ "meta_study": {
47
+ "V_tmag_range": [
48
+ 0.001,
49
+ 0.001
50
+ ],
51
+ "V_tscl_range": [
52
+ 1,
53
+ 5
54
+ ],
55
+ "drug_id_options": [
56
+ "Drug_A",
57
+ "Drug_B",
58
+ "Drug_C"
59
+ ],
60
+ "k_1p_tmag_range": [
61
+ 0.01,
62
+ 0.02
63
+ ],
64
+ "k_1p_tscl_range": [
65
+ 1,
66
+ 5
67
+ ],
68
+ "k_a_tmag_range": [
69
+ 0.01,
70
+ 0.02
71
+ ],
72
+ "k_a_tscl_range": [
73
+ 1,
74
+ 5
75
+ ],
76
+ "k_e_tmag_range": [
77
+ 0.01,
78
+ 0.02
79
+ ],
80
+ "k_e_tscl_range": [
81
+ 1,
82
+ 5
83
+ ],
84
+ "k_p1_tmag_range": [
85
+ 0.01,
86
+ 0.02
87
+ ],
88
+ "k_p1_tscl_range": [
89
+ 1,
90
+ 5
91
+ ],
92
+ "log_V_mean_range": [
93
+ 2,
94
+ 8
95
+ ],
96
+ "log_V_std_range": [
97
+ 0.2,
98
+ 0.6
99
+ ],
100
+ "log_k_1p_mean_range": [
101
+ -4,
102
+ 0
103
+ ],
104
+ "log_k_1p_std_range": [
105
+ 0.2,
106
+ 0.6
107
+ ],
108
+ "log_k_a_mean_range": [
109
+ -1,
110
+ 2
111
+ ],
112
+ "log_k_a_std_range": [
113
+ 0.2,
114
+ 0.6
115
+ ],
116
+ "log_k_e_mean_range": [
117
+ -5,
118
+ 0
119
+ ],
120
+ "log_k_e_std_range": [
121
+ 0.2,
122
+ 0.6
123
+ ],
124
+ "log_k_p1_mean_range": [
125
+ -4,
126
+ -1
127
+ ],
128
+ "log_k_p1_std_range": [
129
+ 0.2,
130
+ 0.6
131
+ ],
132
+ "num_individuals_range": [
133
+ 5,
134
+ 10
135
+ ],
136
+ "num_peripherals_range": [
137
+ 1,
138
+ 3
139
+ ],
140
+ "rel_ruv_range": [
141
+ 0.001,
142
+ 0.01
143
+ ],
144
+ "solver_method": "rk4",
145
+ "time_num_steps": 100,
146
+ "time_start": 0.0,
147
+ "time_stop": 16.0
148
+ },
149
+ "mix_data": {
150
+ "evaluate_prediction_steps_past": 5,
151
+ "keep_tempfile": false,
152
+ "log_transform": false,
153
+ "n_of_databatches": null,
154
+ "n_of_permutations": 3,
155
+ "n_of_target_individuals": 1,
156
+ "normalize_by_max": true,
157
+ "normalize_time": true,
158
+ "pretraining_epochs": 800,
159
+ "pretraining_protocol": "none",
160
+ "recreate_tempfile": false,
161
+ "split_seed": 42,
162
+ "split_strategy": "study",
163
+ "store_in_tempfile": false,
164
+ "tempfile_path": [
165
+ "preprocessed",
166
+ "simulated_ou_as_rates"
167
+ ],
168
+ "test_empirical_datasets": [
169
+ "cesarali/lenuzza-2016",
170
+ "cesarali/Indometacin",
171
+ "cesarali/Theophylline"
172
+ ],
173
+ "test_protocol": "simulated",
174
+ "test_size": 64,
175
+ "tqdm_progress": false,
176
+ "train_size": 320,
177
+ "val_protocol": "simulated",
178
+ "val_size": 64,
179
+ "z_score_normalization": false
180
+ },
181
+ "model_type": "node_pk",
182
+ "my_results_path": null,
183
+ "name_str": "PredictionPK",
184
+ "network": {
185
+ "activation": "ReLU",
186
+ "aggregator_num_heads": 8,
187
+ "aggregator_type": "attention",
188
+ "combine_latent_mode": "mlp",
189
+ "cov_proj_dim": 16,
190
+ "decoder_attention_layers": 2,
191
+ "decoder_hidden_dim": 512,
192
+ "decoder_name": "RNNDecoder",
193
+ "decoder_num_layers": 4,
194
+ "decoder_rnn_hidden_dim": 256,
195
+ "drift_activation": "Tanh",
196
+ "drift_num_layers": 2,
197
+ "dropout": 0.1,
198
+ "encoder_rnn_hidden_dim": 256,
199
+ "exclusive_node_step": false,
200
+ "ignore_logvar": true,
201
+ "individual_encoder_name": "RNNContextEncoderDosing",
202
+ "individual_encoder_number_of_heads": 4,
203
+ "init_hidden_num_layers": 4,
204
+ "input_encoding_hidden_dim": 128,
205
+ "kl_weight": 1.0,
206
+ "loss_name": "log_nll",
207
+ "node_step": true,
208
+ "norm": "layer",
209
+ "output_head_num_layers": 3,
210
+ "prediction_latent_deterministic": false,
211
+ "prediction_only": false,
212
+ "reconstruction_only": false,
213
+ "rnn_decoder_number_of_layers": 4,
214
+ "rnn_individual_encoder_number_of_layers": 4,
215
+ "study_latent_deterministic": false,
216
+ "time_obs_encoder_hidden_dim": 256,
217
+ "time_obs_encoder_output_dim": 256,
218
+ "use_attention": true,
219
+ "use_invariance_loss": true,
220
+ "use_kl_i": true,
221
+ "use_kl_i_np": true,
222
+ "use_kl_init": true,
223
+ "use_kl_s": true,
224
+ "use_self_attention": true,
225
+ "use_time_deltas": true,
226
+ "zi_latent_dim": 256
227
+ },
228
+ "run_index": 0,
229
+ "tags": [
230
+ "AISTATS-2026",
231
+ "NODE-PK"
232
+ ],
233
+ "target_observations": {
234
+ "add_rem": true,
235
+ "empirical_number_of_obs": 2,
236
+ "max_num_obs": 15,
237
+ "max_past": 5,
238
+ "min_past": 3,
239
+ "obs_dataset": "/home/ojedamarin/Projects/Pharma/generative_pk/data/preprocessed/lenuzza/Lenuzza2016.csv",
240
+ "past_time_ratio": 0.1,
241
+ "split_past_future": true,
242
+ "type": "pk_peak_half_life"
243
+ },
244
+ "train": {
245
+ "amsgrad": false,
246
+ "batch_size": 32,
247
+ "betas": [
248
+ 0.9,
249
+ 0.999
250
+ ],
251
+ "epochs": 10,
252
+ "eps": 1e-08,
253
+ "gradient_clip_val": 1.0,
254
+ "learning_rate": 0.0001,
255
+ "log_empirical_evaluation_pct": 0.5,
256
+ "log_image_every_epoch_pct": 0.5,
257
+ "log_interval": 1,
258
+ "log_prediction_in_val": true,
259
+ "log_reconstruction_in_val": true,
260
+ "log_vcp": false,
261
+ "num_batch_plot": 1,
262
+ "num_workers": 8,
263
+ "optimizer_name": "AdamW",
264
+ "persistent_workers": true,
265
+ "scheduler_name": "CosineAnnealingLR",
266
+ "scheduler_params": {
267
+ "T_max": 1000,
268
+ "eta_min": 5e-05,
269
+ "last_epoch": -1
270
+ },
271
+ "shuffle_val": true,
272
+ "weight_decay": 0.0001
273
+ },
274
+ "transformers_version": "4.52.4",
275
+ "upload_to_hf_hub": false,
276
+ "verbose": false
277
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a54495279aabc709aa3509b78dfe59d46a146679804503167b3e5db5c518c3b
3
+ size 31136387