Upload 21 files
Browse files- app.py +159 -0
- lib/.DS_Store +0 -0
- lib/__init__.py +24 -0
- lib/component/__init__.py +16 -0
- lib/component/criterion.py +332 -0
- lib/component/likelihood.py +107 -0
- lib/component/loss.py +248 -0
- lib/component/net.py +624 -0
- lib/component/optimizer.py +34 -0
- lib/dataloader.py +400 -0
- lib/framework.py +373 -0
- lib/logger.py +71 -0
- lib/metrics.py +623 -0
- lib/options.py +655 -0
- parameters.json +36 -0
- requirements.txt +5 -0
- sample/.DS_Store +0 -0
- sample/sample_AP_inverted.png +0 -0
- sample/sample_PA_right.png +0 -0
- sample/sample_lateral_upright.png +0 -0
- weight_epoch-011_best.pt +3 -0
app.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torchvision.transforms as T
|
| 8 |
+
|
| 9 |
+
from lib.framework import create_model
|
| 10 |
+
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
|
| 11 |
+
from lib.dataloader import ImageMixin
|
| 12 |
+
|
| 13 |
+
# ===========================================
|
| 14 |
+
# 1) パスなど(修正があれば適宜変更)
|
| 15 |
+
# ===========================================
|
| 16 |
+
test_weight = './weight_epoch-011_best.pt'
|
| 17 |
+
parameter = './parameters.json'
|
| 18 |
+
|
| 19 |
+
# ===========================================
|
| 20 |
+
# 2) クラスラベルの定義
|
| 21 |
+
# - 撮像方向(label_APorPA)と回転方向(label_round)のクラス
|
| 22 |
+
# ===========================================
|
| 23 |
+
LABEL_APorPA = [
|
| 24 |
+
"AP", # class 0
|
| 25 |
+
"PA" # class 1
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
LABEL_ROUND = [
|
| 29 |
+
"0° Rotation", # class 0
|
| 30 |
+
"90° Rotation", # class 1
|
| 31 |
+
"180° Rotation", # class 2
|
| 32 |
+
"270° Rotation" # class 3
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
# ===========================================
|
| 36 |
+
# 3) 前処理用の ImageHandlerクラス
|
| 37 |
+
# - 画像が既に256×256前提。Resizeはコメントアウトで残す
|
| 38 |
+
# ===========================================
|
| 39 |
+
class ImageHandler(ImageMixin):
|
| 40 |
+
def __init__(self, params):
|
| 41 |
+
self.params = params
|
| 42 |
+
# ここでリサイズは省略(推論が重いので)
|
| 43 |
+
# 入力画像は既に256×256であることを想定
|
| 44 |
+
self.transform = T.Compose([
|
| 45 |
+
# T.Resize((256, 256)), # コメントアウト: 画像を256×256にリサイズ
|
| 46 |
+
T.ToTensor(), # Tensor化 (0~1, shape: C,H,W)
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
def set_image(self, image):
|
| 50 |
+
# PIL画像 -> transform -> バッチ次元を付ける
|
| 51 |
+
image = self.transform(image)
|
| 52 |
+
image = {'image': image.unsqueeze(0)}
|
| 53 |
+
return image
|
| 54 |
+
|
| 55 |
+
# ===========================================
|
| 56 |
+
# 4) パラメータのロード
|
| 57 |
+
# ===========================================
|
| 58 |
+
def load_parameter(parameter):
|
| 59 |
+
_args = ParamSet()
|
| 60 |
+
params = _retrieve_parameter(parameter)
|
| 61 |
+
for _param, _arg in params.items():
|
| 62 |
+
setattr(_args, _param, _arg)
|
| 63 |
+
|
| 64 |
+
# 推論用に書き換え (学習関連は無効化または無視)
|
| 65 |
+
_args.augmentation = 'no'
|
| 66 |
+
_args.sampler = 'no'
|
| 67 |
+
_args.pretrained = False
|
| 68 |
+
_args.mlp = None
|
| 69 |
+
_args.net = _args.model
|
| 70 |
+
_args.device = torch.device('cpu')
|
| 71 |
+
|
| 72 |
+
args_model = _dispatch_by_group(_args, 'model')
|
| 73 |
+
args_dataloader = _dispatch_by_group(_args, 'dataloader')
|
| 74 |
+
return args_model, args_dataloader
|
| 75 |
+
|
| 76 |
+
args_model, args_dataloader = load_parameter(parameter)
|
| 77 |
+
|
| 78 |
+
# ===========================================
|
| 79 |
+
# 5) モデルを作成し学習済み重みをロード
|
| 80 |
+
# ===========================================
|
| 81 |
+
model = create_model(args_model)
|
| 82 |
+
print(f"Load weight: {test_weight}")
|
| 83 |
+
model.load_weight(test_weight)
|
| 84 |
+
model.eval() # 推論モード
|
| 85 |
+
|
| 86 |
+
# ===========================================
|
| 87 |
+
# 6) 推論関数
|
| 88 |
+
# ===========================================
|
| 89 |
+
def classify_APorPA_and_round(image):
|
| 90 |
+
"""
|
| 91 |
+
モデルが以下を出力する想定:
|
| 92 |
+
outputs["label_APorPA"] -> shape=[1, 2] (2クラス: AP/PA)
|
| 93 |
+
outputs["label_round"] -> shape=[1, 4] (4クラス: 0°, 90°, 180°, 270°)
|
| 94 |
+
"""
|
| 95 |
+
image_handler = ImageHandler(args_dataloader)
|
| 96 |
+
image_tensor = image_handler.set_image(image)
|
| 97 |
+
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
outputs = model(image_tensor)
|
| 100 |
+
|
| 101 |
+
# --- label_APorPA ---
|
| 102 |
+
if "label_APorPA" not in outputs:
|
| 103 |
+
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 104 |
+
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
|
| 105 |
+
|
| 106 |
+
scores_APorPA = outputs["label_APorPA"] # shape=[1,2]想定
|
| 107 |
+
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
|
| 108 |
+
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
|
| 109 |
+
|
| 110 |
+
# --- label_round ---
|
| 111 |
+
if "label_round" not in outputs:
|
| 112 |
+
print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 113 |
+
return predicted_APorPA, "ERROR: Missing 'label_round'"
|
| 114 |
+
|
| 115 |
+
scores_round = outputs["label_round"] # shape=[1,4]想定
|
| 116 |
+
pred_round_idx = torch.argmax(scores_round, dim=1).item()
|
| 117 |
+
predicted_round = LABEL_ROUND[pred_round_idx]
|
| 118 |
+
|
| 119 |
+
return predicted_APorPA, predicted_round
|
| 120 |
+
|
| 121 |
+
# ===========================================
|
| 122 |
+
# 7) Gradio UI
|
| 123 |
+
# ===========================================
|
| 124 |
+
html_content = """
|
| 125 |
+
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
|
| 126 |
+
<h3>Chest X-ray: AP/PA & Rotation Classification</h3>
|
| 127 |
+
<p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
|
| 128 |
+
<p>胸部レントゲン画像に対して、撮像方向(AP or PA)と回転方向(0°, 90°, 180°, 270°)を同時に推定します。</p>
|
| 129 |
+
</div>
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
|
| 133 |
+
gr.HTML("<div style='text-align:center'><h2>Chest X-ray AP/PA & Rotation Classification</h2></div>")
|
| 134 |
+
gr.HTML(html_content)
|
| 135 |
+
|
| 136 |
+
with gr.Row():
|
| 137 |
+
input_image = gr.Image(type="pil", image_mode="L")
|
| 138 |
+
output_APorPA = gr.Label(label="Predicted AP or PA")
|
| 139 |
+
output_round = gr.Label(label="Predicted Rotation")
|
| 140 |
+
|
| 141 |
+
send_btn = gr.Button("Inference")
|
| 142 |
+
send_btn.click(
|
| 143 |
+
fn=classify_APorPA_and_round,
|
| 144 |
+
inputs=input_image,
|
| 145 |
+
outputs=[output_APorPA, output_round]
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
with gr.Row():
|
| 149 |
+
# サンプルファイルは実際のパスに置き換えてください
|
| 150 |
+
gr.Examples(
|
| 151 |
+
examples=[
|
| 152 |
+
'./sample/sample_AP_inverted.png',
|
| 153 |
+
'./sample/sample_PA_right.png',
|
| 154 |
+
'./sample/sample_lateral_upright.png'
|
| 155 |
+
],
|
| 156 |
+
inputs=input_image
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
demo.launch(debug=True)
|
lib/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
lib/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from .options import (
|
| 5 |
+
ParamSet,
|
| 6 |
+
set_options,
|
| 7 |
+
save_parameter,
|
| 8 |
+
print_parameter
|
| 9 |
+
)
|
| 10 |
+
from .dataloader import create_dataloader
|
| 11 |
+
from .framework import create_model
|
| 12 |
+
from .metrics import set_eval
|
| 13 |
+
from .logger import BaseLogger
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'ParamSet',
|
| 17 |
+
'set_options',
|
| 18 |
+
'print_parameter',
|
| 19 |
+
'save_parameter',
|
| 20 |
+
'create_dataloader',
|
| 21 |
+
'create_model',
|
| 22 |
+
'set_eval',
|
| 23 |
+
'BaseLogger'
|
| 24 |
+
]
|
lib/component/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from .net import create_net
|
| 5 |
+
from .criterion import set_criterion
|
| 6 |
+
from .optimizer import set_optimizer
|
| 7 |
+
from .loss import set_loss_store
|
| 8 |
+
from .likelihood import set_likelihood
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
'create_net',
|
| 12 |
+
'set_criterion',
|
| 13 |
+
'set_optimizer',
|
| 14 |
+
'set_loss_store',
|
| 15 |
+
'set_likelihood'
|
| 16 |
+
]
|
lib/component/criterion.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from typing import Dict, Union
|
| 7 |
+
|
| 8 |
+
# Alias of typing
|
| 9 |
+
# eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}}
|
| 10 |
+
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RMSELoss(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Class to calculate RMSE.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, eps: float = 1e-7) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Args:
|
| 20 |
+
eps (float, optional): value to avoid 0. Defaults to 1e-7.
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.mse = nn.MSELoss()
|
| 24 |
+
self.eps = eps
|
| 25 |
+
|
| 26 |
+
def forward(self, yhat: float, y: float) -> torch.FloatTensor:
|
| 27 |
+
"""
|
| 28 |
+
Calculate RMSE.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
yhat (float): prediction value
|
| 32 |
+
y (float): ground truth value
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
float: RMSE
|
| 36 |
+
"""
|
| 37 |
+
_loss = self.mse(yhat, y) + self.eps
|
| 38 |
+
return torch.sqrt(_loss)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Regularization:
|
| 42 |
+
"""
|
| 43 |
+
Class to calculate regularization loss.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
object (object): object
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, order: int, weight_decay: float) -> None:
|
| 49 |
+
"""
|
| 50 |
+
The initialization of Regularization class.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
order: (int) norm order number
|
| 54 |
+
weight_decay: (float) weight decay rate
|
| 55 |
+
"""
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.order = order
|
| 58 |
+
self.weight_decay = weight_decay
|
| 59 |
+
|
| 60 |
+
def __call__(self, network: nn.Module) -> torch.FloatTensor:
|
| 61 |
+
""""
|
| 62 |
+
Calculates regularization(self.order) loss for network.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model: (torch.nn.Module object)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
torch.FloatTensor: the regularization(self.order) loss
|
| 69 |
+
"""
|
| 70 |
+
reg_loss = 0
|
| 71 |
+
for name, w in network.named_parameters():
|
| 72 |
+
if 'weight' in name:
|
| 73 |
+
reg_loss = reg_loss + torch.norm(w, p=self.order)
|
| 74 |
+
reg_loss = self.weight_decay * reg_loss
|
| 75 |
+
return reg_loss
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class NegativeLogLikelihood(nn.Module):
|
| 79 |
+
"""
|
| 80 |
+
Class to calculate RMSE.
|
| 81 |
+
"""
|
| 82 |
+
def __init__(self, device: torch.device) -> None:
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
device (torch.device): device
|
| 86 |
+
"""
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.L2_reg = 0.05
|
| 89 |
+
self.reg = Regularization(order=2, weight_decay=self.L2_reg)
|
| 90 |
+
self.device = device
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
output: torch.FloatTensor,
|
| 95 |
+
label: torch.IntTensor,
|
| 96 |
+
periods: torch.FloatTensor,
|
| 97 |
+
network: nn.Module
|
| 98 |
+
) -> torch.FloatTensor:
|
| 99 |
+
"""
|
| 100 |
+
Calculates Negative Log Likelihood.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
output (torch.FloatTensor): prediction value, ie risk prediction
|
| 104 |
+
label (torch.IntTensor): occurrence of event
|
| 105 |
+
periods (torch.FloatTensor): period
|
| 106 |
+
network (nn.Network): network
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
torch.FloatTensor: Negative Log Likelihood
|
| 110 |
+
"""
|
| 111 |
+
mask = torch.ones(periods.shape[0], periods.shape[0]).to(self.device) # output and mask should be on the same device.
|
| 112 |
+
mask[(periods.T - periods) > 0] = 0
|
| 113 |
+
|
| 114 |
+
_loss = torch.exp(output) * mask
|
| 115 |
+
# Note: torch.sum(_loss, dim=0) possibly returns nan, in particular MLP.
|
| 116 |
+
_loss = torch.sum(_loss, dim=0) / torch.sum(mask, dim=0)
|
| 117 |
+
_loss = torch.log(_loss).reshape(-1, 1)
|
| 118 |
+
num_occurs = torch.sum(label)
|
| 119 |
+
|
| 120 |
+
if num_occurs.item() == 0.0:
|
| 121 |
+
loss = torch.tensor([1e-7], requires_grad=True).to(self.device) # To avoid zero division, set small value as loss
|
| 122 |
+
return loss
|
| 123 |
+
else:
|
| 124 |
+
neg_log_loss = -torch.sum((output - _loss) * label) / num_occurs
|
| 125 |
+
l2_loss = self.reg(network)
|
| 126 |
+
loss = neg_log_loss + l2_loss
|
| 127 |
+
return loss
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ClsCriterion:
|
| 131 |
+
"""
|
| 132 |
+
Class of criterion for classification.
|
| 133 |
+
"""
|
| 134 |
+
def __init__(self, device: torch.device = None) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Set CrossEntropyLoss.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
device (torch.device): device
|
| 140 |
+
"""
|
| 141 |
+
self.device = device
|
| 142 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 143 |
+
|
| 144 |
+
def __call__(
|
| 145 |
+
self,
|
| 146 |
+
outputs: Dict[str, torch.FloatTensor],
|
| 147 |
+
labels: Dict[str, LabelDict]
|
| 148 |
+
) -> Dict[str, torch.FloatTensor]:
|
| 149 |
+
"""
|
| 150 |
+
Calculate loss.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
outputs (Dict[str, torch.FloatTensor], optional): output
|
| 154 |
+
labels (Dict[str, LabelDict]): labels
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Dict[str, torch.FloatTensor]: loss for each label and their total loss
|
| 158 |
+
|
| 159 |
+
# No reshape and no cast:
|
| 160 |
+
output: [64, 2]: torch.float32
|
| 161 |
+
label: [64] : torch.int64
|
| 162 |
+
label.dtype should be torch.int64, otherwise nn.CrossEntropyLoss() causes error.
|
| 163 |
+
|
| 164 |
+
eg.
|
| 165 |
+
outputs = {'label_A': [[0.8, 0.2], ...] 'label_B': [[0.7, 0.3]], ...}
|
| 166 |
+
labels = { 'labels': {'label_A: 1: [1, 1, 0, ...], 'label_B': [0, 0, 1, ...], ...} }
|
| 167 |
+
|
| 168 |
+
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
|
| 169 |
+
"""
|
| 170 |
+
_labels = labels['labels']
|
| 171 |
+
|
| 172 |
+
# loss for each label and total of their losses
|
| 173 |
+
losses = dict()
|
| 174 |
+
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
|
| 175 |
+
for label_name in labels['labels'].keys():
|
| 176 |
+
_output = outputs[label_name]
|
| 177 |
+
_label = _labels[label_name]
|
| 178 |
+
_label_loss = self.criterion(_output, _label)
|
| 179 |
+
losses[label_name] = _label_loss
|
| 180 |
+
losses['total'] = torch.add(losses['total'], _label_loss)
|
| 181 |
+
return losses
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class RegCriterion:
|
| 185 |
+
"""
|
| 186 |
+
Class of criterion for regression.
|
| 187 |
+
"""
|
| 188 |
+
def __init__(self, criterion_name: str = None, device: torch.device = None) -> None:
|
| 189 |
+
"""
|
| 190 |
+
Set MSE, RMSE or MAE.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
criterion_name (str): 'MSE', 'RMSE', or 'MAE'
|
| 194 |
+
device (torch.device): device
|
| 195 |
+
"""
|
| 196 |
+
self.device = device
|
| 197 |
+
|
| 198 |
+
if criterion_name == 'MSE':
|
| 199 |
+
self.criterion = nn.MSELoss()
|
| 200 |
+
elif criterion_name == 'RMSE':
|
| 201 |
+
self.criterion = RMSELoss()
|
| 202 |
+
elif criterion_name == 'MAE':
|
| 203 |
+
self.criterion = nn.L1Loss()
|
| 204 |
+
else:
|
| 205 |
+
raise ValueError(f"Invalid criterion for regression: {criterion_name}.")
|
| 206 |
+
|
| 207 |
+
def __call__(
|
| 208 |
+
self,
|
| 209 |
+
outputs: Dict[str, torch.FloatTensor],
|
| 210 |
+
labels: Dict[str, LabelDict]
|
| 211 |
+
) -> Dict[str, torch.FloatTensor]:
|
| 212 |
+
"""
|
| 213 |
+
Calculate loss.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
Args:
|
| 217 |
+
outputs (Dict[str, torch.FloatTensor], optional): output
|
| 218 |
+
labels (Dict[str, LabelDict]): labels
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Dict[str, torch.FloatTensor]: loss for each label and their total loss
|
| 222 |
+
|
| 223 |
+
# Reshape and cast
|
| 224 |
+
output: [64, 1] -> [64]: torch.float32
|
| 225 |
+
label: [64]: torch.float64 -> torch.float32
|
| 226 |
+
# label.dtype should be torch.float32, otherwise cannot backward.
|
| 227 |
+
|
| 228 |
+
eg.
|
| 229 |
+
outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
|
| 230 |
+
labels = {'labels': {'label_A: 1: [10, 9, ...], 'label_B': [12, 17,], ...}}
|
| 231 |
+
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
|
| 232 |
+
"""
|
| 233 |
+
_outputs = {label_name: _output.squeeze() for label_name, _output in outputs.items()}
|
| 234 |
+
_labels = {label_name: _label.to(torch.float32) for label_name, _label in labels['labels'].items()}
|
| 235 |
+
|
| 236 |
+
# loss for each label and total of their losses
|
| 237 |
+
losses = dict()
|
| 238 |
+
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
|
| 239 |
+
for label_name in labels['labels'].keys():
|
| 240 |
+
_output = _outputs[label_name]
|
| 241 |
+
_label = _labels[label_name]
|
| 242 |
+
_label_loss = self.criterion(_output, _label)
|
| 243 |
+
losses[label_name] = _label_loss
|
| 244 |
+
losses['total'] = torch.add(losses['total'], _label_loss)
|
| 245 |
+
return losses
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class DeepSurvCriterion:
|
| 249 |
+
"""
|
| 250 |
+
Class of criterion for deepsurv.
|
| 251 |
+
"""
|
| 252 |
+
def __init__(self, device: torch.device = None) -> None:
|
| 253 |
+
"""
|
| 254 |
+
Set NegativeLogLikelihood.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
device (torch.device, optional): device
|
| 258 |
+
"""
|
| 259 |
+
self.device = device
|
| 260 |
+
self.criterion = NegativeLogLikelihood(self.device).to(self.device)
|
| 261 |
+
|
| 262 |
+
def __call__(
|
| 263 |
+
self,
|
| 264 |
+
outputs: Dict[str, torch.FloatTensor],
|
| 265 |
+
labels: Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 266 |
+
) -> Dict[str, torch.FloatTensor]:
|
| 267 |
+
"""
|
| 268 |
+
Calculate loss.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
outputs (Dict[str, torch.FloatTensor], optional): output
|
| 272 |
+
labels (Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]): labels, periods, and network
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Dict[str, torch.FloatTensor]: loss for each label and their total loss
|
| 276 |
+
|
| 277 |
+
# Reshape and no cast
|
| 278 |
+
output: [64, 1]: torch.float32
|
| 279 |
+
label: [64] -> [64, 1]: torch.int64
|
| 280 |
+
period: [64] -> [64, 1]: torch.float32
|
| 281 |
+
|
| 282 |
+
eg.
|
| 283 |
+
outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
|
| 284 |
+
labels = {
|
| 285 |
+
'labels': {'label_A: 1: [1, 0, 1, ...] },
|
| 286 |
+
'periods': [5, 10, 7, ...],
|
| 287 |
+
'network': network
|
| 288 |
+
}
|
| 289 |
+
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
|
| 290 |
+
"""
|
| 291 |
+
_labels = {label_name: _label.reshape(-1, 1) for label_name, _label in labels['labels'].items()}
|
| 292 |
+
_periods = labels['periods'].reshape(-1, 1)
|
| 293 |
+
_network = labels['network']
|
| 294 |
+
|
| 295 |
+
# loss for each label and total of their losses
|
| 296 |
+
losses = dict()
|
| 297 |
+
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
|
| 298 |
+
for label_name in labels['labels'].keys():
|
| 299 |
+
_output = outputs[label_name]
|
| 300 |
+
_label = _labels[label_name]
|
| 301 |
+
_label_loss = self.criterion(_output, _label, _periods, _network)
|
| 302 |
+
losses[label_name] = _label_loss
|
| 303 |
+
losses['total'] = torch.add(losses['total'], _label_loss)
|
| 304 |
+
return losses
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def set_criterion(
|
| 308 |
+
criterion_name: str,
|
| 309 |
+
device: torch.device
|
| 310 |
+
) -> Union[ClsCriterion, RegCriterion, DeepSurvCriterion]:
|
| 311 |
+
"""
|
| 312 |
+
Return criterion class
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
criterion_name (str): criterion name
|
| 316 |
+
device (torch.device): device
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
Union[ClsCriterion, RegCriterion, DeepSurvCriterion]: criterion class
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
if criterion_name == 'CEL':
|
| 323 |
+
return ClsCriterion(device=device)
|
| 324 |
+
|
| 325 |
+
elif criterion_name in ['MSE', 'RMSE', 'MAE']:
|
| 326 |
+
return RegCriterion(criterion_name=criterion_name, device=device)
|
| 327 |
+
|
| 328 |
+
elif criterion_name == 'NLL':
|
| 329 |
+
return DeepSurvCriterion(device=device)
|
| 330 |
+
|
| 331 |
+
else:
|
| 332 |
+
raise ValueError(f"Invalid criterion: {criterion_name}.")
|
lib/component/likelihood.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Likelihood:
|
| 10 |
+
"""
|
| 11 |
+
Class for making likelihood.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, task: str, num_outputs_for_label: Dict[str, int]) -> None:
|
| 14 |
+
"""
|
| 15 |
+
Args:
|
| 16 |
+
task (str): task
|
| 17 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
| 18 |
+
"""
|
| 19 |
+
self.task = task
|
| 20 |
+
self.num_outputs_for_label = num_outputs_for_label
|
| 21 |
+
self.base_column_list = self._set_base_columns(self.task)
|
| 22 |
+
self.pred_column_list = self._make_pred_columns(self.task, self.num_outputs_for_label)
|
| 23 |
+
|
| 24 |
+
def _set_base_columns(self, task: str) -> List[str]:
|
| 25 |
+
"""
|
| 26 |
+
Return base columns.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
task (str): task
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List[str]: base columns except columns of label and prediction
|
| 33 |
+
"""
|
| 34 |
+
if (task == 'classification') or (task == 'regression'):
|
| 35 |
+
base_columns = ['uniqID', 'group', 'imgpath', 'split']
|
| 36 |
+
return base_columns
|
| 37 |
+
elif task == 'deepsurv':
|
| 38 |
+
base_columns = ['uniqID', 'group', 'imgpath', 'split', 'periods']
|
| 39 |
+
return base_columns
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f"Invalid task: {task}.")
|
| 42 |
+
|
| 43 |
+
def _make_pred_columns(self, task: str, num_outputs_for_label: Dict[str, int]) -> Dict[str, List[str]]:
|
| 44 |
+
"""
|
| 45 |
+
Make column names of predictions with label name and its number of classes.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
task (str): task
|
| 49 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Dict[str, List[str]]: label and list of columns of predictions with its class number
|
| 53 |
+
|
| 54 |
+
eg.
|
| 55 |
+
{label_A: 2, label_B: 2} -> {label_A: [pred_label_A_0, pred_label_A_1], label_B: [pred_label_B_0, pred_label_B_1]}
|
| 56 |
+
{label_A: 1, label_B: 1} -> {label_A: [pred_label_A], label_B: [pred_label_B]}
|
| 57 |
+
"""
|
| 58 |
+
pred_columns = dict()
|
| 59 |
+
if task == 'classification':
|
| 60 |
+
for label_name, num_classes in num_outputs_for_label.items():
|
| 61 |
+
pred_columns[label_name] = ['pred_' + label_name + '_' + str(i) for i in range(num_classes)]
|
| 62 |
+
return pred_columns
|
| 63 |
+
elif (task == 'regression') or (task == 'deepsurv'):
|
| 64 |
+
for label_name, num_classes in num_outputs_for_label.items():
|
| 65 |
+
pred_columns[label_name] = ['pred_' + label_name]
|
| 66 |
+
return pred_columns
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Invalid task: {task}.")
|
| 69 |
+
|
| 70 |
+
def make_format(self, data: Dict, output: Dict[str, torch.Tensor]) -> pd.DataFrame:
|
| 71 |
+
"""
|
| 72 |
+
Make a new DataFrame of likelihood every batch.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
data (Dict): batch data from dataloader
|
| 76 |
+
output (Dict[str, torch.Tensor]): output of model
|
| 77 |
+
"""
|
| 78 |
+
_likelihood = {column_name: data[column_name] for column_name in self.base_column_list}
|
| 79 |
+
df_likelihood = pd.DataFrame(_likelihood)
|
| 80 |
+
|
| 81 |
+
if any(data['labels']):
|
| 82 |
+
for label_name, pred in output.items():
|
| 83 |
+
_df_label = pd.DataFrame({label_name: data['labels'][label_name].tolist()})
|
| 84 |
+
pred = pred.to('cpu').detach().numpy().copy()
|
| 85 |
+
_df_pred = pd.DataFrame(pred, columns=self.pred_column_list[label_name])
|
| 86 |
+
df_likelihood = pd.concat([df_likelihood, _df_label, _df_pred], axis=1)
|
| 87 |
+
return df_likelihood
|
| 88 |
+
else:
|
| 89 |
+
for label_name, pred in output.items():
|
| 90 |
+
pred = pred.to('cpu').detach().numpy().copy()
|
| 91 |
+
_df_pred = pd.DataFrame(pred, columns=self.pred_column_list[label_name])
|
| 92 |
+
df_likelihood = pd.concat([df_likelihood, _df_pred], axis=1)
|
| 93 |
+
return df_likelihood
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def set_likelihood(task: str, num_outputs_for_label: Dict[str, int]) -> Likelihood:
|
| 97 |
+
"""
|
| 98 |
+
Set likelihood.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
task (str): task
|
| 102 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Likelihood: instance of class Likelihood
|
| 106 |
+
"""
|
| 107 |
+
return Likelihood(task, num_outputs_for_label)
|
lib/component/loss.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import torch
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from ..logger import BaseLogger
|
| 8 |
+
from typing import List, Dict, Union
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = BaseLogger.get_logger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LabelLoss:
|
| 15 |
+
"""
|
| 16 |
+
Class to store loss for every bash and epoch loss of each label.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
# Accumulate batch_loss(=loss * batch_size)
|
| 20 |
+
self.train_batch_loss = 0.0
|
| 21 |
+
self.val_batch_loss = 0.0
|
| 22 |
+
|
| 23 |
+
# epoch_loss = batch_loss / dataset_size
|
| 24 |
+
self.train_epoch_loss = [] # List[float]
|
| 25 |
+
self.val_epoch_loss = [] # List[float]
|
| 26 |
+
|
| 27 |
+
self.best_val_loss = None # float
|
| 28 |
+
self.best_epoch = None # int
|
| 29 |
+
self.is_val_loss_updated = None # bool
|
| 30 |
+
|
| 31 |
+
def get_loss(self, phase: str, target: str) -> Union[float, List[float]]:
|
| 32 |
+
"""
|
| 33 |
+
Return loss depending on phase and target
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
phase (str): 'train' or 'val'
|
| 37 |
+
target (str): 'batch' or 'epoch'
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Union[float, List[float]]: batch_loss or epoch_loss
|
| 41 |
+
"""
|
| 42 |
+
_target = phase + '_' + target + '_loss'
|
| 43 |
+
return getattr(self, _target)
|
| 44 |
+
|
| 45 |
+
def store_batch_loss(self, phase: str, new_batch_loss: torch.FloatTensor, batch_size: int) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Add new batch loss to previous one for phase by multiplying by batch_size.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
phase (str): 'train' or 'val'
|
| 51 |
+
new_batch_loss (torch.FloatTensor): batch loss calculated by criterion
|
| 52 |
+
batch_size (int): batch size
|
| 53 |
+
"""
|
| 54 |
+
_new = new_batch_loss.item() * batch_size # torch.FloatTensor -> float
|
| 55 |
+
_prev = self.get_loss(phase, 'batch')
|
| 56 |
+
_added = _prev + _new
|
| 57 |
+
_target = phase + '_' + 'batch_loss'
|
| 58 |
+
setattr(self, _target, _added)
|
| 59 |
+
|
| 60 |
+
def append_epoch_loss(self, phase: str, new_epoch_loss: float) -> None:
|
| 61 |
+
"""
|
| 62 |
+
Append epoch loss depending on phase and target
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
phase (str): 'train' or 'val'
|
| 66 |
+
new_epoch_loss (float): batch loss or epoch loss
|
| 67 |
+
"""
|
| 68 |
+
_target = phase + '_' + 'epoch_loss'
|
| 69 |
+
getattr(self, _target).append(new_epoch_loss)
|
| 70 |
+
|
| 71 |
+
def get_latest_epoch_loss(self, phase: str) -> float:
|
| 72 |
+
"""
|
| 73 |
+
Return the latest loss of phase.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
phase (str): train or val
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
float: the latest loss
|
| 80 |
+
"""
|
| 81 |
+
return self.get_loss(phase, 'epoch')[-1]
|
| 82 |
+
|
| 83 |
+
def update_best_val_loss(self, at_epoch: int = None) -> None:
|
| 84 |
+
"""
|
| 85 |
+
Update val_epoch_loss is the best.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
at_epoch (int): epoch when checked
|
| 89 |
+
"""
|
| 90 |
+
_latest_val_loss = self.get_latest_epoch_loss('val')
|
| 91 |
+
|
| 92 |
+
if at_epoch == 1:
|
| 93 |
+
self.best_val_loss = _latest_val_loss
|
| 94 |
+
self.best_epoch = at_epoch
|
| 95 |
+
self.is_val_loss_updated = True
|
| 96 |
+
else:
|
| 97 |
+
# When at_epoch > 1
|
| 98 |
+
if _latest_val_loss < self.best_val_loss:
|
| 99 |
+
self.best_val_loss = _latest_val_loss
|
| 100 |
+
self.best_epoch = at_epoch
|
| 101 |
+
self.is_val_loss_updated = True
|
| 102 |
+
else:
|
| 103 |
+
self.is_val_loss_updated = False
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class LossStore:
|
| 107 |
+
"""
|
| 108 |
+
Class for calculating loss and store it.
|
| 109 |
+
"""
|
| 110 |
+
def __init__(self, label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> None:
|
| 111 |
+
"""
|
| 112 |
+
Args:
|
| 113 |
+
label_list (List[str]): list of internal labels
|
| 114 |
+
num_epochs (int) : number of epochs
|
| 115 |
+
dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
|
| 116 |
+
"""
|
| 117 |
+
self.label_list = label_list
|
| 118 |
+
self.num_epochs = num_epochs
|
| 119 |
+
self.dataset_info = dataset_info
|
| 120 |
+
|
| 121 |
+
# Added a special label 'total' to store total of losses of all labels.
|
| 122 |
+
self.label_losses = {label_name: LabelLoss() for label_name in self.label_list + ['total']}
|
| 123 |
+
|
| 124 |
+
def store(self, phase: str, losses: Dict[str, torch.FloatTensor], batch_size: int = None) -> None:
|
| 125 |
+
"""
|
| 126 |
+
Store label-wise batch losses of phase to previous one.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
phase (str): 'train' or 'val'
|
| 130 |
+
losses (Dict[str, torch.FloatTensor]): loss for each label calculated by criterion
|
| 131 |
+
batch_size (int): batch size
|
| 132 |
+
|
| 133 |
+
# Note:
|
| 134 |
+
self.loss_stores['total'] is already total of losses of all label, which is calculated in criterion.py,
|
| 135 |
+
therefore, it is OK just to multiply by batch_size. This is done in add_batch_loss().
|
| 136 |
+
"""
|
| 137 |
+
for label_name in self.label_list + ['total']:
|
| 138 |
+
_new_batch_loss = losses[label_name]
|
| 139 |
+
self.label_losses[label_name].store_batch_loss(phase, _new_batch_loss, batch_size)
|
| 140 |
+
|
| 141 |
+
def cal_epoch_loss(self, at_epoch: int = None) -> None:
|
| 142 |
+
"""
|
| 143 |
+
Calculate epoch loss for each phase all at once.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
at_epoch (int): epoch number
|
| 147 |
+
"""
|
| 148 |
+
# For each label
|
| 149 |
+
for label_name in self.label_list:
|
| 150 |
+
for phase in ['train', 'val']:
|
| 151 |
+
_batch_loss = self.label_losses[label_name].get_loss(phase, 'batch')
|
| 152 |
+
_dataset_size = self.dataset_info[phase]
|
| 153 |
+
_new_epoch_loss = _batch_loss / _dataset_size
|
| 154 |
+
self.label_losses[label_name].append_epoch_loss(phase, _new_epoch_loss)
|
| 155 |
+
|
| 156 |
+
# For total, average by dataset_size and the number of labels.
|
| 157 |
+
for phase in ['train', 'val']:
|
| 158 |
+
_batch_loss = self.label_losses['total'].get_loss(phase, 'batch')
|
| 159 |
+
_dataset_size = self.dataset_info[phase]
|
| 160 |
+
_new_epoch_loss = _batch_loss / (_dataset_size * len(self.label_list))
|
| 161 |
+
self.label_losses['total'].append_epoch_loss(phase, _new_epoch_loss)
|
| 162 |
+
|
| 163 |
+
# Update val_best_loss and best_epoch.
|
| 164 |
+
for label_name in self.label_list + ['total']:
|
| 165 |
+
self.label_losses[label_name].update_best_val_loss(at_epoch=at_epoch)
|
| 166 |
+
|
| 167 |
+
# Initialize batch_loss after calculating epoch loss.
|
| 168 |
+
for label_name in self.label_list + ['total']:
|
| 169 |
+
self.label_losses[label_name].train_batch_loss = 0.0
|
| 170 |
+
self.label_losses[label_name].val_batch_loss = 0.0
|
| 171 |
+
|
| 172 |
+
def is_val_loss_updated(self) -> bool:
|
| 173 |
+
"""
|
| 174 |
+
Check if val_loss of 'total' is updated.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
bool: Updated or not
|
| 178 |
+
"""
|
| 179 |
+
return self.label_losses['total'].is_val_loss_updated
|
| 180 |
+
|
| 181 |
+
def get_best_epoch(self) -> int:
|
| 182 |
+
"""
|
| 183 |
+
Returns best epoch.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
int: best epoch
|
| 187 |
+
"""
|
| 188 |
+
return self.label_losses['total'].best_epoch
|
| 189 |
+
|
| 190 |
+
def print_epoch_loss(self, at_epoch: int = None) -> None:
|
| 191 |
+
"""
|
| 192 |
+
Print train_loss and val_loss for the ith epoch.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
at_epoch (int): epoch number
|
| 196 |
+
"""
|
| 197 |
+
train_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('train')
|
| 198 |
+
val_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('val')
|
| 199 |
+
|
| 200 |
+
_epoch_comm = f"epoch [{at_epoch:>3}/{self.num_epochs:<3}]"
|
| 201 |
+
_train_comm = f"train_loss: {train_epoch_loss :>8.4f}"
|
| 202 |
+
_val_comm = f"val_loss: {val_epoch_loss:>8.4f}"
|
| 203 |
+
_updated_comment = ''
|
| 204 |
+
if (at_epoch > 1) and (self.is_val_loss_updated()):
|
| 205 |
+
_updated_comment = ' Updated best val_loss!'
|
| 206 |
+
comment = _epoch_comm + ', ' + _train_comm + ', ' + _val_comm + _updated_comment
|
| 207 |
+
logger.info(comment)
|
| 208 |
+
|
| 209 |
+
def save_learning_curve(self, save_datetime_dir: str) -> None:
|
| 210 |
+
"""
|
| 211 |
+
Save learning curve.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
save_datetime_dir (str): save_datetime_dir
|
| 215 |
+
"""
|
| 216 |
+
save_dir = Path(save_datetime_dir, 'learning_curve')
|
| 217 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
|
| 219 |
+
for label_name in self.label_list + ['total']:
|
| 220 |
+
_label_loss = self.label_losses[label_name]
|
| 221 |
+
_train_epoch_loss = _label_loss.get_loss('train', 'epoch')
|
| 222 |
+
_val_epoch_loss = _label_loss.get_loss('val', 'epoch')
|
| 223 |
+
|
| 224 |
+
df_label_epoch_loss = pd.DataFrame({
|
| 225 |
+
'train_loss': _train_epoch_loss,
|
| 226 |
+
'val_loss': _val_epoch_loss
|
| 227 |
+
})
|
| 228 |
+
|
| 229 |
+
_best_epoch = str(_label_loss.best_epoch).zfill(3)
|
| 230 |
+
_best_val_loss = f"{_label_loss.best_val_loss:.4f}"
|
| 231 |
+
save_name = 'learning_curve_' + label_name + '_val-best-epoch-' + _best_epoch + '_val-best-loss-' + _best_val_loss + '.csv'
|
| 232 |
+
save_path = Path(save_dir, save_name)
|
| 233 |
+
df_label_epoch_loss.to_csv(save_path, index=False)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def set_loss_store(label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> LossStore:
|
| 237 |
+
"""
|
| 238 |
+
Return class LossStore.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
label_list (List[str]): label list
|
| 242 |
+
num_epochs (int) : number of epochs
|
| 243 |
+
dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
LossStore: LossStore
|
| 247 |
+
"""
|
| 248 |
+
return LossStore(label_list, num_epochs, dataset_info)
|
lib/component/net.py
ADDED
|
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-r
|
| 3 |
+
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torchvision.ops import MLP
|
| 8 |
+
import torchvision.models as models
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseNet:
|
| 13 |
+
"""
|
| 14 |
+
Class to construct network
|
| 15 |
+
"""
|
| 16 |
+
cnn = {
|
| 17 |
+
'ResNet18': models.resnet18,
|
| 18 |
+
'ResNet': models.resnet50,
|
| 19 |
+
'DenseNet': models.densenet161,
|
| 20 |
+
'EfficientNetB0': models.efficientnet_b0,
|
| 21 |
+
'EfficientNetB2': models.efficientnet_b2,
|
| 22 |
+
'EfficientNetB4': models.efficientnet_b4,
|
| 23 |
+
'EfficientNetB6': models.efficientnet_b6,
|
| 24 |
+
'EfficientNetV2s': models.efficientnet_v2_s,
|
| 25 |
+
'EfficientNetV2m': models.efficientnet_v2_m,
|
| 26 |
+
'EfficientNetV2l': models.efficientnet_v2_l,
|
| 27 |
+
'ConvNeXtTiny': models.convnext_tiny,
|
| 28 |
+
'ConvNeXtSmall': models.convnext_small,
|
| 29 |
+
'ConvNeXtBase': models.convnext_base,
|
| 30 |
+
'ConvNeXtLarge': models.convnext_large
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
vit = {
|
| 34 |
+
'ViTb16': models.vit_b_16,
|
| 35 |
+
'ViTb32': models.vit_b_32,
|
| 36 |
+
'ViTl16': models.vit_l_16,
|
| 37 |
+
'ViTl32': models.vit_l_32,
|
| 38 |
+
'ViTH14': models.vit_h_14
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
net = {**cnn, **vit}
|
| 42 |
+
|
| 43 |
+
_classifier = {
|
| 44 |
+
'ResNet': 'fc',
|
| 45 |
+
'DenseNet': 'classifier',
|
| 46 |
+
'EfficientNet': 'classifier',
|
| 47 |
+
'ConvNext': 'classifier',
|
| 48 |
+
'ViT': 'heads'
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
classifier = {
|
| 52 |
+
'ResNet18': _classifier['ResNet'],
|
| 53 |
+
'ResNet': _classifier['ResNet'],
|
| 54 |
+
'DenseNet': _classifier['DenseNet'],
|
| 55 |
+
'EfficientNetB0': _classifier['EfficientNet'],
|
| 56 |
+
'EfficientNetB2': _classifier['EfficientNet'],
|
| 57 |
+
'EfficientNetB4': _classifier['EfficientNet'],
|
| 58 |
+
'EfficientNetB6': _classifier['EfficientNet'],
|
| 59 |
+
'EfficientNetV2s': _classifier['EfficientNet'],
|
| 60 |
+
'EfficientNetV2m': _classifier['EfficientNet'],
|
| 61 |
+
'EfficientNetV2l': _classifier['EfficientNet'],
|
| 62 |
+
'ConvNeXtTiny': _classifier['ConvNext'],
|
| 63 |
+
'ConvNeXtSmall': _classifier['ConvNext'],
|
| 64 |
+
'ConvNeXtBase': _classifier['ConvNext'],
|
| 65 |
+
'ConvNeXtLarge': _classifier['ConvNext'],
|
| 66 |
+
'ViTb16': _classifier['ViT'],
|
| 67 |
+
'ViTb32': _classifier['ViT'],
|
| 68 |
+
'ViTl16': _classifier['ViT'],
|
| 69 |
+
'ViTl32': _classifier['ViT'],
|
| 70 |
+
'ViTH14': _classifier['ViT']
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
mlp_config = {
|
| 74 |
+
'hidden_channels': [256, 256, 256],
|
| 75 |
+
'dropout': 0.2
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
DUMMY = nn.Identity()
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP:
|
| 82 |
+
"""
|
| 83 |
+
Construct MLP.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
mlp_num_inputs (int): the number of input of MLP
|
| 87 |
+
inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
MLP: MLP
|
| 91 |
+
"""
|
| 92 |
+
assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}."
|
| 93 |
+
mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout'])
|
| 94 |
+
return mlp
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module:
|
| 98 |
+
"""
|
| 99 |
+
Modify network to handle gray scale image.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
net_name (str): network name
|
| 103 |
+
net (nn.Module): network itself
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
nn.Module: network available for gray scale
|
| 107 |
+
"""
|
| 108 |
+
if net_name.startswith('ResNet'):
|
| 109 |
+
net.conv1.in_channels = 1
|
| 110 |
+
net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1))
|
| 111 |
+
|
| 112 |
+
elif net_name.startswith('DenseNet'):
|
| 113 |
+
net.features.conv0.in_channels = 1
|
| 114 |
+
net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1))
|
| 115 |
+
|
| 116 |
+
elif net_name.startswith('Efficient'):
|
| 117 |
+
net.features[0][0].in_channels = 1
|
| 118 |
+
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
|
| 119 |
+
|
| 120 |
+
elif net_name.startswith('ConvNeXt'):
|
| 121 |
+
net.features[0][0].in_channels = 1
|
| 122 |
+
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
|
| 123 |
+
|
| 124 |
+
elif net_name.startswith('ViT'):
|
| 125 |
+
net.conv_proj.in_channels = 1
|
| 126 |
+
net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1))
|
| 127 |
+
|
| 128 |
+
else:
|
| 129 |
+
raise ValueError(f"No specified net: {net_name}.")
|
| 130 |
+
return net
|
| 131 |
+
|
| 132 |
+
@classmethod
|
| 133 |
+
def set_net(
|
| 134 |
+
cls,
|
| 135 |
+
net_name: str = None,
|
| 136 |
+
in_channel: int = None,
|
| 137 |
+
vit_image_size: int = None,
|
| 138 |
+
pretrained: bool = None
|
| 139 |
+
) -> nn.Module:
|
| 140 |
+
"""
|
| 141 |
+
Modify network depending on in_channel and vit_image_size.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
net_name (str): network name
|
| 145 |
+
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
|
| 146 |
+
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
|
| 147 |
+
vit_image_size should be power of patch size.
|
| 148 |
+
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
nn.Module: modified network
|
| 152 |
+
"""
|
| 153 |
+
assert net_name in cls.net, f"No specified net: {net_name}."
|
| 154 |
+
if net_name in cls.cnn:
|
| 155 |
+
if pretrained:
|
| 156 |
+
net = cls.cnn[net_name](weights='DEFAULT')
|
| 157 |
+
else:
|
| 158 |
+
net = cls.cnn[net_name]()
|
| 159 |
+
else:
|
| 160 |
+
# When ViT
|
| 161 |
+
# always use pretrained
|
| 162 |
+
net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size)
|
| 163 |
+
|
| 164 |
+
if in_channel == 1:
|
| 165 |
+
net = cls.align_in_channels_1ch(net_name=net_name, net=net)
|
| 166 |
+
return net
|
| 167 |
+
|
| 168 |
+
@classmethod
|
| 169 |
+
def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module:
|
| 170 |
+
"""
|
| 171 |
+
Modify ViT depending on vit_image_size.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
net_name (str): ViT name
|
| 175 |
+
vit_image_size (int): image size which ViT handles if ViT is used.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
nn.Module: modified ViT
|
| 179 |
+
"""
|
| 180 |
+
base_vit = cls.vit[net_name]
|
| 181 |
+
# pretrained_vit = base_vit(weights=cls.vit_weight[net_name])
|
| 182 |
+
pretrained_vit = base_vit(weights='DEFAULT')
|
| 183 |
+
|
| 184 |
+
# Align weight depending on image size
|
| 185 |
+
weight = pretrained_vit.state_dict()
|
| 186 |
+
patch_size = int(net_name[-2:]) # 'ViTb16' -> 16
|
| 187 |
+
aligned_weight = models.vision_transformer.interpolate_embeddings(
|
| 188 |
+
image_size=vit_image_size,
|
| 189 |
+
patch_size=patch_size,
|
| 190 |
+
model_state=weight
|
| 191 |
+
)
|
| 192 |
+
aligned_vit = base_vit(image_size=vit_image_size) # Specify new image size.
|
| 193 |
+
aligned_vit.load_state_dict(aligned_weight) # Load weight which can handle the new image size.
|
| 194 |
+
return aligned_vit
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def construct_extractor(
|
| 198 |
+
cls,
|
| 199 |
+
net_name: str = None,
|
| 200 |
+
mlp_num_inputs: int = None,
|
| 201 |
+
in_channel: int = None,
|
| 202 |
+
vit_image_size: int = None,
|
| 203 |
+
pretrained: bool = None
|
| 204 |
+
) -> nn.Module:
|
| 205 |
+
"""
|
| 206 |
+
Construct extractor of network depending on net_name.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
net_name (str): network name.
|
| 210 |
+
mlp_num_inputs (int, optional): number of input of MLP. Defaults to None.
|
| 211 |
+
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
|
| 212 |
+
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
|
| 213 |
+
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
nn.Module: extractor of network
|
| 217 |
+
"""
|
| 218 |
+
if net_name == 'MLP':
|
| 219 |
+
extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs)
|
| 220 |
+
else:
|
| 221 |
+
extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained)
|
| 222 |
+
setattr(extractor, cls.classifier[net_name], cls.DUMMY) # Replace classifier with DUMMY(=nn.Identity()).
|
| 223 |
+
return extractor
|
| 224 |
+
|
| 225 |
+
@classmethod
|
| 226 |
+
def get_classifier(cls, net_name: str) -> nn.Module:
|
| 227 |
+
"""
|
| 228 |
+
Get classifier of network depending on net_name.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
net_name (str): network name
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
nn.Module: classifier of network
|
| 235 |
+
"""
|
| 236 |
+
net = cls.net[net_name]()
|
| 237 |
+
classifier = getattr(net, cls.classifier[net_name])
|
| 238 |
+
return classifier
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict:
|
| 242 |
+
"""
|
| 243 |
+
Construct classifier for multi-label.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
net_name (str): network name
|
| 247 |
+
num_outputs_for_label (Dict[str, int]): number of outputs for each label
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
nn.ModuleDict: classifier for multi-label
|
| 251 |
+
"""
|
| 252 |
+
classifiers = dict()
|
| 253 |
+
if net_name == 'MLP':
|
| 254 |
+
in_features = cls.mlp_config['hidden_channels'][-1]
|
| 255 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
| 256 |
+
classifiers[label_name] = nn.Linear(in_features, num_outputs)
|
| 257 |
+
|
| 258 |
+
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
|
| 259 |
+
base_classifier = cls.get_classifier(net_name)
|
| 260 |
+
in_features = base_classifier.in_features
|
| 261 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
| 262 |
+
classifiers[label_name] = nn.Linear(in_features, num_outputs)
|
| 263 |
+
|
| 264 |
+
elif net_name.startswith('EfficientNet'):
|
| 265 |
+
base_classifier = cls.get_classifier(net_name)
|
| 266 |
+
dropout = base_classifier[0].p
|
| 267 |
+
in_features = base_classifier[1].in_features
|
| 268 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
| 269 |
+
classifiers[label_name] = nn.Sequential(
|
| 270 |
+
nn.Dropout(p=dropout, inplace=False),
|
| 271 |
+
nn.Linear(in_features, num_outputs)
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
elif net_name.startswith('ConvNeXt'):
|
| 275 |
+
base_classifier = cls.get_classifier(net_name)
|
| 276 |
+
layer_norm = base_classifier[0]
|
| 277 |
+
flatten = base_classifier[1]
|
| 278 |
+
in_features = base_classifier[2].in_features
|
| 279 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
| 280 |
+
# Shape is changed before nn.Linear.
|
| 281 |
+
classifiers[label_name] = nn.Sequential(
|
| 282 |
+
layer_norm,
|
| 283 |
+
flatten,
|
| 284 |
+
nn.Linear(in_features, num_outputs)
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
elif net_name.startswith('ViT'):
|
| 288 |
+
base_classifier = cls.get_classifier(net_name)
|
| 289 |
+
in_features = base_classifier.head.in_features
|
| 290 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
| 291 |
+
classifiers[label_name] = nn.Sequential(
|
| 292 |
+
OrderedDict([
|
| 293 |
+
('head', nn.Linear(in_features, num_outputs))
|
| 294 |
+
])
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
else:
|
| 298 |
+
raise ValueError(f"No specified net: {net_name}.")
|
| 299 |
+
|
| 300 |
+
multi_classifier = nn.ModuleDict(classifiers)
|
| 301 |
+
return multi_classifier
|
| 302 |
+
|
| 303 |
+
@classmethod
|
| 304 |
+
def get_classifier_in_features(cls, net_name: str) -> int:
|
| 305 |
+
"""
|
| 306 |
+
Return in_feature of network indicating by net_name.
|
| 307 |
+
This class is used in class MultiNetFusion() only.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
net_name (str): net_name
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
int : in_feature
|
| 314 |
+
|
| 315 |
+
Required:
|
| 316 |
+
classifier.in_feature
|
| 317 |
+
classifier.[1].in_features
|
| 318 |
+
classifier.[2].in_features
|
| 319 |
+
classifier.head.in_features
|
| 320 |
+
"""
|
| 321 |
+
if net_name == 'MLP':
|
| 322 |
+
in_features = cls.mlp_config['hidden_channels'][-1]
|
| 323 |
+
|
| 324 |
+
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
|
| 325 |
+
base_classifier = cls.get_classifier(net_name)
|
| 326 |
+
in_features = base_classifier.in_features
|
| 327 |
+
|
| 328 |
+
elif net_name.startswith('EfficientNet'):
|
| 329 |
+
base_classifier = cls.get_classifier(net_name)
|
| 330 |
+
in_features = base_classifier[1].in_features
|
| 331 |
+
|
| 332 |
+
elif net_name.startswith('ConvNeXt'):
|
| 333 |
+
base_classifier = cls.get_classifier(net_name)
|
| 334 |
+
in_features = base_classifier[2].in_features
|
| 335 |
+
|
| 336 |
+
elif net_name.startswith('ViT'):
|
| 337 |
+
base_classifier = cls.get_classifier(net_name)
|
| 338 |
+
in_features = base_classifier.head.in_features
|
| 339 |
+
|
| 340 |
+
else:
|
| 341 |
+
raise ValueError(f"No specified net: {net_name}.")
|
| 342 |
+
return in_features
|
| 343 |
+
|
| 344 |
+
@classmethod
|
| 345 |
+
def construct_aux_module(cls, net_name: str) -> nn.Sequential:
|
| 346 |
+
"""
|
| 347 |
+
Construct module to align the shape of feature from extractor depending on network.
|
| 348 |
+
Actually, only when net_name == 'ConvNeXt'.
|
| 349 |
+
Because ConvNeXt has the process of aligning the dimensions in its classifier.
|
| 350 |
+
|
| 351 |
+
Needs to align shape of the feature extractor when ConvNeXt
|
| 352 |
+
(classifier):
|
| 353 |
+
Sequential(
|
| 354 |
+
(0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
|
| 355 |
+
(1): Flatten(start_dim=1, end_dim=-1)
|
| 356 |
+
(2): Linear(in_features=768, out_features=1000, bias=True)
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
net_name (str): net name
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt.
|
| 364 |
+
"""
|
| 365 |
+
aux_module = cls.DUMMY
|
| 366 |
+
if net_name.startswith('ConvNeXt'):
|
| 367 |
+
base_classifier = cls.get_classifier(net_name)
|
| 368 |
+
layer_norm = base_classifier[0]
|
| 369 |
+
flatten = base_classifier[1]
|
| 370 |
+
aux_module = nn.Sequential(
|
| 371 |
+
layer_norm,
|
| 372 |
+
flatten
|
| 373 |
+
)
|
| 374 |
+
return aux_module
|
| 375 |
+
|
| 376 |
+
@classmethod
|
| 377 |
+
def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module:
|
| 378 |
+
"""
|
| 379 |
+
Return the last extractor of network.
|
| 380 |
+
This is for Grad-CAM.
|
| 381 |
+
net should be one loaded weight.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
net (nn.Module): network itself
|
| 385 |
+
mlp (str): 'MLP', otherwise None
|
| 386 |
+
net_name (str): network name
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
nn.Module: last extractor of network
|
| 390 |
+
"""
|
| 391 |
+
assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}."
|
| 392 |
+
|
| 393 |
+
_extractor = net.extractor_net
|
| 394 |
+
|
| 395 |
+
if net_name.startswith('ResNet'):
|
| 396 |
+
last_extractor = _extractor.layer4[-1]
|
| 397 |
+
elif net_name.startswith('DenseNet'):
|
| 398 |
+
last_extractor = _extractor.features.denseblock4.denselayer24
|
| 399 |
+
elif net_name.startswith('EfficientNet'):
|
| 400 |
+
last_extractor = _extractor.features[-1]
|
| 401 |
+
elif net_name.startswith('ConvNeXt'):
|
| 402 |
+
last_extractor = _extractor.features[-1][-1].block
|
| 403 |
+
elif net_name.startswith('ViT'):
|
| 404 |
+
last_extractor = _extractor.encoder.layers[-1]
|
| 405 |
+
else:
|
| 406 |
+
raise ValueError(f"Cannot get last extractor of net: {net_name}.")
|
| 407 |
+
return last_extractor
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class MultiMixin:
|
| 411 |
+
"""
|
| 412 |
+
Class to define auxiliary function to handle multi-label.
|
| 413 |
+
"""
|
| 414 |
+
def multi_forward(self, out_features: int) -> Dict[str, float]:
|
| 415 |
+
"""
|
| 416 |
+
Forward out_features to classifier for each label.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
out_features (int): output from extractor
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
Dict[str, float]: output of classifier of each label
|
| 423 |
+
"""
|
| 424 |
+
output = dict()
|
| 425 |
+
for label_name, classifier in self.multi_classifier.items():
|
| 426 |
+
output[label_name] = classifier(out_features)
|
| 427 |
+
return output
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class MultiWidget(nn.Module, BaseNet, MultiMixin):
|
| 431 |
+
"""
|
| 432 |
+
Class for a widget to inherit multiple classes simultaneously.
|
| 433 |
+
"""
|
| 434 |
+
pass
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class MultiNet(MultiWidget):
|
| 438 |
+
"""
|
| 439 |
+
Model of MLP, CNN or ViT.
|
| 440 |
+
"""
|
| 441 |
+
def __init__(
|
| 442 |
+
self,
|
| 443 |
+
net_name: str = None,
|
| 444 |
+
num_outputs_for_label: Dict[str, int] = None,
|
| 445 |
+
mlp_num_inputs: int = None,
|
| 446 |
+
in_channel: int = None,
|
| 447 |
+
vit_image_size: int = None,
|
| 448 |
+
pretrained: bool = None
|
| 449 |
+
) -> None:
|
| 450 |
+
"""
|
| 451 |
+
Args:
|
| 452 |
+
net_name (str): MLP, CNN or ViT name
|
| 453 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
| 454 |
+
mlp_num_inputs (int): number of input of MLP.
|
| 455 |
+
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
|
| 456 |
+
vit_image_size (int): image size to be input to ViT.
|
| 457 |
+
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
|
| 458 |
+
"""
|
| 459 |
+
super().__init__()
|
| 460 |
+
|
| 461 |
+
self.net_name = net_name
|
| 462 |
+
self.num_outputs_for_label = num_outputs_for_label
|
| 463 |
+
self.mlp_num_inputs = mlp_num_inputs
|
| 464 |
+
self.in_channel = in_channel
|
| 465 |
+
self.vit_image_size = vit_image_size
|
| 466 |
+
self.pretrained = pretrained
|
| 467 |
+
|
| 468 |
+
# self.extractor_net = MLP or CVmodel
|
| 469 |
+
self.extractor_net = self.construct_extractor(
|
| 470 |
+
net_name=self.net_name,
|
| 471 |
+
mlp_num_inputs=self.mlp_num_inputs,
|
| 472 |
+
in_channel=self.in_channel,
|
| 473 |
+
vit_image_size=self.vit_image_size,
|
| 474 |
+
pretrained=self.pretrained
|
| 475 |
+
)
|
| 476 |
+
self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label)
|
| 477 |
+
|
| 478 |
+
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 479 |
+
"""
|
| 480 |
+
Forward.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
x (torch.Tensor): tabular data or image
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
Dict[str, torch.Tensor]: output
|
| 487 |
+
"""
|
| 488 |
+
out_features = self.extractor_net(x)
|
| 489 |
+
output = self.multi_forward(out_features)
|
| 490 |
+
return output
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class MultiNetFusion(MultiWidget):
|
| 494 |
+
"""
|
| 495 |
+
Fusion model of MLP and CNN or ViT.
|
| 496 |
+
"""
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
net_name: str = None,
|
| 500 |
+
num_outputs_for_label: Dict[str, int] = None,
|
| 501 |
+
mlp_num_inputs: int = None,
|
| 502 |
+
in_channel: int = None,
|
| 503 |
+
vit_image_size: int = None,
|
| 504 |
+
pretrained: bool = None
|
| 505 |
+
) -> None:
|
| 506 |
+
"""
|
| 507 |
+
Args:
|
| 508 |
+
net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model.
|
| 509 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
| 510 |
+
mlp_num_inputs (int): number of input of MLP. Defaults to None.
|
| 511 |
+
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
|
| 512 |
+
vit_image_size (int): image size to be input to ViT.
|
| 513 |
+
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
|
| 514 |
+
"""
|
| 515 |
+
assert (net_name != 'MLP'), 'net_name should not be MLP.'
|
| 516 |
+
|
| 517 |
+
super().__init__()
|
| 518 |
+
|
| 519 |
+
self.net_name = net_name
|
| 520 |
+
self.num_outputs_for_label = num_outputs_for_label
|
| 521 |
+
self.mlp_num_inputs = mlp_num_inputs
|
| 522 |
+
self.in_channel = in_channel
|
| 523 |
+
self.vit_image_size = vit_image_size
|
| 524 |
+
self.pretrained = pretrained
|
| 525 |
+
|
| 526 |
+
# Extractor of MLP and Net
|
| 527 |
+
self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs)
|
| 528 |
+
self.extractor_net = self.construct_extractor(
|
| 529 |
+
net_name=self.net_name,
|
| 530 |
+
in_channel=self.in_channel,
|
| 531 |
+
vit_image_size=self.vit_image_size,
|
| 532 |
+
pretrained=self.pretrained
|
| 533 |
+
)
|
| 534 |
+
self.aux_module = self.construct_aux_module(self.net_name)
|
| 535 |
+
|
| 536 |
+
# Intermediate MLP
|
| 537 |
+
self.in_features_from_mlp = self.get_classifier_in_features('MLP')
|
| 538 |
+
self.in_features_from_net = self.get_classifier_in_features(self.net_name)
|
| 539 |
+
self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net
|
| 540 |
+
self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False)
|
| 541 |
+
|
| 542 |
+
# Multi classifier
|
| 543 |
+
self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label)
|
| 544 |
+
|
| 545 |
+
def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 546 |
+
"""
|
| 547 |
+
Forward.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
x_mlp (torch.Tensor): tabular data
|
| 551 |
+
x_net (torch.Tensor): image
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
Dict[str, torch.Tensor]: output
|
| 555 |
+
"""
|
| 556 |
+
out_mlp = self.extractor_mlp(x_mlp)
|
| 557 |
+
out_net = self.extractor_net(x_net)
|
| 558 |
+
out_net = self.aux_module(out_net)
|
| 559 |
+
|
| 560 |
+
out_features = torch.cat([out_mlp, out_net], dim=1)
|
| 561 |
+
out_features = self.inter_mlp(out_features)
|
| 562 |
+
output = self.multi_forward(out_features)
|
| 563 |
+
return output
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def create_net(
|
| 567 |
+
mlp: Optional[str] = None,
|
| 568 |
+
net: Optional[str] = None,
|
| 569 |
+
num_outputs_for_label: Dict[str, int] = None,
|
| 570 |
+
mlp_num_inputs: int = None,
|
| 571 |
+
in_channel: int = None,
|
| 572 |
+
vit_image_size: int = None,
|
| 573 |
+
pretrained: bool = None
|
| 574 |
+
) -> nn.Module:
|
| 575 |
+
"""
|
| 576 |
+
Create network.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
mlp (Optional[str]): 'MLP' or None
|
| 580 |
+
net (Optional[str]): CNN, ViT name or None
|
| 581 |
+
num_outputs_for_label (Dict[str, int]): number of outputs for each label
|
| 582 |
+
mlp_num_inputs (int): number of input of MLP.
|
| 583 |
+
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
|
| 584 |
+
vit_image_size (int): image size to be input to ViT.
|
| 585 |
+
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
nn.Module: network
|
| 589 |
+
"""
|
| 590 |
+
_isMLPModel = (mlp is not None) and (net is None)
|
| 591 |
+
_isCVModel = (mlp is None) and (net is not None)
|
| 592 |
+
_isFusion = (mlp is not None) and (net is not None)
|
| 593 |
+
|
| 594 |
+
if _isMLPModel:
|
| 595 |
+
multi_net = MultiNet(
|
| 596 |
+
net_name='MLP',
|
| 597 |
+
num_outputs_for_label=num_outputs_for_label,
|
| 598 |
+
mlp_num_inputs=mlp_num_inputs,
|
| 599 |
+
in_channel=in_channel,
|
| 600 |
+
vit_image_size=vit_image_size,
|
| 601 |
+
pretrained=False # No need of pretrained for MLP
|
| 602 |
+
)
|
| 603 |
+
elif _isCVModel:
|
| 604 |
+
multi_net = MultiNet(
|
| 605 |
+
net_name=net,
|
| 606 |
+
num_outputs_for_label=num_outputs_for_label,
|
| 607 |
+
mlp_num_inputs=mlp_num_inputs,
|
| 608 |
+
in_channel=in_channel,
|
| 609 |
+
vit_image_size=vit_image_size,
|
| 610 |
+
pretrained=pretrained
|
| 611 |
+
)
|
| 612 |
+
elif _isFusion:
|
| 613 |
+
multi_net = MultiNetFusion(
|
| 614 |
+
net_name=net,
|
| 615 |
+
num_outputs_for_label=num_outputs_for_label,
|
| 616 |
+
mlp_num_inputs=mlp_num_inputs,
|
| 617 |
+
in_channel=in_channel,
|
| 618 |
+
vit_image_size=vit_image_size,
|
| 619 |
+
pretrained=pretrained
|
| 620 |
+
)
|
| 621 |
+
else:
|
| 622 |
+
raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.")
|
| 623 |
+
|
| 624 |
+
return multi_net
|
lib/component/optimizer.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def set_optimizer(optimizer_name: str, network: nn.Module, lr: float) -> optim:
|
| 9 |
+
"""
|
| 10 |
+
Set optimizer.
|
| 11 |
+
Args:
|
| 12 |
+
optimizer_name (str): criterion name
|
| 13 |
+
network (torch.nn.Module): network
|
| 14 |
+
lr (float): learning rate
|
| 15 |
+
Returns:
|
| 16 |
+
torch.optim: optimizer
|
| 17 |
+
"""
|
| 18 |
+
optimizers = {
|
| 19 |
+
'SGD': optim.SGD,
|
| 20 |
+
'Adadelta': optim.Adadelta,
|
| 21 |
+
'Adam': optim.Adam,
|
| 22 |
+
'RMSprop': optim.RMSprop,
|
| 23 |
+
'RAdam': optim.RAdam
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
assert (optimizer_name in optimizers), f"No specified optimizer: {optimizer_name}."
|
| 27 |
+
|
| 28 |
+
_optim = optimizers[optimizer_name]
|
| 29 |
+
|
| 30 |
+
if lr is None:
|
| 31 |
+
optimizer = _optim(network.parameters())
|
| 32 |
+
else:
|
| 33 |
+
optimizer = _optim(network.parameters(), lr=lr)
|
| 34 |
+
return optimizer
|
lib/dataloader.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
from torch.utils.data.dataset import Dataset
|
| 8 |
+
from torch.utils.data.dataloader import DataLoader
|
| 9 |
+
from torch.utils.data.sampler import WeightedRandomSampler
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 12 |
+
import pickle
|
| 13 |
+
from .logger import BaseLogger
|
| 14 |
+
from typing import List, Dict, Union
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = BaseLogger.get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PrivateAugment(torch.nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Augmentation defined privately.
|
| 24 |
+
Variety of augmentation can be written in this class if necessary.
|
| 25 |
+
"""
|
| 26 |
+
# For X-ray photo.
|
| 27 |
+
xray_augs_list = [
|
| 28 |
+
transforms.RandomAffine(degrees=(-3, 3), translate=(0.02, 0.02)),
|
| 29 |
+
transforms.RandomAdjustSharpness(sharpness_factor=2),
|
| 30 |
+
transforms.RandomAutocontrast()
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class InputDataMixin:
|
| 35 |
+
"""
|
| 36 |
+
Class to normalizes input data.
|
| 37 |
+
"""
|
| 38 |
+
def _make_scaler(self) -> MinMaxScaler:
|
| 39 |
+
"""
|
| 40 |
+
Make scaler to normalize input data by min-max normalization with train data.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
MinMaxScaler: scaler
|
| 44 |
+
"""
|
| 45 |
+
scaler = MinMaxScaler()
|
| 46 |
+
_df_train = self.df_source[self.df_source['split'] == 'train'] # should be normalized with min and max of training data
|
| 47 |
+
_ = scaler.fit(_df_train[self.input_list]) # fit only
|
| 48 |
+
return scaler
|
| 49 |
+
|
| 50 |
+
def save_scaler(self, save_path :str) -> None:
|
| 51 |
+
"""
|
| 52 |
+
Save scaler
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
save_path (str): path for saving scaler.
|
| 56 |
+
"""
|
| 57 |
+
#save_scaler_path = Path(save_datetime_dir, 'scaler.pkl')
|
| 58 |
+
with open(save_path, 'wb') as f:
|
| 59 |
+
pickle.dump(self.scaler, f)
|
| 60 |
+
|
| 61 |
+
def load_scaler(self, scaler_path :str) -> None:
|
| 62 |
+
"""
|
| 63 |
+
Load scaler.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
scaler_path (str): path to scaler
|
| 67 |
+
"""
|
| 68 |
+
with open(scaler_path, 'rb') as f:
|
| 69 |
+
scaler = pickle.load(f)
|
| 70 |
+
return scaler
|
| 71 |
+
|
| 72 |
+
def _normalize_inputs(self, df_inputs: pd.DataFrame) -> torch.FloatTensor:
|
| 73 |
+
"""
|
| 74 |
+
Normalize inputs.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
df_inputs (pd.DataFrame): DataFrame of inputs
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
torch.FloatTensor: normalized inputs
|
| 81 |
+
|
| 82 |
+
Note:
|
| 83 |
+
After iloc[[idx], index_input_list], pd.DataFrame is obtained.
|
| 84 |
+
DataFrame fits the input type of self.scaler.transform.
|
| 85 |
+
However, after normalizing, the shape of inputs_value is (1, N), where N is the number of input values.
|
| 86 |
+
Since the shape (1, N) is not acceptable when forwarding, convert (1, N) -> (N,) is needed.
|
| 87 |
+
"""
|
| 88 |
+
inputs_value = self.scaler.transform(df_inputs).reshape(-1) # np.float64
|
| 89 |
+
inputs_value = np.array(inputs_value, dtype=np.float32) # -> np.float32
|
| 90 |
+
inputs_value = torch.from_numpy(inputs_value).clone() # -> torch.float32
|
| 91 |
+
return inputs_value
|
| 92 |
+
|
| 93 |
+
def _load_input_value_if_mlp(self, idx: int) -> Union[torch.FloatTensor, str]:
|
| 94 |
+
"""
|
| 95 |
+
Load input values after converting them into tensor if MLP is used.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
idx (int): index
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Union[torch.Tensor[float], str]: tensor of input values, or empty string
|
| 102 |
+
"""
|
| 103 |
+
inputs_value = ''
|
| 104 |
+
|
| 105 |
+
if self.params.mlp is None:
|
| 106 |
+
return inputs_value
|
| 107 |
+
|
| 108 |
+
index_input_list = [self.col_index_dict[input] for input in self.input_list]
|
| 109 |
+
_df_inputs = self.df_split.iloc[[idx], index_input_list]
|
| 110 |
+
inputs_value = self._normalize_inputs( _df_inputs)
|
| 111 |
+
return inputs_value
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class ImageMixin:
|
| 115 |
+
"""
|
| 116 |
+
Class to normalize and transform image.
|
| 117 |
+
"""
|
| 118 |
+
def _make_augmentations(self) -> List:
|
| 119 |
+
"""
|
| 120 |
+
Define which augmentation is applied.
|
| 121 |
+
|
| 122 |
+
When training, augmentation is needed for train data only.
|
| 123 |
+
When test, no need of augmentation.
|
| 124 |
+
"""
|
| 125 |
+
_augmentation = []
|
| 126 |
+
if (self.params.isTrain) and (self.split == 'train'):
|
| 127 |
+
if self.params.augmentation == 'xrayaug':
|
| 128 |
+
_augmentation = PrivateAugment.xray_augs_list
|
| 129 |
+
elif self.params.augmentation == 'trivialaugwide':
|
| 130 |
+
_augmentation.append(transforms.TrivialAugmentWide())
|
| 131 |
+
elif self.params.augmentation == 'randaug':
|
| 132 |
+
_augmentation.append(transforms.RandAugment())
|
| 133 |
+
else:
|
| 134 |
+
# ie. self.params.augmentation == 'no':
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
_augmentation = transforms.Compose(_augmentation)
|
| 138 |
+
return _augmentation
|
| 139 |
+
|
| 140 |
+
def _make_transforms(self) -> List:
|
| 141 |
+
"""
|
| 142 |
+
Make list of transforms.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
list of transforms: image normalization
|
| 146 |
+
"""
|
| 147 |
+
_transforms = []
|
| 148 |
+
_transforms.append(transforms.ToTensor())
|
| 149 |
+
|
| 150 |
+
if self.params.normalize_image == 'yes':
|
| 151 |
+
# transforms.Normalize accepts only Tensor.
|
| 152 |
+
if self.params.in_channel == 1:
|
| 153 |
+
_transforms.append(transforms.Normalize(mean=(0.5, ), std=(0.5, )))
|
| 154 |
+
else:
|
| 155 |
+
# ie. self.params.in_channel == 3
|
| 156 |
+
_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
|
| 157 |
+
|
| 158 |
+
_transforms = transforms.Compose(_transforms)
|
| 159 |
+
return _transforms
|
| 160 |
+
|
| 161 |
+
def _open_image_in_channel(self, imgpath: str, in_channel: int) -> Image:
|
| 162 |
+
"""
|
| 163 |
+
Open image in channel.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
imgpath (str): path to image
|
| 167 |
+
in_channel (int): channel, or 1 or 3
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
Image: PIL image
|
| 171 |
+
"""
|
| 172 |
+
if in_channel == 1:
|
| 173 |
+
image = Image.open(imgpath).convert('L') # eg. np.array(image).shape = (64, 64)
|
| 174 |
+
return image
|
| 175 |
+
else:
|
| 176 |
+
# ie. self.params.in_channel == 3
|
| 177 |
+
image = Image.open(imgpath).convert('RGB') # eg. np.array(image).shape = (64, 64, 3)
|
| 178 |
+
return image
|
| 179 |
+
|
| 180 |
+
def _load_image_if_cnn(self, idx: int) -> Union[torch.Tensor, str]:
|
| 181 |
+
"""
|
| 182 |
+
Load image and convert it to tensor if any of CNN or ViT is used.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
idx (int): index
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Union[torch.Tensor[float], str]: tensor converted from image, or empty string
|
| 189 |
+
"""
|
| 190 |
+
image = ''
|
| 191 |
+
|
| 192 |
+
if self.params.net is None:
|
| 193 |
+
return image
|
| 194 |
+
|
| 195 |
+
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']]
|
| 196 |
+
image = self._open_image_in_channel(imgpath, self.params.in_channel)
|
| 197 |
+
image = self.augmentation(image)
|
| 198 |
+
image = self.transform(image)
|
| 199 |
+
return image
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class DeepSurvMixin:
|
| 203 |
+
"""
|
| 204 |
+
Class to handle required data for deepsurv.
|
| 205 |
+
"""
|
| 206 |
+
def _load_periods_if_deepsurv(self, idx: int) -> Union[torch.FloatTensor, str]:
|
| 207 |
+
"""
|
| 208 |
+
Return period if deepsurv.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
idx (int): index
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Union[torch.FloatTensor, str]: period, or empty string
|
| 215 |
+
"""
|
| 216 |
+
periods = ''
|
| 217 |
+
|
| 218 |
+
if self.params.task != 'deepsurv':
|
| 219 |
+
return periods
|
| 220 |
+
|
| 221 |
+
assert (self.params.task == 'deepsurv') and (len(self.label_list) == 1), 'Deepsurv cannot work in multi-label.'
|
| 222 |
+
periods = self.df_split.iat[idx, self.col_index_dict[self.period_name]] # int64
|
| 223 |
+
periods = np.array(periods, dtype=np.float32) # -> np.float32
|
| 224 |
+
periods = torch.from_numpy(periods).clone() # -> torch.float32
|
| 225 |
+
return periods
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class DataSetWidget(InputDataMixin, ImageMixin, DeepSurvMixin):
|
| 229 |
+
"""
|
| 230 |
+
Class for a widget to inherit multiple classes simultaneously.
|
| 231 |
+
"""
|
| 232 |
+
pass
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class LoadDataSet(Dataset, DataSetWidget):
|
| 236 |
+
"""
|
| 237 |
+
Dataset for split.
|
| 238 |
+
"""
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
params,
|
| 242 |
+
split: str
|
| 243 |
+
) -> None:
|
| 244 |
+
"""
|
| 245 |
+
Args:
|
| 246 |
+
params (ParamSet): parameter for model
|
| 247 |
+
split (str): split
|
| 248 |
+
"""
|
| 249 |
+
self.params = params
|
| 250 |
+
self.df_source = self.params.df_source
|
| 251 |
+
self.split = split
|
| 252 |
+
|
| 253 |
+
self.input_list = self.params.input_list
|
| 254 |
+
self.label_list = self.params.label_list
|
| 255 |
+
|
| 256 |
+
if self.params.task == 'deepsurv':
|
| 257 |
+
self.period_name = self.params.period_name
|
| 258 |
+
|
| 259 |
+
self.df_split = self.df_source[self.df_source['split'] == self.split]
|
| 260 |
+
self.col_index_dict = {col_name: self.df_split.columns.get_loc(col_name) for col_name in self.df_split.columns}
|
| 261 |
+
|
| 262 |
+
# For input data
|
| 263 |
+
if self.params.mlp is not None:
|
| 264 |
+
assert (self.input_list != []), f"input list is empty."
|
| 265 |
+
if params.isTrain:
|
| 266 |
+
self.scaler = self._make_scaler()
|
| 267 |
+
else:
|
| 268 |
+
# load scaler used at training.
|
| 269 |
+
self.scaler = self.load_scaler(self.params.scaler_path)
|
| 270 |
+
|
| 271 |
+
# For image
|
| 272 |
+
if self.params.net is not None:
|
| 273 |
+
self.augmentation = self._make_augmentations()
|
| 274 |
+
self.transform = self._make_transforms()
|
| 275 |
+
|
| 276 |
+
def __len__(self) -> int:
|
| 277 |
+
"""
|
| 278 |
+
Return length of DataFrame.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
int: length of DataFrame
|
| 282 |
+
"""
|
| 283 |
+
return len(self.df_split)
|
| 284 |
+
|
| 285 |
+
def _load_label(self, idx: int) -> Dict[str, Union[int, float]]:
|
| 286 |
+
"""
|
| 287 |
+
Return labels.
|
| 288 |
+
If no column of label when csv of external dataset is used,
|
| 289 |
+
empty dictionary is returned.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
idx (int): index
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Dict[str, Union[int, float]]: dictionary of label name and its value
|
| 296 |
+
"""
|
| 297 |
+
# For checking if columns of labels exist when used csv for external dataset.
|
| 298 |
+
label_list_in_split = list(self.df_split.columns[self.df_split.columns.str.startswith('label')])
|
| 299 |
+
label_dict = dict()
|
| 300 |
+
if label_list_in_split != []:
|
| 301 |
+
for label_name in self.label_list:
|
| 302 |
+
label_dict[label_name] = self.df_split.iat[idx, self.col_index_dict[label_name]]
|
| 303 |
+
else:
|
| 304 |
+
# no label
|
| 305 |
+
pass
|
| 306 |
+
return label_dict
|
| 307 |
+
|
| 308 |
+
def __getitem__(self, idx: int) -> Dict:
|
| 309 |
+
"""
|
| 310 |
+
Return data row specified by index.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
idx (int): index
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
Dict: dictionary of data to be passed model
|
| 317 |
+
"""
|
| 318 |
+
uniqID = self.df_split.iat[idx, self.col_index_dict['uniqID']]
|
| 319 |
+
group = self.df_split.iat[idx, self.col_index_dict['group']]
|
| 320 |
+
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']]
|
| 321 |
+
split = self.df_split.iat[idx, self.col_index_dict['split']]
|
| 322 |
+
inputs_value = self._load_input_value_if_mlp(idx)
|
| 323 |
+
image = self._load_image_if_cnn(idx)
|
| 324 |
+
label_dict = self._load_label(idx)
|
| 325 |
+
periods = self._load_periods_if_deepsurv(idx)
|
| 326 |
+
|
| 327 |
+
_data = {
|
| 328 |
+
'uniqID': uniqID,
|
| 329 |
+
'group': group,
|
| 330 |
+
'imgpath': imgpath,
|
| 331 |
+
'split': split,
|
| 332 |
+
'inputs': inputs_value,
|
| 333 |
+
'image': image,
|
| 334 |
+
'labels': label_dict,
|
| 335 |
+
'periods': periods
|
| 336 |
+
}
|
| 337 |
+
return _data
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _make_sampler(split_data: LoadDataSet) -> WeightedRandomSampler:
|
| 341 |
+
"""
|
| 342 |
+
Make sampler.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
split_data (LoadDataSet): dataset
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
WeightedRandomSampler: sampler
|
| 349 |
+
"""
|
| 350 |
+
_target = []
|
| 351 |
+
for _, data in enumerate(split_data):
|
| 352 |
+
_target.append(list(data['labels'].values())[0])
|
| 353 |
+
|
| 354 |
+
class_sample_count = np.array([len(np.where(_target == t)[0]) for t in np.unique(_target)])
|
| 355 |
+
weight = 1. / class_sample_count
|
| 356 |
+
samples_weight = np.array([weight[t] for t in _target])
|
| 357 |
+
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
|
| 358 |
+
return sampler
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def create_dataloader(
|
| 362 |
+
params,
|
| 363 |
+
split: str = None
|
| 364 |
+
) -> DataLoader:
|
| 365 |
+
"""
|
| 366 |
+
Create data loader ofr split.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
params (ParamSet): parameter for dataloader
|
| 370 |
+
split (str): split. Defaults to None.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
DataLoader: data loader
|
| 374 |
+
"""
|
| 375 |
+
split_data = LoadDataSet(params, split)
|
| 376 |
+
|
| 377 |
+
if params.isTrain:
|
| 378 |
+
batch_size = params.batch_size
|
| 379 |
+
shuffle = True
|
| 380 |
+
else:
|
| 381 |
+
batch_size = params.test_batch_size
|
| 382 |
+
shuffle = False
|
| 383 |
+
|
| 384 |
+
if params.sampler == 'yes':
|
| 385 |
+
assert ((params.task == 'classification') or (params.task == 'deepsurv')), 'Cannot make sampler in regression.'
|
| 386 |
+
assert (len(params.label_list) == 1), 'Cannot make sampler for multi-label.'
|
| 387 |
+
shuffle = False
|
| 388 |
+
sampler = _make_sampler(split_data)
|
| 389 |
+
else:
|
| 390 |
+
# When params.sampler == 'no'
|
| 391 |
+
sampler = None
|
| 392 |
+
|
| 393 |
+
split_loader = DataLoader(
|
| 394 |
+
dataset=split_data,
|
| 395 |
+
batch_size=batch_size,
|
| 396 |
+
shuffle=shuffle,
|
| 397 |
+
num_workers=0,
|
| 398 |
+
sampler=sampler
|
| 399 |
+
)
|
| 400 |
+
return split_loader
|
lib/framework.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import copy
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from .component import create_net
|
| 10 |
+
from .logger import BaseLogger
|
| 11 |
+
from lib import ParamSet
|
| 12 |
+
from typing import List, Dict, Tuple, Union
|
| 13 |
+
|
| 14 |
+
# Alias of typing
|
| 15 |
+
# eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}}
|
| 16 |
+
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = BaseLogger.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BaseModel(ABC):
|
| 23 |
+
"""
|
| 24 |
+
Class to construct model. This class is the base class to construct model.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, params: ParamSet) -> None:
|
| 27 |
+
"""
|
| 28 |
+
Class to define Model
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
param (ParamSet): parameters
|
| 32 |
+
"""
|
| 33 |
+
self.params = params
|
| 34 |
+
self.device = self.params.device
|
| 35 |
+
|
| 36 |
+
self.network = create_net(
|
| 37 |
+
mlp=self.params.mlp,
|
| 38 |
+
net=self.params.net,
|
| 39 |
+
num_outputs_for_label=self.params.num_outputs_for_label,
|
| 40 |
+
mlp_num_inputs=self.params.mlp_num_inputs,
|
| 41 |
+
in_channel=self.params.in_channel,
|
| 42 |
+
vit_image_size=self.params.vit_image_size,
|
| 43 |
+
pretrained=self.params.pretrained
|
| 44 |
+
)
|
| 45 |
+
self.network.to(self.device)
|
| 46 |
+
|
| 47 |
+
# variables to keep temporary best_weight and best_epoch
|
| 48 |
+
self.acting_best_weight = None
|
| 49 |
+
self.acting_best_epoch = None
|
| 50 |
+
|
| 51 |
+
def train(self) -> None:
|
| 52 |
+
"""
|
| 53 |
+
Make network training mode.
|
| 54 |
+
"""
|
| 55 |
+
self.network.train()
|
| 56 |
+
|
| 57 |
+
def eval(self) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Make network evaluation mode.
|
| 60 |
+
"""
|
| 61 |
+
self.network.eval()
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
def set_data(
|
| 65 |
+
self,
|
| 66 |
+
data: Dict
|
| 67 |
+
) -> Tuple[
|
| 68 |
+
Dict[str, torch.FloatTensor],
|
| 69 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 70 |
+
]:
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
def store_weight(self, at_epoch: int = None) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Store weight and epoch number when it is saved.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
at_epoch (int): epoch number when save weight
|
| 79 |
+
"""
|
| 80 |
+
self.acting_best_epoch = at_epoch
|
| 81 |
+
|
| 82 |
+
_network = copy.deepcopy(self.network)
|
| 83 |
+
if hasattr(_network, 'module'):
|
| 84 |
+
# When DataParallel used, move weight to CPU.
|
| 85 |
+
self.acting_best_weight = copy.deepcopy(_network.module.to(torch.device('cpu')).state_dict())
|
| 86 |
+
else:
|
| 87 |
+
self.acting_best_weight = copy.deepcopy(_network.state_dict())
|
| 88 |
+
|
| 89 |
+
def save_weight(self, save_datetime_dir: str, as_best: bool = None) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Save weight.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
save_datetime_dir (str): save_datetime_dir
|
| 95 |
+
as_best (bool): True if weight is saved as best, otherwise False. Defaults to None.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
save_dir = Path(save_datetime_dir, 'weights')
|
| 99 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt'
|
| 101 |
+
save_path = Path(save_dir, save_name)
|
| 102 |
+
|
| 103 |
+
if as_best:
|
| 104 |
+
save_name_as_best = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '_best' + '.pt'
|
| 105 |
+
save_path_as_best = Path(save_dir, save_name_as_best)
|
| 106 |
+
if save_path.exists():
|
| 107 |
+
# Check if best weight already saved. If exists, rename with '_best'
|
| 108 |
+
save_path.rename(save_path_as_best)
|
| 109 |
+
else:
|
| 110 |
+
torch.save(self.acting_best_weight, save_path_as_best)
|
| 111 |
+
else:
|
| 112 |
+
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt'
|
| 113 |
+
torch.save(self.acting_best_weight, save_path)
|
| 114 |
+
|
| 115 |
+
def load_weight(self, weight_path: Path) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Load wight from weight_path.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
weight_path (Path): path to weight
|
| 121 |
+
"""
|
| 122 |
+
logger.info(f"Load weight: {weight_path}.\n")
|
| 123 |
+
weight = torch.load(weight_path)
|
| 124 |
+
self.network.load_state_dict(weight)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class ModelMixin:
|
| 128 |
+
def to_gpu(self, gpu_ids: List[int]) -> None:
|
| 129 |
+
"""
|
| 130 |
+
Make model compute on the GPU.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
gpu_ids (List[int]): GPU ids
|
| 134 |
+
"""
|
| 135 |
+
if gpu_ids != []:
|
| 136 |
+
assert torch.cuda.is_available(), 'No available GPU on this machine.'
|
| 137 |
+
self.network = nn.DataParallel(self.network, device_ids=gpu_ids)
|
| 138 |
+
|
| 139 |
+
def init_network(self) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Initialize network.
|
| 142 |
+
This method is used at test to reset the current weight by redefining network.
|
| 143 |
+
"""
|
| 144 |
+
self.network = create_net(
|
| 145 |
+
mlp=self.params.mlp,
|
| 146 |
+
net=self.params.net,
|
| 147 |
+
num_outputs_for_label=self.params.num_outputs_for_label,
|
| 148 |
+
mlp_num_inputs=self.params.mlp_num_inputs,
|
| 149 |
+
in_channel=self.params.in_channel,
|
| 150 |
+
vit_image_size=self.params.vit_image_size,
|
| 151 |
+
pretrained=self.params.pretrained
|
| 152 |
+
)
|
| 153 |
+
self.network.to(self.device)
|
| 154 |
+
|
| 155 |
+
class ModelWidget(BaseModel, ModelMixin):
|
| 156 |
+
"""
|
| 157 |
+
Class for a widget to inherit multiple classes simultaneously
|
| 158 |
+
"""
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class MLPModel(ModelWidget):
|
| 163 |
+
"""
|
| 164 |
+
Class for MLP model
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(self, params: ParamSet) -> None:
|
| 168 |
+
"""
|
| 169 |
+
Args:
|
| 170 |
+
params: (ParamSet): parameters
|
| 171 |
+
"""
|
| 172 |
+
super().__init__(params)
|
| 173 |
+
|
| 174 |
+
def set_data(
|
| 175 |
+
self,
|
| 176 |
+
data: Dict
|
| 177 |
+
) -> Tuple[
|
| 178 |
+
Dict[str, torch.FloatTensor],
|
| 179 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 180 |
+
]:
|
| 181 |
+
"""
|
| 182 |
+
Unpack data for forwarding of MLP and calculating loss
|
| 183 |
+
by passing them to device.
|
| 184 |
+
When deepsurv, period and network are also returned.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
data (Dict): dictionary of data
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Tuple[
|
| 191 |
+
Dict[str, torch.FloatTensor],
|
| 192 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 193 |
+
]: input of model and data for calculating loss.
|
| 194 |
+
eg.
|
| 195 |
+
([inputs], [labels]), or ([inputs], [labels, periods, network]) when deepsurv
|
| 196 |
+
"""
|
| 197 |
+
in_data = {'inputs': data['inputs'].to(self.device)}
|
| 198 |
+
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
|
| 199 |
+
|
| 200 |
+
if not any(data['periods']):
|
| 201 |
+
return in_data, labels
|
| 202 |
+
|
| 203 |
+
# When deepsurv
|
| 204 |
+
labels = {
|
| 205 |
+
**labels,
|
| 206 |
+
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
|
| 207 |
+
}
|
| 208 |
+
return in_data, labels
|
| 209 |
+
|
| 210 |
+
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 211 |
+
"""
|
| 212 |
+
Forward
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
in_data (Dict[str, torch.Tensor]): data to be input into model
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Dict[str, torch.Tensor]: output
|
| 219 |
+
"""
|
| 220 |
+
inputs = in_data['inputs']
|
| 221 |
+
output = self.network(inputs)
|
| 222 |
+
return output
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class CVModel(ModelWidget):
|
| 226 |
+
"""
|
| 227 |
+
Class for CNN or ViT model
|
| 228 |
+
"""
|
| 229 |
+
def __init__(self, params: ParamSet) -> None:
|
| 230 |
+
"""
|
| 231 |
+
Args:
|
| 232 |
+
params: (ParamSet): parameters
|
| 233 |
+
"""
|
| 234 |
+
super().__init__(params)
|
| 235 |
+
|
| 236 |
+
def set_data(
|
| 237 |
+
self,
|
| 238 |
+
data: Dict
|
| 239 |
+
) -> Tuple[
|
| 240 |
+
Dict[str, torch.FloatTensor],
|
| 241 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 242 |
+
]:
|
| 243 |
+
"""
|
| 244 |
+
Unpack data for forwarding of CNN or ViT and calculating loss by passing them to device.
|
| 245 |
+
When deepsurv, period and network are also returned.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
data (Dict): dictionary of data
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Tuple[
|
| 252 |
+
Dict[str, torch.FloatTensor],
|
| 253 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 254 |
+
]: input of model and data for calculating loss.
|
| 255 |
+
eg.
|
| 256 |
+
([image], [labels]), or ([image], [labels, periods, network]) when deepsurv
|
| 257 |
+
"""
|
| 258 |
+
in_data = {'image': data['image'].to(self.device)}
|
| 259 |
+
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
|
| 260 |
+
|
| 261 |
+
if not any(data['periods']):
|
| 262 |
+
return in_data, labels
|
| 263 |
+
|
| 264 |
+
# When deepsurv
|
| 265 |
+
labels = {
|
| 266 |
+
**labels,
|
| 267 |
+
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
|
| 268 |
+
}
|
| 269 |
+
return in_data, labels
|
| 270 |
+
|
| 271 |
+
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 272 |
+
"""
|
| 273 |
+
Forward
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
in_data (Dict[str, torch.Tensor]): data to be input into model
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Dict[str, torch.Tensor]: output
|
| 280 |
+
"""
|
| 281 |
+
image = in_data['image']
|
| 282 |
+
output = self.network(image)
|
| 283 |
+
return output
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class FusionModel(ModelWidget):
|
| 287 |
+
"""
|
| 288 |
+
Class for MLP+CNN or MLP+ViT model.
|
| 289 |
+
"""
|
| 290 |
+
def __init__(self, params: ParamSet) -> None:
|
| 291 |
+
"""
|
| 292 |
+
Args:
|
| 293 |
+
params: (ParamSet): parameters
|
| 294 |
+
"""
|
| 295 |
+
super().__init__(params)
|
| 296 |
+
|
| 297 |
+
def set_data(
|
| 298 |
+
self,
|
| 299 |
+
data: Dict
|
| 300 |
+
) -> Tuple[
|
| 301 |
+
Dict[str, torch.FloatTensor],
|
| 302 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 303 |
+
]:
|
| 304 |
+
"""
|
| 305 |
+
Unpack data for forwarding of MLP+CNN or MLP+ViT and calculating loss
|
| 306 |
+
by passing them to device.
|
| 307 |
+
When deepsurv, period and network are also returned.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
data (Dict): dictionary of data
|
| 311 |
+
|
| 312 |
+
Returns:
|
| 313 |
+
Tuple[
|
| 314 |
+
Dict[str, torch.FloatTensor],
|
| 315 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
| 316 |
+
]: input of model and data for calculating loss.
|
| 317 |
+
eg.
|
| 318 |
+
([inputs, image], [labels]), or ([inputs, image], [labels, periods, network]) when deepsurv
|
| 319 |
+
"""
|
| 320 |
+
in_data = {
|
| 321 |
+
'inputs': data['inputs'].to(self.device),
|
| 322 |
+
'image': data['image'].to(self.device)
|
| 323 |
+
}
|
| 324 |
+
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
|
| 325 |
+
|
| 326 |
+
if not any(data['periods']):
|
| 327 |
+
return in_data, labels
|
| 328 |
+
|
| 329 |
+
# When deepsurv
|
| 330 |
+
labels = {
|
| 331 |
+
**labels,
|
| 332 |
+
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
|
| 333 |
+
}
|
| 334 |
+
return in_data, labels
|
| 335 |
+
|
| 336 |
+
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 337 |
+
"""
|
| 338 |
+
Forward
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
in_data (Dict[str, torch.Tensor]): data to be input into model
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
Dict[str, torch.Tensor]: output
|
| 345 |
+
"""
|
| 346 |
+
inputs = in_data['inputs']
|
| 347 |
+
image = in_data['image']
|
| 348 |
+
output = self.network(inputs, image)
|
| 349 |
+
return output
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def create_model(params: ParamSet) -> nn.Module:
|
| 353 |
+
"""
|
| 354 |
+
Construct model.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
params (ParamSet): parameters
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
nn.Module: model
|
| 361 |
+
"""
|
| 362 |
+
_isMLPModel = (params.mlp is not None) and (params.net is None)
|
| 363 |
+
_isCVModel = (params.mlp is None) and (params.net is not None)
|
| 364 |
+
_isFusionModel = (params.mlp is not None) and (params.net is not None)
|
| 365 |
+
|
| 366 |
+
if _isMLPModel:
|
| 367 |
+
return MLPModel(params)
|
| 368 |
+
elif _isCVModel:
|
| 369 |
+
return CVModel(params)
|
| 370 |
+
elif _isFusionModel:
|
| 371 |
+
return FusionModel(params)
|
| 372 |
+
else:
|
| 373 |
+
raise ValueError(f"Invalid model type: mlp={params.mlp}, net={params.net}.")
|
lib/logger.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseLogger:
|
| 9 |
+
"""
|
| 10 |
+
Class for defining logger.
|
| 11 |
+
"""
|
| 12 |
+
_unexecuted_configure = True
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def get_logger(cls, name: str) -> logging.Logger:
|
| 16 |
+
"""
|
| 17 |
+
Set logger.
|
| 18 |
+
Args:
|
| 19 |
+
name (str): If needed, potentially hierarchical name is desired, eg. lib.net, lib.dataloader, etc.
|
| 20 |
+
For the details, see https://docs.python.org/3/library/logging.html?highlight=logging#module-logging.
|
| 21 |
+
Returns:
|
| 22 |
+
logging.Logger: logger
|
| 23 |
+
"""
|
| 24 |
+
if cls._unexecuted_configure:
|
| 25 |
+
cls._init_logger()
|
| 26 |
+
|
| 27 |
+
return logging.getLogger('nervus.{}'.format(name))
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def _init_logger(cls) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Configure logger.
|
| 33 |
+
"""
|
| 34 |
+
_root_logger = logging.getLogger('nervus')
|
| 35 |
+
_root_logger.setLevel(logging.DEBUG)
|
| 36 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 37 |
+
|
| 38 |
+
log_dir = Path('logs')
|
| 39 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
log_path = Path(log_dir, 'log.log')
|
| 41 |
+
|
| 42 |
+
# file handler
|
| 43 |
+
## upper warning
|
| 44 |
+
fh_err = logging.FileHandler(log_path)
|
| 45 |
+
fh_err.setLevel(logging.WARNING)
|
| 46 |
+
fh_err.setFormatter(formatter)
|
| 47 |
+
fh_err.addFilter(lambda log_record: not ('BdbQuit' in str(log_record.exc_info)) and (log_record.levelno >= logging.WARNING))
|
| 48 |
+
_root_logger.addHandler(fh_err)
|
| 49 |
+
|
| 50 |
+
## lower warning
|
| 51 |
+
fh = logging.FileHandler(log_path)
|
| 52 |
+
fh.setLevel(logging.DEBUG)
|
| 53 |
+
fh.addFilter(lambda log_record: log_record.levelno < logging.WARNING)
|
| 54 |
+
_root_logger.addHandler(fh)
|
| 55 |
+
|
| 56 |
+
# stream handler
|
| 57 |
+
## upper warning
|
| 58 |
+
ch_err = logging.StreamHandler()
|
| 59 |
+
ch_err.setLevel(logging.WARNING)
|
| 60 |
+
ch_err.setFormatter(formatter)
|
| 61 |
+
ch_err.addFilter(lambda log_record: log_record.levelno >= logging.WARNING)
|
| 62 |
+
_root_logger.addHandler(ch_err)
|
| 63 |
+
|
| 64 |
+
## lower warning
|
| 65 |
+
ch = logging.StreamHandler()
|
| 66 |
+
ch.setLevel(logging.DEBUG)
|
| 67 |
+
ch.addFilter(lambda log_record: log_record.levelno < logging.WARNING)
|
| 68 |
+
_root_logger.addHandler(ch)
|
| 69 |
+
|
| 70 |
+
cls._unexecuted_configure = False
|
| 71 |
+
|
lib/metrics.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from sklearn import metrics
|
| 8 |
+
from sklearn.preprocessing import label_binarize
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from matplotlib import colors as mcolors
|
| 11 |
+
from .logger import BaseLogger
|
| 12 |
+
from typing import Dict, Union
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = BaseLogger.get_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MetricsData:
|
| 19 |
+
"""
|
| 20 |
+
Class to store metrics as class variable.
|
| 21 |
+
Metrics are defined depending on task.
|
| 22 |
+
|
| 23 |
+
For ROC
|
| 24 |
+
self.fpr: np.ndarray
|
| 25 |
+
self.tpr: np.ndarray
|
| 26 |
+
self.auc: float
|
| 27 |
+
|
| 28 |
+
For Regression
|
| 29 |
+
self.y_obs: np.ndarray
|
| 30 |
+
self.y_pred: np.ndarray
|
| 31 |
+
self.r2: float
|
| 32 |
+
|
| 33 |
+
For DeepSurv
|
| 34 |
+
self.c_index: float
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self) -> None:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class LabelMetrics:
|
| 41 |
+
"""
|
| 42 |
+
Class to store metrics of each split for each label.
|
| 43 |
+
"""
|
| 44 |
+
def __init__(self) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Metrics of split, ie 'val' and 'test'
|
| 47 |
+
"""
|
| 48 |
+
self.val = MetricsData()
|
| 49 |
+
self.test = MetricsData()
|
| 50 |
+
|
| 51 |
+
def set_label_metrics(self, split: str, attr: str, value: Union[np.ndarray, float]) -> None:
|
| 52 |
+
"""
|
| 53 |
+
Set value as appropriate metrics of split.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
split (str): split
|
| 57 |
+
attr (str): attribute name as follows:
|
| 58 |
+
classification: 'fpr', 'tpr', or 'auc',
|
| 59 |
+
regression: 'y_obs'(ground truth), 'y_pred'(prediction) or 'r2', or
|
| 60 |
+
deepsurv: 'c_index'
|
| 61 |
+
value (Union[np.ndarray,float]): value of attr
|
| 62 |
+
"""
|
| 63 |
+
setattr(getattr(self, split), attr, value)
|
| 64 |
+
|
| 65 |
+
def get_label_metrics(self, split: str, attr: str) -> Union[np.ndarray, float]:
|
| 66 |
+
"""
|
| 67 |
+
Return value of metrics of split.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
split (str): split
|
| 71 |
+
attr (str): metrics name
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Union[np.ndarray,float]: value of attr
|
| 75 |
+
"""
|
| 76 |
+
return getattr(getattr(self, split), attr)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ROCMixin:
|
| 80 |
+
"""
|
| 81 |
+
Class for calculating ROC and AUC.
|
| 82 |
+
"""
|
| 83 |
+
def _set_roc(self, label_metrics: LabelMetrics, split: str, fpr: np.ndarray, tpr: np.ndarray) -> None:
|
| 84 |
+
"""
|
| 85 |
+
Set fpr, tpr, and auc.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
label_metrics (LabelMetrics): metrics of 'val' and 'test'
|
| 89 |
+
split (str): 'val' or 'test'
|
| 90 |
+
fpr (np.ndarray): FPR
|
| 91 |
+
tpr (np.ndarray): TPR
|
| 92 |
+
|
| 93 |
+
self.metrics_kind = 'auc' is defined in class ClsEval below.
|
| 94 |
+
"""
|
| 95 |
+
label_metrics.set_label_metrics(split, 'fpr', fpr)
|
| 96 |
+
label_metrics.set_label_metrics(split, 'tpr', tpr)
|
| 97 |
+
label_metrics.set_label_metrics(split, self.metrics_kind, metrics.auc(fpr, tpr))
|
| 98 |
+
|
| 99 |
+
def _cal_label_roc_binary(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
| 100 |
+
"""
|
| 101 |
+
Calculate ROC for binary class.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
label_name (str): label name
|
| 105 |
+
df_group (pd.DataFrame): likelihood for group
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
LabelMetrics: metrics of 'val' and 'test'
|
| 109 |
+
"""
|
| 110 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
|
| 111 |
+
df_label = df_group[required_columns]
|
| 112 |
+
POSITIVE = 1
|
| 113 |
+
positive_pred_name = 'pred_' + label_name + '_' + str(POSITIVE)
|
| 114 |
+
|
| 115 |
+
# ! When splits is 'test' only, ie when external dataset, error occurs.
|
| 116 |
+
label_metrics = LabelMetrics()
|
| 117 |
+
for split in ['val', 'test']:
|
| 118 |
+
df_split = df_label.query('split == @split')
|
| 119 |
+
y_true = df_split[label_name]
|
| 120 |
+
y_score = df_split[positive_pred_name]
|
| 121 |
+
_fpr, _tpr, _ = metrics.roc_curve(y_true, y_score)
|
| 122 |
+
self._set_roc(label_metrics, split, _fpr, _tpr)
|
| 123 |
+
return label_metrics
|
| 124 |
+
|
| 125 |
+
def _cal_label_roc_multi(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
| 126 |
+
"""
|
| 127 |
+
Calculate ROC for multi-class by macro average.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
label_name (str): label name
|
| 131 |
+
df_group (pd.DataFrame): likelihood for group
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
LabelMetrics: metrics of 'val' and 'test'
|
| 135 |
+
"""
|
| 136 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
|
| 137 |
+
df_label = df_group[required_columns]
|
| 138 |
+
|
| 139 |
+
pred_name_list = list(df_label.columns[df_label.columns.str.startswith('pred')])
|
| 140 |
+
class_list = [int(pred_name.rsplit('_', 1)[-1]) for pred_name in pred_name_list] # [pred_label_0, pred_label_1, pred_label_2] -> [0, 1, 2]
|
| 141 |
+
num_classes = len(class_list)
|
| 142 |
+
|
| 143 |
+
label_metrics = LabelMetrics()
|
| 144 |
+
for split in ['val', 'test']:
|
| 145 |
+
df_split = df_label.query('split == @split')
|
| 146 |
+
y_true = df_split[label_name]
|
| 147 |
+
y_true_bin = label_binarize(y_true, classes=class_list) # Since y_true: List[int], should be class_list: List[int]
|
| 148 |
+
|
| 149 |
+
# Compute ROC for each class by OneVsRest
|
| 150 |
+
_fpr = dict()
|
| 151 |
+
_tpr = dict()
|
| 152 |
+
for i, class_name in enumerate(class_list):
|
| 153 |
+
pred_name = 'pred_' + label_name + '_' + str(class_name)
|
| 154 |
+
_fpr[class_name], _tpr[class_name], _ = metrics.roc_curve(y_true_bin[:, i], df_split[pred_name])
|
| 155 |
+
|
| 156 |
+
# First aggregate all false positive rates
|
| 157 |
+
all_fpr = np.unique(np.concatenate([_fpr[class_name] for class_name in class_list]))
|
| 158 |
+
|
| 159 |
+
# Then interpolate all ROC at this points
|
| 160 |
+
mean_tpr = np.zeros_like(all_fpr)
|
| 161 |
+
for class_name in class_list:
|
| 162 |
+
mean_tpr += np.interp(all_fpr, _fpr[class_name], _tpr[class_name])
|
| 163 |
+
|
| 164 |
+
# Finally average it and compute AUC
|
| 165 |
+
mean_tpr /= num_classes
|
| 166 |
+
|
| 167 |
+
_fpr['macro'] = all_fpr
|
| 168 |
+
_tpr['macro'] = mean_tpr
|
| 169 |
+
self._set_roc(label_metrics, split, _fpr['macro'], _tpr['macro'])
|
| 170 |
+
return label_metrics
|
| 171 |
+
|
| 172 |
+
def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
| 173 |
+
"""
|
| 174 |
+
Calculate ROC and AUC for label depending on binary or multi-class.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
label_name (str):label name
|
| 178 |
+
df_group (pd.DataFrame): likelihood for group
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
LabelMetrics: metrics of 'val' and 'test'
|
| 182 |
+
"""
|
| 183 |
+
pred_name_list = df_group.columns[df_group.columns.str.startswith('pred_' + label_name)]
|
| 184 |
+
isMultiClass = (len(pred_name_list) > 2)
|
| 185 |
+
if isMultiClass:
|
| 186 |
+
label_metrics = self._cal_label_roc_multi(label_name, df_group)
|
| 187 |
+
else:
|
| 188 |
+
label_metrics = self._cal_label_roc_binary(label_name, df_group)
|
| 189 |
+
return label_metrics
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class YYMixin:
|
| 193 |
+
"""
|
| 194 |
+
Class for calculating YY and R2.
|
| 195 |
+
"""
|
| 196 |
+
def _set_yy(self, label_metrics: LabelMetrics, split: str, y_obs: np.ndarray, y_pred: np.ndarray) -> None:
|
| 197 |
+
"""
|
| 198 |
+
Set ground truth, prediction, and R2.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
label_metrics (LabelMetrics): metrics of 'val' and 'test'
|
| 202 |
+
split (str): 'val' or 'test'
|
| 203 |
+
y_obs (np.ndarray): ground truth
|
| 204 |
+
y_pred (np.ndarray): prediction
|
| 205 |
+
|
| 206 |
+
self.metrics_kind = 'r2' is defined in class RegEval below.
|
| 207 |
+
"""
|
| 208 |
+
label_metrics.set_label_metrics(split, 'y_obs', y_obs.values)
|
| 209 |
+
label_metrics.set_label_metrics(split, 'y_pred', y_pred.values)
|
| 210 |
+
label_metrics.set_label_metrics(split, self.metrics_kind, metrics.r2_score(y_obs, y_pred))
|
| 211 |
+
|
| 212 |
+
def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
| 213 |
+
"""
|
| 214 |
+
Calculate YY and R2 for label.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
label_name (str): label name
|
| 218 |
+
df_group (pd.DataFrame): likelihood for group
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
LabelMetrics: metrics of 'val' and 'test'
|
| 222 |
+
"""
|
| 223 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
|
| 224 |
+
df_label = df_group[required_columns]
|
| 225 |
+
label_metrics = LabelMetrics()
|
| 226 |
+
for split in ['val', 'test']:
|
| 227 |
+
df_split = df_label.query('split == @split')
|
| 228 |
+
y_obs = df_split[label_name]
|
| 229 |
+
y_pred = df_split['pred_' + label_name]
|
| 230 |
+
self._set_yy(label_metrics, split, y_obs, y_pred)
|
| 231 |
+
return label_metrics
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class C_IndexMixin:
|
| 235 |
+
"""
|
| 236 |
+
Class for calculating C-Index.
|
| 237 |
+
"""
|
| 238 |
+
def _set_c_index(
|
| 239 |
+
self,
|
| 240 |
+
label_metrics: LabelMetrics,
|
| 241 |
+
split: str,
|
| 242 |
+
periods: pd.Series,
|
| 243 |
+
preds: pd.Series,
|
| 244 |
+
labels: pd.Series
|
| 245 |
+
) -> None:
|
| 246 |
+
"""
|
| 247 |
+
Set C-Index.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
label_metrics (LabelMetrics): metrics of 'val' and 'test'
|
| 251 |
+
split (str): 'val' or 'test'
|
| 252 |
+
periods (pd.Series): periods
|
| 253 |
+
preds (pd.Series): prediction
|
| 254 |
+
labels (pd.Series): label
|
| 255 |
+
|
| 256 |
+
self.metrics_kind = 'c_index' is defined in class DeepSurvEval below.
|
| 257 |
+
"""
|
| 258 |
+
from lifelines.utils import concordance_index
|
| 259 |
+
value_c_index = concordance_index(periods, (-1)*preds, labels)
|
| 260 |
+
label_metrics.set_label_metrics(split, self.metrics_kind, value_c_index)
|
| 261 |
+
|
| 262 |
+
def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
| 263 |
+
"""
|
| 264 |
+
Calculate C-Index for label.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
label_name (str): label name
|
| 268 |
+
df_group (pd.DataFrame): likelihood for group
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
LabelMetrics: metrics of 'val' and 'test'
|
| 272 |
+
"""
|
| 273 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['periods', 'split']
|
| 274 |
+
df_label = df_group[required_columns]
|
| 275 |
+
label_metrics = LabelMetrics()
|
| 276 |
+
for split in ['val', 'test']:
|
| 277 |
+
df_split = df_label.query('split == @split')
|
| 278 |
+
periods = df_split['periods']
|
| 279 |
+
preds = df_split['pred_' + label_name]
|
| 280 |
+
labels = df_split[label_name]
|
| 281 |
+
self._set_c_index(label_metrics, split, periods, preds, labels)
|
| 282 |
+
return label_metrics
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class MetricsMixin:
|
| 286 |
+
"""
|
| 287 |
+
Class to calculate metrics and make summary.
|
| 288 |
+
"""
|
| 289 |
+
def _cal_group_metrics(self, df_group: pd.DataFrame) -> Dict[str, LabelMetrics]:
|
| 290 |
+
"""
|
| 291 |
+
Calculate metrics for each group.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
df_group (pd.DataFrame): likelihood for group
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
Dict[str, LabelMetrics]: dictionary of label and its LabelMetrics
|
| 298 |
+
eg. {{label_1: LabelMetrics(), label_2: LabelMetrics(), ...}
|
| 299 |
+
"""
|
| 300 |
+
label_list = list(df_group.columns[df_group.columns.str.startswith('label')])
|
| 301 |
+
group_metrics = dict()
|
| 302 |
+
for label_name in label_list:
|
| 303 |
+
label_metrics = self.cal_label_metrics(label_name, df_group)
|
| 304 |
+
group_metrics[label_name] = label_metrics
|
| 305 |
+
return group_metrics
|
| 306 |
+
|
| 307 |
+
def cal_whole_metrics(self, df_likelihood: pd.DataFrame) -> Dict[str, Dict[str, LabelMetrics]]:
|
| 308 |
+
"""
|
| 309 |
+
Calculate metrics for all groups.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
df_likelihood (pd.DataFrame) : DataFrame of likelihood
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Dict[str, Dict[str, LabelMetrics]]: dictionary of group and dictionary of label and its LabelMetrics
|
| 316 |
+
eg. {
|
| 317 |
+
groupA: {label_1: LabelMetrics(), label_2: LabelMetrics(), ...},
|
| 318 |
+
groupB: {label_1: LabelMetrics(), label_2: LabelMetrics()}, ...},
|
| 319 |
+
...}
|
| 320 |
+
"""
|
| 321 |
+
whole_metrics = dict()
|
| 322 |
+
for group in df_likelihood['group'].unique():
|
| 323 |
+
df_group = df_likelihood.query('group == @group')
|
| 324 |
+
whole_metrics[group] = self._cal_group_metrics(df_group)
|
| 325 |
+
return whole_metrics
|
| 326 |
+
|
| 327 |
+
def make_summary(
|
| 328 |
+
self,
|
| 329 |
+
whole_metrics: Dict[str, Dict[str, LabelMetrics]],
|
| 330 |
+
likelihood_path: Path,
|
| 331 |
+
metrics_kind: str
|
| 332 |
+
) -> pd.DataFrame:
|
| 333 |
+
"""
|
| 334 |
+
Make summary.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups
|
| 338 |
+
likelihood_path (Path): path to likelihood
|
| 339 |
+
metrics_kind (str): kind of metrics, ie, 'auc', 'r2', or 'c_index'
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
pd.DataFrame: summary
|
| 343 |
+
"""
|
| 344 |
+
_datetime = likelihood_path.parents[1].name
|
| 345 |
+
_weight = likelihood_path.stem.replace('likelihood_', '') + '.pt'
|
| 346 |
+
df_summary = pd.DataFrame()
|
| 347 |
+
for group, group_metrics in whole_metrics.items():
|
| 348 |
+
_new = dict()
|
| 349 |
+
_new['datetime'] = [_datetime]
|
| 350 |
+
_new['weight'] = [ _weight]
|
| 351 |
+
_new['group'] = [group]
|
| 352 |
+
for label_name, label_metrics in group_metrics.items():
|
| 353 |
+
_val_metrics = label_metrics.get_label_metrics('val', metrics_kind)
|
| 354 |
+
_test_metrics = label_metrics.get_label_metrics('test', metrics_kind)
|
| 355 |
+
_new[label_name + '_val_' + metrics_kind] = [f"{_val_metrics:.2f}"]
|
| 356 |
+
_new[label_name + '_test_' + metrics_kind] = [f"{_test_metrics:.2f}"]
|
| 357 |
+
df_summary = pd.concat([df_summary, pd.DataFrame(_new)], ignore_index=True)
|
| 358 |
+
|
| 359 |
+
df_summary = df_summary.sort_values('group')
|
| 360 |
+
return df_summary
|
| 361 |
+
|
| 362 |
+
def print_metrics(self, df_summary: pd.DataFrame, metrics_kind: str) -> None:
|
| 363 |
+
"""
|
| 364 |
+
Print metrics.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
df_summary (pd.DataFrame): summary
|
| 368 |
+
metrics_kind (str): kind of metrics, ie. 'auc', 'r2', or 'c_index'
|
| 369 |
+
"""
|
| 370 |
+
label_list = list(df_summary.columns[df_summary.columns.str.startswith('label')]) # [label_1_val, label_1_test, label_2_val, label_2_test, ...]
|
| 371 |
+
num_splits = len(['val', 'test'])
|
| 372 |
+
_column_val_test_list = [label_list[i:i+num_splits] for i in range(0, len(label_list), num_splits)] # [[label_1_val, label_1_test], [label_2_val, label_2_test], ...]
|
| 373 |
+
for _, row in df_summary.iterrows():
|
| 374 |
+
logger.info(row['group'])
|
| 375 |
+
for _column_val_test in _column_val_test_list:
|
| 376 |
+
_label_name = _column_val_test[0].replace('_val', '')
|
| 377 |
+
_label_name_val = _column_val_test[0]
|
| 378 |
+
_label_name_test = _column_val_test[1]
|
| 379 |
+
logger.info(f"{_label_name:<25} val_{metrics_kind}: {row[_label_name_val]:>7}, test_{metrics_kind}: {row[_label_name_test]:>7}")
|
| 380 |
+
|
| 381 |
+
def update_summary(self, df_summary: pd.DataFrame, likelihood_path: Path) -> None:
|
| 382 |
+
"""
|
| 383 |
+
Update summary.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
df_summary (pd.DataFrame): summary to be added to the previous summary
|
| 387 |
+
likelihood_path (Path): path to likelihood
|
| 388 |
+
"""
|
| 389 |
+
_project_dir = likelihood_path.parents[3]
|
| 390 |
+
summary_dir = Path(_project_dir, 'summary')
|
| 391 |
+
summary_path = Path(summary_dir, 'summary.csv')
|
| 392 |
+
if summary_path.exists():
|
| 393 |
+
df_prev = pd.read_csv(summary_path)
|
| 394 |
+
df_updated = pd.concat([df_prev, df_summary], axis=0)
|
| 395 |
+
else:
|
| 396 |
+
summary_dir.mkdir(parents=True, exist_ok=True)
|
| 397 |
+
df_updated = df_summary
|
| 398 |
+
df_updated.to_csv(summary_path, index=False)
|
| 399 |
+
|
| 400 |
+
def make_metrics(self, likelihood_path: Path) -> None:
|
| 401 |
+
"""
|
| 402 |
+
Make metrics.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
likelihood_path (Path): path to likelihood
|
| 406 |
+
"""
|
| 407 |
+
df_likelihood = pd.read_csv(likelihood_path)
|
| 408 |
+
whole_metrics = self.cal_whole_metrics(df_likelihood)
|
| 409 |
+
self.make_save_fig(whole_metrics, likelihood_path, self.fig_kind)
|
| 410 |
+
df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind)
|
| 411 |
+
self.print_metrics(df_summary, self.metrics_kind)
|
| 412 |
+
self.update_summary(df_summary, likelihood_path)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class FigROCMixin:
|
| 416 |
+
"""
|
| 417 |
+
Class to plot ROC.
|
| 418 |
+
"""
|
| 419 |
+
def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt:
|
| 420 |
+
"""
|
| 421 |
+
Plot ROC.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
group (str): group
|
| 425 |
+
group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
plt: ROC
|
| 429 |
+
"""
|
| 430 |
+
label_list = group_metrics.keys()
|
| 431 |
+
num_rows = 1
|
| 432 |
+
num_cols = len(label_list)
|
| 433 |
+
base_size = 7
|
| 434 |
+
height = num_rows * base_size
|
| 435 |
+
width = num_cols * height
|
| 436 |
+
fig = plt.figure(figsize=(width, height))
|
| 437 |
+
|
| 438 |
+
for i, label_name in enumerate(label_list):
|
| 439 |
+
label_metrics = group_metrics[label_name]
|
| 440 |
+
offset = i + 1
|
| 441 |
+
ax_i = fig.add_subplot(
|
| 442 |
+
num_rows,
|
| 443 |
+
num_cols,
|
| 444 |
+
offset,
|
| 445 |
+
title=group + ': ' + label_name,
|
| 446 |
+
xlabel='1 - Specificity',
|
| 447 |
+
ylabel='Sensitivity',
|
| 448 |
+
xmargin=0,
|
| 449 |
+
ymargin=0
|
| 450 |
+
)
|
| 451 |
+
ax_i.plot(label_metrics.val.fpr, label_metrics.val.tpr, label=f"AUC_val = {label_metrics.val.auc:.2f}", marker='x')
|
| 452 |
+
ax_i.plot(label_metrics.test.fpr, label_metrics.test.tpr, label=f"AUC_test = {label_metrics.test.auc:.2f}", marker='o')
|
| 453 |
+
ax_i.grid()
|
| 454 |
+
ax_i.legend()
|
| 455 |
+
fig.tight_layout()
|
| 456 |
+
return fig
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class FigYYMixin:
|
| 460 |
+
"""
|
| 461 |
+
Class to plot YY-graph.
|
| 462 |
+
"""
|
| 463 |
+
def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt:
|
| 464 |
+
"""
|
| 465 |
+
Plot yy.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
group (str): group
|
| 469 |
+
group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
plt: YY-graph
|
| 473 |
+
"""
|
| 474 |
+
label_list = group_metrics.keys()
|
| 475 |
+
num_splits = len(['val', 'test'])
|
| 476 |
+
num_rows = 1
|
| 477 |
+
num_cols = len(label_list) * num_splits
|
| 478 |
+
base_size = 7
|
| 479 |
+
height = num_rows * base_size
|
| 480 |
+
width = num_cols * height
|
| 481 |
+
fig = plt.figure(figsize=(width, height))
|
| 482 |
+
|
| 483 |
+
for i, label_name in enumerate(label_list):
|
| 484 |
+
label_metrics = group_metrics[label_name]
|
| 485 |
+
val_offset = (i * num_splits) + 1
|
| 486 |
+
test_offset = val_offset + 1
|
| 487 |
+
|
| 488 |
+
val_ax = fig.add_subplot(
|
| 489 |
+
num_rows,
|
| 490 |
+
num_cols,
|
| 491 |
+
val_offset,
|
| 492 |
+
title=group + ': ' + label_name + '\n' + 'val: Observed-Predicted Plot',
|
| 493 |
+
xlabel='Observed',
|
| 494 |
+
ylabel='Predicted',
|
| 495 |
+
xmargin=0,
|
| 496 |
+
ymargin=0
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
test_ax = fig.add_subplot(
|
| 500 |
+
num_rows,
|
| 501 |
+
num_cols,
|
| 502 |
+
test_offset,
|
| 503 |
+
title=group + ': ' + label_name + '\n' + 'test: Observed-Predicted Plot',
|
| 504 |
+
xlabel='Observed',
|
| 505 |
+
ylabel='Predicted',
|
| 506 |
+
xmargin=0,
|
| 507 |
+
ymargin=0
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
y_obs_val = label_metrics.val.y_obs
|
| 511 |
+
y_pred_val = label_metrics.val.y_pred
|
| 512 |
+
|
| 513 |
+
y_obs_test = label_metrics.test.y_obs
|
| 514 |
+
y_pred_test = label_metrics.test.y_pred
|
| 515 |
+
|
| 516 |
+
# Plot
|
| 517 |
+
color = mcolors.TABLEAU_COLORS
|
| 518 |
+
val_ax.scatter(y_obs_val, y_pred_val, color=color['tab:blue'], label='val')
|
| 519 |
+
test_ax.scatter(y_obs_test, y_pred_test, color=color['tab:orange'], label='test')
|
| 520 |
+
|
| 521 |
+
# Draw diagonal line
|
| 522 |
+
y_values_val = np.concatenate([y_obs_val.flatten(), y_pred_val.flatten()])
|
| 523 |
+
y_values_test = np.concatenate([y_obs_test.flatten(), y_pred_test.flatten()])
|
| 524 |
+
|
| 525 |
+
y_values_val_min, y_values_val_max, y_values_val_range = np.amin(y_values_val), np.amax(y_values_val), np.ptp(y_values_val)
|
| 526 |
+
y_values_test_min, y_values_test_max, y_values_test_range = np.amin(y_values_test), np.amax(y_values_test), np.ptp(y_values_test)
|
| 527 |
+
|
| 528 |
+
val_ax.plot([y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)],
|
| 529 |
+
[y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)], color='red')
|
| 530 |
+
|
| 531 |
+
test_ax.plot([y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)],
|
| 532 |
+
[y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)], color='red')
|
| 533 |
+
|
| 534 |
+
fig.tight_layout()
|
| 535 |
+
return fig
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class FigMixin:
|
| 539 |
+
"""
|
| 540 |
+
Class for make and save figure
|
| 541 |
+
This class is for ROC and YY-graph.
|
| 542 |
+
"""
|
| 543 |
+
def make_save_fig(self, whole_metrics: Dict[str, Dict[str, LabelMetrics]], likelihood_path: Path, fig_kind: str) -> None:
|
| 544 |
+
"""
|
| 545 |
+
Make and save figure.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups
|
| 549 |
+
likelihood_path (Path): path to likelihood
|
| 550 |
+
fig_kind (str): kind of figure, ie. 'roc' or 'yy'
|
| 551 |
+
"""
|
| 552 |
+
_datetime_dir = likelihood_path.parents[1]
|
| 553 |
+
save_dir = Path(_datetime_dir, fig_kind)
|
| 554 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 555 |
+
_fig_name = fig_kind + '_' + likelihood_path.stem.replace('likelihood_', '')
|
| 556 |
+
for group, group_metrics in whole_metrics.items():
|
| 557 |
+
fig = self._plot_fig_group_metrics(group, group_metrics)
|
| 558 |
+
save_path = Path(save_dir, group + '_' + _fig_name + '.png')
|
| 559 |
+
fig.savefig(save_path)
|
| 560 |
+
plt.close()
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class ClsEval(MetricsMixin, ROCMixin, FigMixin, FigROCMixin):
|
| 564 |
+
"""
|
| 565 |
+
Class for calculation metrics for classification.
|
| 566 |
+
"""
|
| 567 |
+
def __init__(self) -> None:
|
| 568 |
+
self.fig_kind = 'roc'
|
| 569 |
+
self.metrics_kind = 'auc'
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class RegEval(MetricsMixin, YYMixin, FigMixin, FigYYMixin):
|
| 573 |
+
"""
|
| 574 |
+
Class for calculation metrics for regression.
|
| 575 |
+
"""
|
| 576 |
+
def __init__(self) -> None:
|
| 577 |
+
self.fig_kind = 'yy'
|
| 578 |
+
self.metrics_kind = 'r2'
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class DeepSurvEval(MetricsMixin, C_IndexMixin):
|
| 582 |
+
"""
|
| 583 |
+
Class for calculation metrics for DeepSurv.
|
| 584 |
+
"""
|
| 585 |
+
def __init__(self) -> None:
|
| 586 |
+
self.fig_kind = None
|
| 587 |
+
self.metrics_kind = 'c_index'
|
| 588 |
+
|
| 589 |
+
def make_metrics(self, likelihood_path: Path) -> None:
|
| 590 |
+
"""
|
| 591 |
+
Make metrics, substantially this method handles everything all.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
likelihood_path (Path): path to likelihood
|
| 595 |
+
|
| 596 |
+
Overwrite def make_metrics() in class MetricsMixin by deleting self.make_save_fig(),
|
| 597 |
+
because of no need to plot and save figure.
|
| 598 |
+
"""
|
| 599 |
+
df_likelihood = pd.read_csv(likelihood_path)
|
| 600 |
+
whole_metrics = self.cal_whole_metrics(df_likelihood)
|
| 601 |
+
df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind)
|
| 602 |
+
self.print_metrics(df_summary, self.metrics_kind)
|
| 603 |
+
self.update_summary(df_summary, likelihood_path)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def set_eval(task: str) -> Union[ClsEval, RegEval, DeepSurvEval]:
|
| 607 |
+
"""
|
| 608 |
+
Set class for evaluation depending on task depending on task.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
task (str): task
|
| 612 |
+
|
| 613 |
+
Returns:
|
| 614 |
+
Union[ClsEval, RegEval, DeepSurvEval]: class for evaluation
|
| 615 |
+
"""
|
| 616 |
+
if task == 'classification':
|
| 617 |
+
return ClsEval()
|
| 618 |
+
elif task == 'regression':
|
| 619 |
+
return RegEval()
|
| 620 |
+
elif task == 'deepsurv':
|
| 621 |
+
return DeepSurvEval()
|
| 622 |
+
else:
|
| 623 |
+
raise ValueError(f"Invalid task: {task}.")
|
lib/options.py
ADDED
|
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
from distutils.util import strtobool
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
from .logger import BaseLogger
|
| 11 |
+
from typing import List, Dict, Tuple, Union
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = BaseLogger.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Options:
|
| 18 |
+
"""
|
| 19 |
+
Class for options.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, datetime: str = None, isTrain: bool = None) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
datetime (str, optional): date time Args:
|
| 25 |
+
isTrain (bool, optional): Variable indicating whether training or not. Defaults to None.
|
| 26 |
+
"""
|
| 27 |
+
self.parser = argparse.ArgumentParser(description='Options for training or test')
|
| 28 |
+
|
| 29 |
+
# CSV
|
| 30 |
+
self.parser.add_argument('--csvpath', type=str, required=True, help='path to csv for training or test')
|
| 31 |
+
|
| 32 |
+
# GPU Ids
|
| 33 |
+
self.parser.add_argument('--gpu_ids', type=str, default='cpu', help='gpu ids: e.g. 0, 0-1-2, 0-2. Use cpu for CPU (Default: cpu)')
|
| 34 |
+
|
| 35 |
+
if isTrain:
|
| 36 |
+
# Task
|
| 37 |
+
self.parser.add_argument('--task', type=str, required=True, choices=['classification', 'regression', 'deepsurv'], help='Task')
|
| 38 |
+
|
| 39 |
+
# Model
|
| 40 |
+
self.parser.add_argument('--model', type=str, required=True, help='model: MLP, CNN, ViT, or MLP+(CNN or ViT)')
|
| 41 |
+
self.parser.add_argument('--pretrained', type=strtobool, default=False, help='For use of pretrained model(CNN or ViT)')
|
| 42 |
+
|
| 43 |
+
# Training and Internal validation
|
| 44 |
+
self.parser.add_argument('--criterion', type=str, required=True, choices=['CEL', 'MSE', 'RMSE', 'MAE', 'NLL'], help='criterion')
|
| 45 |
+
self.parser.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adadelta', 'RMSprop', 'Adam', 'RAdam'], help='optimizer')
|
| 46 |
+
self.parser.add_argument('--lr', type=float, metavar='N', help='learning rate')
|
| 47 |
+
self.parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs (Default: 10)')
|
| 48 |
+
|
| 49 |
+
# Batch size
|
| 50 |
+
self.parser.add_argument('--batch_size', type=int, required=True, metavar='N', help='batch size in training')
|
| 51 |
+
|
| 52 |
+
# Preprocess for image
|
| 53 |
+
self.parser.add_argument('--augmentation', type=str, default='no', choices=['xrayaug', 'trivialaugwide', 'randaug', 'no'], help='kind of augmentation')
|
| 54 |
+
self.parser.add_argument('--normalize_image', type=str, choices=['yes', 'no'], default='yes', help='image normalization: yes, no (Default: yes)')
|
| 55 |
+
|
| 56 |
+
# Sampler
|
| 57 |
+
self.parser.add_argument('--sampler', type=str, default='no', choices=['yes', 'no'], help='sample data in training or not, yes or no')
|
| 58 |
+
|
| 59 |
+
# Input channel
|
| 60 |
+
self.parser.add_argument('--in_channel', type=int, required=True, choices=[1, 3], help='channel of input image')
|
| 61 |
+
self.parser.add_argument('--vit_image_size', type=int, default=0, help='input image size for ViT. Set 0 if not used ViT (Default: 0)')
|
| 62 |
+
|
| 63 |
+
# Weight saving strategy
|
| 64 |
+
self.parser.add_argument('--save_weight_policy', type=str, choices=['best', 'each'], default='best', help='Save weight policy: best, or each(ie. save each time loss decreases when multi-label output) (Default: best)')
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
# Directory of weight at training
|
| 68 |
+
self.parser.add_argument('--weight_dir', type=str, default=None, help='directory of weight to be used when test. If None, the latest one is selected')
|
| 69 |
+
|
| 70 |
+
# Test bash size
|
| 71 |
+
self.parser.add_argument('--test_batch_size', type=int, default=1, metavar='N', help='batch size for test (Default: 1)')
|
| 72 |
+
|
| 73 |
+
# Splits for test
|
| 74 |
+
self.parser.add_argument('--test_splits', type=str, default='train-val-test', help='splits for test: e.g. test, val-test, train-val-test. (Default: train-val-test)')
|
| 75 |
+
|
| 76 |
+
self.args = self.parser.parse_args()
|
| 77 |
+
|
| 78 |
+
if datetime is not None:
|
| 79 |
+
self.args.datetime = datetime
|
| 80 |
+
|
| 81 |
+
assert isinstance(isTrain, bool), 'isTrain should be bool.'
|
| 82 |
+
self.args.isTrain = isTrain
|
| 83 |
+
|
| 84 |
+
def get_args(self) -> argparse.Namespace:
|
| 85 |
+
"""
|
| 86 |
+
Return arguments.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
argparse.Namespace: arguments
|
| 90 |
+
"""
|
| 91 |
+
return self.args
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class CSVParser:
|
| 95 |
+
"""
|
| 96 |
+
Class to get information of csv and cast csv.
|
| 97 |
+
"""
|
| 98 |
+
def __init__(self, csvpath: str, task: str, isTrain: bool = None) -> None:
|
| 99 |
+
"""
|
| 100 |
+
Args:
|
| 101 |
+
csvpath (str): path to csv
|
| 102 |
+
task (str): task
|
| 103 |
+
isTrain (bool): if training or not
|
| 104 |
+
"""
|
| 105 |
+
self.csvpath = csvpath
|
| 106 |
+
self.task = task
|
| 107 |
+
|
| 108 |
+
_df_source = pd.read_csv(self.csvpath)
|
| 109 |
+
_df_source = _df_source[_df_source['split'] != 'exclude']
|
| 110 |
+
|
| 111 |
+
self.input_list = list(_df_source.columns[_df_source.columns.str.startswith('input')])
|
| 112 |
+
self.label_list = list(_df_source.columns[_df_source.columns.str.startswith('label')])
|
| 113 |
+
if self.task == 'deepsurv':
|
| 114 |
+
_period_name_list = list(_df_source.columns[_df_source.columns.str.startswith('period')])
|
| 115 |
+
assert (len(_period_name_list) == 1), f"One column of period should be contained in {self.csvpath} when deepsurv."
|
| 116 |
+
self.period_name = _period_name_list[0]
|
| 117 |
+
|
| 118 |
+
_df_source = self._cast(_df_source, self.task)
|
| 119 |
+
|
| 120 |
+
# If no column of group, add it.
|
| 121 |
+
if 'group' not in _df_source.columns:
|
| 122 |
+
_df_source = _df_source.assign(group='all')
|
| 123 |
+
|
| 124 |
+
self.df_source = _df_source
|
| 125 |
+
|
| 126 |
+
if isTrain:
|
| 127 |
+
self.mlp_num_inputs = len(self.input_list)
|
| 128 |
+
self.num_outputs_for_label = self._define_num_outputs_for_label(self.df_source, self.label_list, self.task)
|
| 129 |
+
|
| 130 |
+
def _cast(self, df_source: pd.DataFrame, task: str) -> pd.DataFrame:
|
| 131 |
+
"""
|
| 132 |
+
Make dictionary of cast depending on task.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
df_source (pd.DataFrame): excluded DataFrame
|
| 136 |
+
task: (str): task
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
DataFrame: csv excluded and cast depending on task
|
| 140 |
+
"""
|
| 141 |
+
_cast_input = {input_name: float for input_name in self.input_list}
|
| 142 |
+
|
| 143 |
+
if task == 'classification':
|
| 144 |
+
_cast_label = {label_name: int for label_name in self.label_list}
|
| 145 |
+
_casts = {**_cast_input, **_cast_label}
|
| 146 |
+
df_source = df_source.astype(_casts)
|
| 147 |
+
return df_source
|
| 148 |
+
|
| 149 |
+
elif task == 'regression':
|
| 150 |
+
_cast_label = {label_name: float for label_name in self.label_list}
|
| 151 |
+
_casts = {**_cast_input, **_cast_label}
|
| 152 |
+
df_source = df_source.astype(_casts)
|
| 153 |
+
return df_source
|
| 154 |
+
|
| 155 |
+
elif task == 'deepsurv':
|
| 156 |
+
_cast_label = {label_name: int for label_name in self.label_list}
|
| 157 |
+
_cast_period = {self.period_name: int}
|
| 158 |
+
_casts = {**_cast_input, **_cast_label, **_cast_period}
|
| 159 |
+
df_source = df_source.astype(_casts)
|
| 160 |
+
return df_source
|
| 161 |
+
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError(f"Invalid task: {self.task}.")
|
| 164 |
+
|
| 165 |
+
def _define_num_outputs_for_label(self, df_source: pd.DataFrame, label_list: List[str], task :str) -> Dict[str, int]:
|
| 166 |
+
"""
|
| 167 |
+
Define the number of outputs for each label.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
df_source (pd.DataFrame): DataFrame of csv
|
| 171 |
+
label_list (List[str]): list of labels
|
| 172 |
+
task: str
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Dict[str, int]: dictionary of the number of outputs for each label
|
| 176 |
+
eg.
|
| 177 |
+
classification: _num_outputs_for_label = {label_A: 2, label_B: 3, ...}
|
| 178 |
+
regression, deepsurv: _num_outputs_for_label = {label_A: 1, label_B: 1, ...}
|
| 179 |
+
deepsurv: _num_outputs_for_label = {label_A: 1}
|
| 180 |
+
"""
|
| 181 |
+
if task == 'classification':
|
| 182 |
+
_num_outputs_for_label = {label_name: df_source[label_name].nunique() for label_name in label_list}
|
| 183 |
+
return _num_outputs_for_label
|
| 184 |
+
|
| 185 |
+
elif (task == 'regression') or (task == 'deepsurv'):
|
| 186 |
+
_num_outputs_for_label = {label_name: 1 for label_name in label_list}
|
| 187 |
+
return _num_outputs_for_label
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"Invalid task: {task}.")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _parse_model(model_name: str) -> Tuple[Union[str, None], Union[str, None]]:
|
| 194 |
+
"""
|
| 195 |
+
Parse model name.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
model_name (str): model name (eg. MLP, ResNey18, or MLP+ResNet18)
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
Tuple[str, str]: MLP, CNN or Vision Transformer name
|
| 202 |
+
eg. 'MLP', 'ResNet18', 'MLP+ResNet18' ->
|
| 203 |
+
['MLP'], ['ResNet18'], ['MLP', 'ResNet18']
|
| 204 |
+
"""
|
| 205 |
+
_model = model_name.split('+')
|
| 206 |
+
mlp = 'MLP' if 'MLP' in _model else None
|
| 207 |
+
_net = [_n for _n in _model if _n != 'MLP']
|
| 208 |
+
net = _net[0] if _net != [] else None
|
| 209 |
+
return mlp, net
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _parse_gpu_ids(gpu_ids: str) -> List[int]:
|
| 213 |
+
"""
|
| 214 |
+
Parse GPU ids concatenated with '-' to list of integers of GPU ids.
|
| 215 |
+
eg. '0-1-2' -> [0, 1, 2], '-1' -> []
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
gpu_ids (str): GPU Ids
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
List[int]: list of GPU ids
|
| 222 |
+
"""
|
| 223 |
+
if (gpu_ids == 'cpu') or (gpu_ids == 'cpu\r'):
|
| 224 |
+
str_ids = []
|
| 225 |
+
else:
|
| 226 |
+
str_ids = gpu_ids.split('-')
|
| 227 |
+
_gpu_ids = []
|
| 228 |
+
for str_id in str_ids:
|
| 229 |
+
id = int(str_id)
|
| 230 |
+
if id >= 0:
|
| 231 |
+
_gpu_ids.append(id)
|
| 232 |
+
return _gpu_ids
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _get_latest_weight_dir() -> str:
|
| 236 |
+
"""
|
| 237 |
+
Return the latest path to directory of weight made at training.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
str: path to directory of the latest weight
|
| 241 |
+
eg. 'results/<project>/trials/2022-09-30-15-56-60/weights'
|
| 242 |
+
"""
|
| 243 |
+
_weight_dirs = list(Path('results').glob('*/trials/*/weights'))
|
| 244 |
+
assert (_weight_dirs != []), 'No directory of weight.'
|
| 245 |
+
weight_dir = max(_weight_dirs, key=lambda weight_dir: weight_dir.stat().st_mtime)
|
| 246 |
+
return str(weight_dir)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _collect_weight_paths(weight_dir: str) -> List[str]:
|
| 250 |
+
"""
|
| 251 |
+
Return list of weight paths.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
weight_dir (str): path to directory of weights
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
List[str]: list of weight paths
|
| 258 |
+
"""
|
| 259 |
+
_weight_paths = list(Path(weight_dir).glob('*.pt'))
|
| 260 |
+
assert _weight_paths != [], f"No weight in {weight_dir}."
|
| 261 |
+
_weight_paths.sort(key=lambda path: path.stat().st_mtime)
|
| 262 |
+
_weight_paths = [str(weight_path) for weight_path in _weight_paths]
|
| 263 |
+
return _weight_paths
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class ParamTable:
|
| 267 |
+
"""
|
| 268 |
+
Class to make table to dispatch parameters by group.
|
| 269 |
+
"""
|
| 270 |
+
def __init__(self) -> None:
|
| 271 |
+
# groups
|
| 272 |
+
# key is abbreviation, value is group name
|
| 273 |
+
self.groups = {
|
| 274 |
+
'mo': 'model',
|
| 275 |
+
'dl': 'dataloader',
|
| 276 |
+
'trc': 'train_conf',
|
| 277 |
+
'tsc': 'test_conf',
|
| 278 |
+
'sa': 'save',
|
| 279 |
+
'lo': 'load',
|
| 280 |
+
'trp': 'train_print',
|
| 281 |
+
'tsp': 'test_print'
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
mo = self.groups['mo']
|
| 285 |
+
dl = self.groups['dl']
|
| 286 |
+
trc = self.groups['trc']
|
| 287 |
+
tsc = self.groups['tsc']
|
| 288 |
+
sa = self.groups['sa']
|
| 289 |
+
lo = self.groups['lo']
|
| 290 |
+
trp = self.groups['trp']
|
| 291 |
+
tsp = self.groups['tsp']
|
| 292 |
+
|
| 293 |
+
# The below shows that which group each parameter dispatches to.
|
| 294 |
+
self.dispatch = {
|
| 295 |
+
'datetime': [sa],
|
| 296 |
+
'project': [sa, trp, tsp],
|
| 297 |
+
'csvpath': [sa, trp, tsp],
|
| 298 |
+
'task': [dl, tsc, sa, lo, trp, tsp],
|
| 299 |
+
'isTrain': [dl, trp, tsp],
|
| 300 |
+
|
| 301 |
+
'model': [sa, lo, trp, tsp],
|
| 302 |
+
'vit_image_size': [mo, sa, lo, trp, tsp],
|
| 303 |
+
'pretrained': [mo, sa, trp],
|
| 304 |
+
'mlp': [mo, dl],
|
| 305 |
+
'net': [mo, dl],
|
| 306 |
+
|
| 307 |
+
'weight_dir': [tsc, tsp],
|
| 308 |
+
'weight_paths': [tsc],
|
| 309 |
+
|
| 310 |
+
'criterion': [trc, sa, trp],
|
| 311 |
+
'optimizer': [trc, sa, trp],
|
| 312 |
+
'lr': [trc, sa, trp],
|
| 313 |
+
'epochs': [trc, sa, trp],
|
| 314 |
+
|
| 315 |
+
'batch_size': [dl, sa, trp],
|
| 316 |
+
'test_batch_size': [dl, tsp],
|
| 317 |
+
'test_splits': [tsc, tsp],
|
| 318 |
+
|
| 319 |
+
'in_channel': [mo, dl, sa, lo, trp, tsp],
|
| 320 |
+
'normalize_image': [dl, sa, lo, trp, tsp],
|
| 321 |
+
'augmentation': [dl, sa, trp],
|
| 322 |
+
'sampler': [dl, sa, trp],
|
| 323 |
+
|
| 324 |
+
'df_source': [dl],
|
| 325 |
+
'label_list': [dl, trc, sa, lo],
|
| 326 |
+
'input_list': [dl, sa, lo],
|
| 327 |
+
'period_name': [dl, sa, lo],
|
| 328 |
+
'mlp_num_inputs': [mo, sa, lo],
|
| 329 |
+
'num_outputs_for_label': [mo, sa, lo, tsc],
|
| 330 |
+
|
| 331 |
+
'save_weight_policy': [sa, trp, trc],
|
| 332 |
+
'scaler_path': [dl, tsp],
|
| 333 |
+
'save_datetime_dir': [trc, tsc, trp, tsp],
|
| 334 |
+
|
| 335 |
+
'gpu_ids': [trc, tsc, sa, trp, tsp],
|
| 336 |
+
'device': [mo, trc, tsc],
|
| 337 |
+
'dataset_info': [trc, sa, trp, tsp]
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
self.table = self._make_table()
|
| 341 |
+
|
| 342 |
+
def _make_table(self) -> pd.DataFrame:
|
| 343 |
+
"""
|
| 344 |
+
Make table to dispatch parameters by group.
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
pd.DataFrame: table which shows that which group each parameter belongs to.
|
| 348 |
+
"""
|
| 349 |
+
df_table = pd.DataFrame([], index=self.dispatch.keys(), columns=self.groups.values()).fillna('no')
|
| 350 |
+
for param, grps in self.dispatch.items():
|
| 351 |
+
for grp in grps:
|
| 352 |
+
df_table.loc[param, grp] = 'yes'
|
| 353 |
+
|
| 354 |
+
df_table = df_table.reset_index()
|
| 355 |
+
df_table = df_table.rename(columns={'index': 'parameter'})
|
| 356 |
+
return df_table
|
| 357 |
+
|
| 358 |
+
def get_by_group(self, group_name: str) -> List[str]:
|
| 359 |
+
"""
|
| 360 |
+
Return list of parameters which belong to group
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
group_name (str): group name
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
List[str]: list of parameters
|
| 367 |
+
"""
|
| 368 |
+
_df_table = self.table
|
| 369 |
+
_param_names = _df_table[_df_table[group_name] == 'yes']['parameter'].tolist()
|
| 370 |
+
return _param_names
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
Param_Table = ParamTable()
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class ParamSet:
|
| 377 |
+
"""
|
| 378 |
+
Class to store required parameters for each group.
|
| 379 |
+
"""
|
| 380 |
+
pass
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _dispatch_by_group(args: argparse.Namespace, group_name: str) -> ParamSet:
|
| 384 |
+
"""
|
| 385 |
+
Dispatch parameters depending on group.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
args (argparse.Namespace): arguments
|
| 389 |
+
group_name (str): group
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
ParamSet: class containing parameters for group
|
| 393 |
+
"""
|
| 394 |
+
_param_names = Param_Table.get_by_group(group_name)
|
| 395 |
+
param_set = ParamSet()
|
| 396 |
+
for param_name in _param_names:
|
| 397 |
+
if hasattr(args, param_name):
|
| 398 |
+
_arg = getattr(args, param_name)
|
| 399 |
+
setattr(param_set, param_name, _arg)
|
| 400 |
+
return param_set
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def save_parameter(params: ParamSet, save_path: str) -> None:
|
| 404 |
+
"""
|
| 405 |
+
Save parameters.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
params (ParamSet): parameters
|
| 409 |
+
|
| 410 |
+
save_path (str): save path for parameters
|
| 411 |
+
"""
|
| 412 |
+
_saved = {_param: _arg for _param, _arg in vars(params).items()}
|
| 413 |
+
save_dir = Path(save_path).parents[0]
|
| 414 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 415 |
+
with open(save_path, 'w') as f:
|
| 416 |
+
json.dump(_saved, f, indent=4)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _retrieve_parameter(parameter_path: str) -> Dict[str, Union[str, int, float]]:
|
| 420 |
+
"""
|
| 421 |
+
Retrieve only parameters required at test from parameters at training.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
parameter_path (str): path to parameter_path
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Dict[str, Union[str, int, float]]: parameters at training
|
| 428 |
+
"""
|
| 429 |
+
with open(parameter_path) as f:
|
| 430 |
+
params = json.load(f)
|
| 431 |
+
|
| 432 |
+
_required = Param_Table.get_by_group('load')
|
| 433 |
+
params = {p: v for p, v in params.items() if p in _required}
|
| 434 |
+
return params
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def print_parameter(params: ParamSet) -> None:
|
| 438 |
+
"""
|
| 439 |
+
Print parameters.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
params (ParamSet): parameters
|
| 443 |
+
"""
|
| 444 |
+
|
| 445 |
+
LINE_LENGTH = 82
|
| 446 |
+
|
| 447 |
+
if params.isTrain:
|
| 448 |
+
phase = 'Training'
|
| 449 |
+
else:
|
| 450 |
+
phase = 'Test'
|
| 451 |
+
|
| 452 |
+
_header = f" Configuration of {phase} "
|
| 453 |
+
_padding = (LINE_LENGTH - len(_header) + 1) // 2 # round up
|
| 454 |
+
_header = ('-' * _padding) + _header + ('-' * _padding) + '\n'
|
| 455 |
+
|
| 456 |
+
_footer = ' End '
|
| 457 |
+
_padding = (LINE_LENGTH - len(_footer) + 1) // 2
|
| 458 |
+
_footer = ('-' * _padding) + _footer + ('-' * _padding) + '\n'
|
| 459 |
+
|
| 460 |
+
message = ''
|
| 461 |
+
message += _header
|
| 462 |
+
|
| 463 |
+
_params_dict = vars(params)
|
| 464 |
+
del _params_dict['isTrain']
|
| 465 |
+
for _param, _arg in _params_dict.items():
|
| 466 |
+
_str_arg = _arg2str(_param, _arg)
|
| 467 |
+
message += f"{_param:>30}: {_str_arg:<40}\n"
|
| 468 |
+
|
| 469 |
+
message += _footer
|
| 470 |
+
logger.info(message)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _arg2str(param: str, arg: Union[str, int, float]) -> str:
|
| 474 |
+
"""
|
| 475 |
+
Convert argument to string.
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
param (str): parameter
|
| 479 |
+
arg (Union[str, int, float]): argument
|
| 480 |
+
|
| 481 |
+
Returns:
|
| 482 |
+
str: strings of argument
|
| 483 |
+
"""
|
| 484 |
+
if param == 'lr':
|
| 485 |
+
if arg is None:
|
| 486 |
+
str_arg = 'Default'
|
| 487 |
+
else:
|
| 488 |
+
str_arg = str(param)
|
| 489 |
+
return str_arg
|
| 490 |
+
elif param == 'gpu_ids':
|
| 491 |
+
if arg == []:
|
| 492 |
+
str_arg = 'CPU selected'
|
| 493 |
+
else:
|
| 494 |
+
str_arg = f"{arg} (Primary GPU:{arg[0]})"
|
| 495 |
+
return str_arg
|
| 496 |
+
elif param == 'test_splits':
|
| 497 |
+
str_arg = ', '.join(arg)
|
| 498 |
+
return str_arg
|
| 499 |
+
elif param == 'dataset_info':
|
| 500 |
+
str_arg = ''
|
| 501 |
+
for i, (split, total) in enumerate(arg.items()):
|
| 502 |
+
if i < len(arg) - 1:
|
| 503 |
+
str_arg += (f"{split}_data={total}, ")
|
| 504 |
+
else:
|
| 505 |
+
str_arg += (f"{split}_data={total}")
|
| 506 |
+
return str_arg
|
| 507 |
+
else:
|
| 508 |
+
if arg is None:
|
| 509 |
+
str_arg = 'No need'
|
| 510 |
+
else:
|
| 511 |
+
str_arg = str(arg)
|
| 512 |
+
return str_arg
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _check_if_valid_criterion(task: str = None, criterion: str = None) -> None:
|
| 516 |
+
"""
|
| 517 |
+
Check if criterion is valid.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
task (str): task
|
| 521 |
+
criterion (str): criterion
|
| 522 |
+
"""
|
| 523 |
+
valid_criterion = {
|
| 524 |
+
'classification': ['CEL'],
|
| 525 |
+
'regression': ['MSE', 'RMSE', 'MAE'],
|
| 526 |
+
'deepsurv': ['NLL']
|
| 527 |
+
}
|
| 528 |
+
if criterion in valid_criterion[task]:
|
| 529 |
+
pass
|
| 530 |
+
else:
|
| 531 |
+
raise ValueError(f"Invalid criterion for task: task={task}, criterion={criterion}.")
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def _train_parse(args: argparse.Namespace) -> Dict[str, ParamSet]:
|
| 535 |
+
"""
|
| 536 |
+
Parse parameters required at training.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
args (argparse.Namespace): arguments
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
Dict[str, ParamSet]: parameters dispatched by group
|
| 543 |
+
"""
|
| 544 |
+
# Check if criterion is valid.
|
| 545 |
+
_check_if_valid_criterion(task=args.task, criterion=args.criterion)
|
| 546 |
+
|
| 547 |
+
args.project = Path(args.csvpath).stem
|
| 548 |
+
args.gpu_ids = _parse_gpu_ids(args.gpu_ids)
|
| 549 |
+
args.device = torch.device(f"cuda:{args.gpu_ids[0]}") if args.gpu_ids != [] else torch.device('cpu')
|
| 550 |
+
args.mlp, args.net = _parse_model(args.model)
|
| 551 |
+
args.pretrained = bool(args.pretrained) # strtobool('False') = 0 (== False)
|
| 552 |
+
args.save_datetime_dir = str(Path('results', args.project, 'trials', args.datetime))
|
| 553 |
+
|
| 554 |
+
# Parse csv
|
| 555 |
+
_csvparser = CSVParser(args.csvpath, args.task, args.isTrain)
|
| 556 |
+
args.df_source = _csvparser.df_source
|
| 557 |
+
args.dataset_info = {split: len(args.df_source[args.df_source['split'] == split]) for split in ['train', 'val']}
|
| 558 |
+
args.input_list = _csvparser.input_list
|
| 559 |
+
args.label_list = _csvparser.label_list
|
| 560 |
+
args.mlp_num_inputs = _csvparser.mlp_num_inputs
|
| 561 |
+
args.num_outputs_for_label = _csvparser.num_outputs_for_label
|
| 562 |
+
if args.task == 'deepsurv':
|
| 563 |
+
args.period_name = _csvparser.period_name
|
| 564 |
+
|
| 565 |
+
# Dispatch parameters
|
| 566 |
+
return {
|
| 567 |
+
'args_model': _dispatch_by_group(args, 'model'),
|
| 568 |
+
'args_dataloader': _dispatch_by_group(args, 'dataloader'),
|
| 569 |
+
'args_conf': _dispatch_by_group(args, 'train_conf'),
|
| 570 |
+
'args_print': _dispatch_by_group(args, 'train_print'),
|
| 571 |
+
'args_save': _dispatch_by_group(args, 'save')
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def _test_parse(args: argparse.Namespace) -> Dict[str, ParamSet]:
|
| 576 |
+
"""
|
| 577 |
+
Parse parameters required at test.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
args (argparse.Namespace): arguments
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
Dict[str, ParamSet]: parameters dispatched by group
|
| 584 |
+
"""
|
| 585 |
+
args.project = Path(args.csvpath).stem
|
| 586 |
+
args.gpu_ids = _parse_gpu_ids(args.gpu_ids)
|
| 587 |
+
args.device = torch.device(f"cuda:{args.gpu_ids[0]}") if args.gpu_ids != [] else torch.device('cpu')
|
| 588 |
+
|
| 589 |
+
# Collect weight paths
|
| 590 |
+
if args.weight_dir is None:
|
| 591 |
+
args.weight_dir = _get_latest_weight_dir()
|
| 592 |
+
args.weight_paths = _collect_weight_paths(args.weight_dir)
|
| 593 |
+
|
| 594 |
+
# Get datetime at training
|
| 595 |
+
_train_datetime_dir = Path(args.weight_dir).parents[0]
|
| 596 |
+
_train_datetime = _train_datetime_dir.name
|
| 597 |
+
|
| 598 |
+
args.save_datetime_dir = str(Path('results', args.project, 'trials', _train_datetime))
|
| 599 |
+
|
| 600 |
+
# Retrieve only parameters required at test
|
| 601 |
+
_parameter_path = str(Path(_train_datetime_dir, 'parameters.json'))
|
| 602 |
+
params = _retrieve_parameter(_parameter_path)
|
| 603 |
+
for _param, _arg in params.items():
|
| 604 |
+
setattr(args, _param, _arg)
|
| 605 |
+
|
| 606 |
+
# When test, the followings are always fixed.
|
| 607 |
+
args.augmentation = 'no'
|
| 608 |
+
args.sampler = 'no'
|
| 609 |
+
args.pretrained = False
|
| 610 |
+
|
| 611 |
+
args.mlp, args.net = _parse_model(args.model)
|
| 612 |
+
if args.mlp is not None:
|
| 613 |
+
args.scaler_path = str(Path(_train_datetime_dir, 'scaler.pkl'))
|
| 614 |
+
|
| 615 |
+
# Parse csv
|
| 616 |
+
_csvparser = CSVParser(args.csvpath, args.task)
|
| 617 |
+
args.df_source = _csvparser.df_source
|
| 618 |
+
|
| 619 |
+
# Align test_splits
|
| 620 |
+
args.test_splits = args.test_splits.split('-')
|
| 621 |
+
_splits = args.df_source['split'].unique().tolist()
|
| 622 |
+
if set(_splits) < set(args.test_splits):
|
| 623 |
+
args.test_splits = _splits
|
| 624 |
+
|
| 625 |
+
args.dataset_info = {split: len(args.df_source[args.df_source['split'] == split]) for split in args.test_splits}
|
| 626 |
+
|
| 627 |
+
# Dispatch parameters
|
| 628 |
+
return {
|
| 629 |
+
'args_model': _dispatch_by_group(args, 'model'),
|
| 630 |
+
'args_dataloader': _dispatch_by_group(args, 'dataloader'),
|
| 631 |
+
'args_conf': _dispatch_by_group(args, 'test_conf'),
|
| 632 |
+
'args_print': _dispatch_by_group(args, 'test_print')
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
def set_options(datetime_name: str = None, phase: str = None) -> argparse.Namespace:
|
| 636 |
+
"""
|
| 637 |
+
Parse options for training or test.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
datetime_name (str, optional): datetime name. Defaults to None.
|
| 641 |
+
phase (str, optional): train or test. Defaults to None.
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
argparse.Namespace: arguments
|
| 645 |
+
"""
|
| 646 |
+
if phase == 'train':
|
| 647 |
+
opt = Options(datetime=datetime_name, isTrain=True)
|
| 648 |
+
_args = opt.get_args()
|
| 649 |
+
args = _train_parse(_args)
|
| 650 |
+
return args
|
| 651 |
+
else:
|
| 652 |
+
opt = Options(isTrain=False)
|
| 653 |
+
_args = opt.get_args()
|
| 654 |
+
args = _test_parse(_args)
|
| 655 |
+
return args
|
parameters.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"datetime": "2025-01-25-04-52-01",
|
| 3 |
+
"project": "CXp_checker_patientbase",
|
| 4 |
+
"csvpath": "materials/docs/CXp_checker_patientbase.csv",
|
| 5 |
+
"task": "classification",
|
| 6 |
+
"model": "EfficientNetB4",
|
| 7 |
+
"vit_image_size": 0,
|
| 8 |
+
"pretrained": true,
|
| 9 |
+
"criterion": "CEL",
|
| 10 |
+
"optimizer": "Adam",
|
| 11 |
+
"lr": null,
|
| 12 |
+
"epochs": 100,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"bit_depth": 8,
|
| 15 |
+
"in_channel": 1,
|
| 16 |
+
"augmentation": "trivialaugwide",
|
| 17 |
+
"normalize_image": "yes",
|
| 18 |
+
"sampler": "distributed",
|
| 19 |
+
"label_list": [
|
| 20 |
+
"label_round",
|
| 21 |
+
"label_APorPA"
|
| 22 |
+
],
|
| 23 |
+
"input_list": [],
|
| 24 |
+
"mlp_num_inputs": 0,
|
| 25 |
+
"num_outputs_for_label": {
|
| 26 |
+
"label_round": 4,
|
| 27 |
+
"label_APorPA": 3
|
| 28 |
+
},
|
| 29 |
+
"save_weight_policy": "best",
|
| 30 |
+
"gpu_ids": [
|
| 31 |
+
0,
|
| 32 |
+
1,
|
| 33 |
+
2,
|
| 34 |
+
3
|
| 35 |
+
]
|
| 36 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
gradio
|
| 4 |
+
matplotlib
|
| 5 |
+
scikit-learn
|
sample/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
sample/sample_AP_inverted.png
ADDED
|
sample/sample_PA_right.png
ADDED
|
sample/sample_lateral_upright.png
ADDED
|
weight_epoch-011_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:edf393c0e7264242cfde729bbf8ccdc899933efa42330d253af94a10ca4b1bfb
|
| 3 |
+
size 71016653
|