plm igr training
Browse files- .gitattributes +2 -0
- plm_igr/best_model.pt +3 -0
- plm_igr/config.yaml +145 -0
- plm_igr/predictions_test.feather +3 -0
- plm_igr/predictions_validation.feather +3 -0
- plm_igr/target_tokenizer.json +1 -0
.gitattributes
CHANGED
|
@@ -39,3 +39,5 @@ plm_icd/predictions_test.feather filter=lfs diff=lfs merge=lfs -text
|
|
| 39 |
plm_icd/predictions_validation.feather filter=lfs diff=lfs merge=lfs -text
|
| 40 |
suppervised_attention_2/predictions_test.feather filter=lfs diff=lfs merge=lfs -text
|
| 41 |
suppervised_attention_2/predictions_validation.feather filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 39 |
plm_icd/predictions_validation.feather filter=lfs diff=lfs merge=lfs -text
|
| 40 |
suppervised_attention_2/predictions_test.feather filter=lfs diff=lfs merge=lfs -text
|
| 41 |
suppervised_attention_2/predictions_validation.feather filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
plm_igr/predictions_test.feather filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
plm_igr/predictions_validation.feather filter=lfs diff=lfs merge=lfs -text
|
plm_igr/best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:785460de24dfb44bc024543dfeae4117f870af37cadbc7640b6e278b5a3d457e
|
| 3 |
+
size 1509227526
|
plm_igr/config.yaml
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 1337
|
| 2 |
+
deterministic: false
|
| 3 |
+
gpu: 0
|
| 4 |
+
name: null
|
| 5 |
+
debug: false
|
| 6 |
+
load_model: null
|
| 7 |
+
distillation: false
|
| 8 |
+
data:
|
| 9 |
+
dataset_path: explainable_medical_coding/datasets/mdace_inpatient_icd9.py
|
| 10 |
+
target_columns:
|
| 11 |
+
- diagnosis_codes
|
| 12 |
+
- procedure_codes
|
| 13 |
+
max_length: 6000
|
| 14 |
+
dataloader:
|
| 15 |
+
max_batch_size: 4
|
| 16 |
+
batch_size: 4
|
| 17 |
+
num_workers: 0
|
| 18 |
+
drop_last: false
|
| 19 |
+
pin_memory: false
|
| 20 |
+
batch_sampler:
|
| 21 |
+
name: BySequenceLengthSampler
|
| 22 |
+
configs:
|
| 23 |
+
bucket_boundaries:
|
| 24 |
+
- 400
|
| 25 |
+
- 600
|
| 26 |
+
- 800
|
| 27 |
+
- 1000
|
| 28 |
+
- 1200
|
| 29 |
+
- 1400
|
| 30 |
+
- 1600
|
| 31 |
+
- 1800
|
| 32 |
+
- 2000
|
| 33 |
+
- 2200
|
| 34 |
+
- 2600
|
| 35 |
+
- 3000
|
| 36 |
+
- 3400
|
| 37 |
+
- 4000
|
| 38 |
+
- 5000
|
| 39 |
+
model:
|
| 40 |
+
name: PLMICD
|
| 41 |
+
autoregressive: false
|
| 42 |
+
configs:
|
| 43 |
+
model_path: models/roberta-base-pm-m3-voc-hf
|
| 44 |
+
chunk_size: 128
|
| 45 |
+
cross_attention: true
|
| 46 |
+
loss: binary_cross_entropy
|
| 47 |
+
lambda_1: 0.0
|
| 48 |
+
scale: 1
|
| 49 |
+
mask_input: false
|
| 50 |
+
trainer:
|
| 51 |
+
name: Trainer
|
| 52 |
+
epochs: 20
|
| 53 |
+
validate_on_training_data: true
|
| 54 |
+
print_metrics: false
|
| 55 |
+
use_amp: true
|
| 56 |
+
threshold_tuning: true
|
| 57 |
+
clip_grad_norm: 1
|
| 58 |
+
clip_value: 10
|
| 59 |
+
optimizer:
|
| 60 |
+
name: AdamW
|
| 61 |
+
configs:
|
| 62 |
+
lr: 5.0e-05
|
| 63 |
+
weight_decay: 0
|
| 64 |
+
lr_scheduler:
|
| 65 |
+
name: linear
|
| 66 |
+
configs:
|
| 67 |
+
warmup: 0.1
|
| 68 |
+
metrics:
|
| 69 |
+
- name: F1Score
|
| 70 |
+
configs:
|
| 71 |
+
average: micro
|
| 72 |
+
- name: F1Score
|
| 73 |
+
configs:
|
| 74 |
+
average: macro
|
| 75 |
+
- name: Recall
|
| 76 |
+
configs:
|
| 77 |
+
average: micro
|
| 78 |
+
- name: Recall
|
| 79 |
+
configs:
|
| 80 |
+
average: macro
|
| 81 |
+
- name: Precision
|
| 82 |
+
configs:
|
| 83 |
+
average: micro
|
| 84 |
+
- name: Precision
|
| 85 |
+
configs:
|
| 86 |
+
average: macro
|
| 87 |
+
- name: FPR
|
| 88 |
+
configs:
|
| 89 |
+
average: micro
|
| 90 |
+
- name: FPR
|
| 91 |
+
configs:
|
| 92 |
+
average: macro
|
| 93 |
+
- name: ExactMatchRatio
|
| 94 |
+
configs: {}
|
| 95 |
+
- name: Precision_K
|
| 96 |
+
configs:
|
| 97 |
+
k: 5
|
| 98 |
+
- name: Precision_K
|
| 99 |
+
configs:
|
| 100 |
+
k: 8
|
| 101 |
+
- name: Precision_K
|
| 102 |
+
configs:
|
| 103 |
+
k: 15
|
| 104 |
+
- name: Recall_K
|
| 105 |
+
configs:
|
| 106 |
+
k: 5
|
| 107 |
+
- name: Recall_K
|
| 108 |
+
configs:
|
| 109 |
+
k: 10
|
| 110 |
+
- name: Recall_K
|
| 111 |
+
configs:
|
| 112 |
+
k: 15
|
| 113 |
+
- name: MeanAveragePrecision
|
| 114 |
+
configs: {}
|
| 115 |
+
- name: PrecisionAtRecall
|
| 116 |
+
configs: {}
|
| 117 |
+
- name: AUC
|
| 118 |
+
configs:
|
| 119 |
+
average: micro
|
| 120 |
+
- name: AUC
|
| 121 |
+
configs:
|
| 122 |
+
average: macro
|
| 123 |
+
- name: LossMetric
|
| 124 |
+
configs: {}
|
| 125 |
+
callbacks:
|
| 126 |
+
- name: WandbCallback
|
| 127 |
+
configs:
|
| 128 |
+
project: explainable-medical-coding
|
| 129 |
+
entity: null
|
| 130 |
+
- name: SaveBestModelCallback
|
| 131 |
+
configs:
|
| 132 |
+
split: validation
|
| 133 |
+
target: all
|
| 134 |
+
metric: map
|
| 135 |
+
- name: EarlyStoppingCallback
|
| 136 |
+
configs:
|
| 137 |
+
split: validation
|
| 138 |
+
target: all
|
| 139 |
+
metric: map
|
| 140 |
+
patience: 3
|
| 141 |
+
loss:
|
| 142 |
+
name: double_backpropagation_loss
|
| 143 |
+
configs:
|
| 144 |
+
lambda_1: 1.0e-05
|
| 145 |
+
p: 1
|
plm_igr/predictions_test.feather
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:16e15126025e5d6c04ef6a389e115060a84b3e8ccd31ca8dc795cf0e358f6ef9
|
| 3 |
+
size 840034
|
plm_igr/predictions_validation.feather
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4866bedcb164d405769bad50bbd9d09c9207250edd7e8ee04b7fac2943be827c
|
| 3 |
+
size 817722
|
plm_igr/target_tokenizer.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
["112.0", "V58.61", "458.9", "E879.8", "36.14", "415.11", "569.83", "053.19", "286.9", "682.3", "790.29", "246.9", "294.21", "459.2", "198.7", "453.40", "342.92", "722.93", "410.71", "820.21", "348.1", "429.71", "620.2", "733.22", "041.85", "807.01", "428.0", "724.5", "427.9", "070.54", "324.0", "294.10", "008.45", "285.9", "250.40", "590.10", "820.20", "V12.54", "39.76", "410.31", "276.1", "458.21", "496", "520.6", "557.0", "452", "754.81", "998.32", "568.0", "571.5", "599.0", "401.9", "440.1", "455.0", "567.29", "288.8", "780.60", "553.29", "535.01", "996.09", "396.3", "480.9", "V12.01", "453.41", "736.9", "323.9", "733.99", "810.00", "991.6", "174.9", "441.2", "50.11", "E878.1", "V10.02", "V10.06", "823.30", "730.16", "823.31", "729.1", "285.1", "428.23", "45.23", "428.30", "555.1", "V10.03", "35.22", "250.50", "724.3", "304.01", "790.7", "969.4", "430", "599.71", "533.90", "999.31", "863.89", "51.88", "96.04", "529.6", "V42.7", "345.40", "369.8", "V43.64", "346.00", "V64.42", "853.06", "853.01", "553.1", "998.12", "259.4", "303.00", "99.04", "284.19", "785.59", "799.4", "357.2", "041.11", "575.8", "578.9", "89.49", "V15.52", "198.4", "45.31", "695.10", "459.81", "277.88", "V12.04", "404.93", "414.01", "860.0", "782.4", "707.11", "038.8", "737.39", "403.91", "715.15", "V10.3", "150.5", "V15.3", "288.00", "557.1", "438.21", "427.89", "348.4", "788.29", "244.9", "E812.0", "276.3", "591", "552.21", "574.81", "425.4", "V45.88", "97.29", "710.2", "45.30", "802.8", "303.91", "379.23", "786.1", "265.1", "808.2", "747.69", "241.0", "39.95", "280.0", "E901.0", "331.0", "V49.72", "285.8", "491.22", "V09.80", "576.8", "725", "873.49", "510.0", "575.10", "365.9", "593.9", "303.93", "535.60", "197.7", "518.53", "344.09", "V85.42", "278.00", "707.10", "780.09", "780.97", "799.1", "E001.0", "39.89", "36.11", "437.3", "041.04", "518.81", "34.91", "39.71", "V10.05", "333.94", "410.41", "482.9", "33.21", "459.89", "952.00", "52.22", "823.21", "51.01", "282.49", "514", "294.8", "807.03", "253.2", "99.25", "E919.4", "357.0", "88.41", "535.10", "54.91", "E849.0", "288.50", "578.0", "211.1", "783.41", "715.14", "453.81", "719.41", "696.0", "198.3", "239.6", "416.8", "531.40", "E878.2", "V49.75", "153.6", "410.01", "532.20", "578.1", "607.84", "733.40", "88.72", "780.52", "780.61", "272.6", "807.4", "44.43", "426.53", "585.6", "V10.47", "45.22", "571.1", "250.00", "250.02", "345.90", "V45.81", "303.90", "45.16", "54.24", "427.41", "414.2", "V10.43", "38.69", "572.4", "83.65", "211.3", "276.69", "V10.46", "277.4", "295.90", "294.11", "V46.3", "188.9", "35.23", "680.6", "278.02", "802.6", "438.22", "341.9", "825.25", "516.9", "403.90", "V45.02", "511.89", "790.4", "96.72", "441.3", "791.9", "737.10", "781.94", "02.21", "250.11", "532.40", "245.2", "237.3", "427.69", "317", "242.90", "428.33", "344.00", "43.11", "V12.71", "V45.79", "493.20", "01.39", "585.2", "35.21", "410.90", "51.22", "93.90", "724.2", "996.02", "V42.2", "421.0", "E812.2", "112.2", "366.8", "263.0", "289.89", "491.21", "784.3", "428.42", "721.2", "442.84", "414.8", "041.6", "300.3", "320.2", "342.91", "572.8", "780.57", "792.1", "03.53", "39.79", "038.49", "96.6", "821.01", "V10.82", "205.00", "574.91", "37.23", "426.4", "790.92", "923.03", "197.0", "070.44", "276.2", "997.2", "45.13", "51.36", "787.01", "564.00", "896.0", "518.0", "995.29", "865.03", "348.2", "518.82", "511.1", "305.01", "53.80", "787.91", "286.6", "873.44", "V44.1", "787.22", "204.00", "427.32", "574.10", "88.51", "337.20", "785.0", "277.39", "332.0", "965.1", "995.91", "V64.2", "213.0", "357.9", "571.8", "852.01", "562.11", "368.47", "730.13", "V58.63", "584.5", "309.81", "285.29", "519.19", "88.48", "574.31", "E823.0", "253.6", "593.2", "785.51", "682.2", "368.2", "434.91", "733.90", "276.0", "785.50", "560.0", "313.89", "03.31", "722.4", "576.2", "171.9", "238.75", "291.9", "V42.0", "368.46", "346.90", "427.1", "577.1", "733.01", "197.6", "443.0", "288.60", "281.1", "312.30", "802.4", "276.7", "349.82", "852.21", "806.26", "V49.83", "482.41", "E888.9", "536.3", "358.01", "952.9", "456.1", "531.90", "311", "530.85", "585.3", "730.10", "279.00", "276.51", "36.07", "998.11", "996.47", "532.90", "189.0", "296.80", "426.0", "041.86", "E884.9", "295.62", "579.0", "E882", "V45.4", "861.21", "995.92", "V15.82", "431", "787.3", "805.05", "053.9", "203.00", "38.93", "584.9", "696.1", "814.00", "428.22", "272.4", "456.8", "493.90", "070.30", "196.2", "300.00", "411.89", "V10.52", "34.09", "070.70", "250.13", "788.20", "V12.51", "433.11", "719.40", "372.73", "807.2", "404.91", "V45.71", "424.2", "213.6", "807.08", "V58.66", "860.4", "323.41", "287.31", "565.1", "511.81", "E819.2", "041.49", "E860", "707.14", "151.0", "560.81", "V10.21", "338.18", "038.11", "426.11", "780.2", "816.00", "V87.41", "951.3", "721.3", "958.92", "V13.02", "314.01", "294.20", "692.9", "440.20", "518.84", "873.0", "787.20", "252.00", "782.3", "493.92", "963.0", "204.02", "853.05", "707.03", "396.2", "600.01", "793.99", "707.19", "01.6", "410.72", "351.0", "50.22", "562.12", "447.8", "530.0", "110.8", "805.6", "275.42", "E878.6", "E885.9", "V10.51", "E861", "E950.0", "682.6", "458.29", "99.08", "788.30", "707.21", "038.40", "V44.6", "998.59", "389.9", "041.7", "209.79", "162.8", "238.71", "81.62", "E883.0", "117.9", "959.9", "E819.9", "078.5", "354.0", "34.04", "157.8", "348.30", "593.89", "996.81", "E884.2", "V45.89", "427.5", "V49.86", "37.22", "560.2", "293.0", "535.51", "860.2", "530.3", "785.52", "813.22", "88.49", "441.4", "997.31", "714.0", "337.3", "864.05", "424.1", "227.3", "427.61", "572.2", "588.1", "158.0", "88.39", "336.0", "971.1", "E884.6", "783.7", "162.5", "537.83", "V58.65", "225.4", "434.01", "801.26", "291.81", "805.02", "424.90", "441.02", "V54.16", "346.80", "396.8", "424.0", "50.3", "820.8", "383.9", "854.06", "41.31", "136.3", "729.92", "852.26", "E880.9", "162.9", "293.1", "473.0", "250.80", "572.3", "338.29", "863.84", "02.92", "355.1", "V45.82", "272.0", "780.4", "V09.91", "99.71", "432.9", "784.0", "397.9", "997.1", "V44.2", "99.05", "571.2", "813.43", "E849.7", "156.1", "250.51", "558.9", "99.07", "V10.11", "250.60", "453.82", "789.59", "807.09", "808.42", "157.0", "562.10", "596.54", "745.5", "209.30", "37.61", "412", "442.83", "512.1", "707.07", "038.42", "710.0", "214.0", "800.70", "01.14", "780.39", "366.9", "E855.6", "715.90", "402.91", "482.0", "793.11", "V44.0", "355.9", "250.92", "444.81", "742.9", "284.11", "274.9", "150.3", "555.9", "456.20", "486", "707.05", "427.0", "415.19", "807.04", "533.40", "39.75", "E858.8", "530.81", "042", "516.0", "041.09", "96.71", "487.0", "E950.4", "965.4", "736.79", "263.9", "V17.3", "070.32", "278.01", "348.9", "920", "428.43", "702.19", "401.1", "595.0", "453.42", "288.66", "E881.0", "732.1", "202.10", "305.62", "36.12", "158.8", "733.00", "31.1", "535.50", "781.8", "934.8", "518.83", "251.2", "79.35", "276.8", "V45.86", "331.4", "491.20", "301.83", "99.10", "268.9", "438.13", "39.72", "410.91", "492.8", "535.30", "569.85", "V15.51", "99.60", "369.66", "198.89", "99.14", "135", "355.8", "428.31", "427.31", "481", "719.7", "784.2", "799.02", "250.61", "89.45", "599.70", "V45.11", "426.3", "17.35", "967.8", "560.1", "255.41", "99.62", "V15.81", "570", "455.8", "87.54", "567.22", "36.15", "00.17", "276.50", "715.16", "154.0", "296.7", "V54.17", "285.21", "99.15", "V10.44", "35.96", "738.4", "V43.65", "428.9", "198.5", "425.7", "443.81", "E878.8", "V10.01", "V10.83", "429.4", "403.11", "995.1", "290.10", "289.50", "444.21", "585.4", "344.04", "038.12", "728.88", "E816.0", "305.1", "997.39", "553.21", "V54.12", "372.30", "441.7", "280.9", "282.40", "567.23", "511.9", "794.5", "516.31", "726.0", "252.08", "348.5", "410.51", "443.9", "E850.1", "507.0", "823.01", "477.8", "478.6", "577.0", "153.2", "200.00", "202.80", "786.2", "36.06", "952.09", "493.22", "307.47", "362.01", "600.00", "458.8", "V42.82", "V45.73", "440.23", "36.13", "112.1", "356.8", "171.5", "593.4", "805.2", "276.52", "540.1", "560.9", "577.2", "813.80", "824.8", "V49.84", "394.2", "782.1", "576.1", "155.0", "456.21", "346.81", "V88.01", "824.2", "815.02", "296.60", "724.02", "V45.01", "433.10", "524.61", "038.9", "724.00", "327.23", "51.98", "V12.55", "250.12", "86.04", "289.59", "536.8", "585.9", "403.00", "296.90", "602.3", "E853.2", "537.89", "365.04", "V58.67", "289.84", "789.1", "V66.7", "112.9", "427.81", "054.8", "83.89", "553.3", "273.8", "V45.72", "340", "784.7", "428.32", "785.9", "185", "441.03", "903.01", "995.27", "574.50", "009.1", "998.81", "693.0", "39.74", "572.0", "428.21", "446.5", "041.3", "338.4", "566", "362.50", "414.00", "86.59", "65.63", "112.84", "V46.2", "432.1", "530.84", "886.0", "287.5", "305.00"]
|