Update app.py
Browse files
app.py
CHANGED
|
@@ -13,16 +13,18 @@ from lib.dataloader import ImageMixin
|
|
| 13 |
# ===========================================
|
| 14 |
# 1) パスなど(修正があれば適宜変更)
|
| 15 |
# ===========================================
|
| 16 |
-
test_weight = './weight_epoch-
|
| 17 |
parameter = './parameters.json'
|
| 18 |
|
| 19 |
# ===========================================
|
| 20 |
# 2) クラスラベルの定義
|
| 21 |
-
# -
|
| 22 |
# ===========================================
|
| 23 |
LABEL_APorPA = [
|
| 24 |
-
"AP",
|
| 25 |
-
"PA"
|
|
|
|
|
|
|
| 26 |
]
|
| 27 |
|
| 28 |
LABEL_ROUND = [
|
|
@@ -39,15 +41,12 @@ LABEL_ROUND = [
|
|
| 39 |
class ImageHandler(ImageMixin):
|
| 40 |
def __init__(self, params):
|
| 41 |
self.params = params
|
| 42 |
-
# ここでリサイズは省略(推論が重いので)
|
| 43 |
-
# 入力画像は既に256×256であることを想定
|
| 44 |
self.transform = T.Compose([
|
| 45 |
-
# T.Resize((256, 256)), #
|
| 46 |
-
T.ToTensor(),
|
| 47 |
])
|
| 48 |
|
| 49 |
def set_image(self, image):
|
| 50 |
-
# PIL画像 -> transform -> バッチ次元を付ける
|
| 51 |
image = self.transform(image)
|
| 52 |
image = {'image': image.unsqueeze(0)}
|
| 53 |
return image
|
|
@@ -89,8 +88,8 @@ model.eval() # 推論モード
|
|
| 89 |
def classify_APorPA_and_round(image):
|
| 90 |
"""
|
| 91 |
モデルが以下を出力する想定:
|
| 92 |
-
outputs["label_APorPA"] -> shape=[1,
|
| 93 |
-
outputs["label_round"] -> shape=[1, 4] (4
|
| 94 |
"""
|
| 95 |
image_handler = ImageHandler(args_dataloader)
|
| 96 |
image_tensor = image_handler.set_image(image)
|
|
@@ -98,13 +97,22 @@ def classify_APorPA_and_round(image):
|
|
| 98 |
with torch.no_grad():
|
| 99 |
outputs = model(image_tensor)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# --- label_APorPA ---
|
| 102 |
if "label_APorPA" not in outputs:
|
| 103 |
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 104 |
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
|
| 105 |
|
| 106 |
-
scores_APorPA = outputs["label_APorPA"] # shape=[1,
|
| 107 |
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
|
|
|
|
|
|
|
| 108 |
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
|
| 109 |
|
| 110 |
# --- label_round ---
|
|
@@ -125,7 +133,7 @@ html_content = """
|
|
| 125 |
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
|
| 126 |
<h3>Chest X-ray: AP/PA & Rotation Classification</h3>
|
| 127 |
<p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
|
| 128 |
-
<p>胸部レントゲン画像に対して、撮像方向(
|
| 129 |
</div>
|
| 130 |
"""
|
| 131 |
|
|
@@ -135,7 +143,7 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
|
|
| 135 |
|
| 136 |
with gr.Row():
|
| 137 |
input_image = gr.Image(type="pil", image_mode="L")
|
| 138 |
-
output_APorPA = gr.Label(label="Predicted AP
|
| 139 |
output_round = gr.Label(label="Predicted Rotation")
|
| 140 |
|
| 141 |
send_btn = gr.Button("Inference")
|
|
@@ -149,9 +157,9 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
|
|
| 149 |
# サンプルファイルは実際のパスに置き換えてください
|
| 150 |
gr.Examples(
|
| 151 |
examples=[
|
| 152 |
-
'./
|
| 153 |
-
'./
|
| 154 |
-
'./
|
| 155 |
],
|
| 156 |
inputs=input_image
|
| 157 |
)
|
|
|
|
| 13 |
# ===========================================
|
| 14 |
# 1) パスなど(修正があれば適宜変更)
|
| 15 |
# ===========================================
|
| 16 |
+
test_weight = './weight_epoch-003_best.pt'
|
| 17 |
parameter = './parameters.json'
|
| 18 |
|
| 19 |
# ===========================================
|
| 20 |
# 2) クラスラベルの定義
|
| 21 |
+
# - 今回は撮像方向 (3クラス) と回転方向 (4クラス)
|
| 22 |
# ===========================================
|
| 23 |
LABEL_APorPA = [
|
| 24 |
+
"AP", # class 0
|
| 25 |
+
"PA", # class 1
|
| 26 |
+
"Lateral", # class 2
|
| 27 |
+
# ↑ 実際のクラス名に合わせて変更してください
|
| 28 |
]
|
| 29 |
|
| 30 |
LABEL_ROUND = [
|
|
|
|
| 41 |
class ImageHandler(ImageMixin):
|
| 42 |
def __init__(self, params):
|
| 43 |
self.params = params
|
|
|
|
|
|
|
| 44 |
self.transform = T.Compose([
|
| 45 |
+
# T.Resize((256, 256)), # 必要ならコメントアウトを外す
|
| 46 |
+
T.ToTensor(),
|
| 47 |
])
|
| 48 |
|
| 49 |
def set_image(self, image):
|
|
|
|
| 50 |
image = self.transform(image)
|
| 51 |
image = {'image': image.unsqueeze(0)}
|
| 52 |
return image
|
|
|
|
| 88 |
def classify_APorPA_and_round(image):
|
| 89 |
"""
|
| 90 |
モデルが以下を出力する想定:
|
| 91 |
+
outputs["label_APorPA"] -> shape=[1, 3] (3クラス)
|
| 92 |
+
outputs["label_round"] -> shape=[1, 4] (4クラス)
|
| 93 |
"""
|
| 94 |
image_handler = ImageHandler(args_dataloader)
|
| 95 |
image_tensor = image_handler.set_image(image)
|
|
|
|
| 97 |
with torch.no_grad():
|
| 98 |
outputs = model(image_tensor)
|
| 99 |
|
| 100 |
+
# デバッグ用の出力チェック
|
| 101 |
+
print("keys in outputs =", outputs.keys())
|
| 102 |
+
if "label_APorPA" in outputs:
|
| 103 |
+
print("label_APorPA shape =", outputs["label_APorPA"].shape)
|
| 104 |
+
if "label_round" in outputs:
|
| 105 |
+
print("label_round shape =", outputs["label_round"].shape)
|
| 106 |
+
|
| 107 |
# --- label_APorPA ---
|
| 108 |
if "label_APorPA" not in outputs:
|
| 109 |
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 110 |
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
|
| 111 |
|
| 112 |
+
scores_APorPA = outputs["label_APorPA"] # shape=[1,3]想定
|
| 113 |
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
|
| 114 |
+
|
| 115 |
+
# IndexError が発生する場合は、pred_APorPA_idx が 0~2 の範囲外になる
|
| 116 |
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
|
| 117 |
|
| 118 |
# --- label_round ---
|
|
|
|
| 133 |
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
|
| 134 |
<h3>Chest X-ray: AP/PA & Rotation Classification</h3>
|
| 135 |
<p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
|
| 136 |
+
<p>胸部レントゲン画像に対して、撮像方向(3クラス)と回転方向(4クラス)を同時に推定します。</p>
|
| 137 |
</div>
|
| 138 |
"""
|
| 139 |
|
|
|
|
| 143 |
|
| 144 |
with gr.Row():
|
| 145 |
input_image = gr.Image(type="pil", image_mode="L")
|
| 146 |
+
output_APorPA = gr.Label(label="Predicted AP/PA/Lateral") # 3クラス想定に合わせた名前
|
| 147 |
output_round = gr.Label(label="Predicted Rotation")
|
| 148 |
|
| 149 |
send_btn = gr.Button("Inference")
|
|
|
|
| 157 |
# サンプルファイルは実際のパスに置き換えてください
|
| 158 |
gr.Examples(
|
| 159 |
examples=[
|
| 160 |
+
'./samples/sample_chest_AP_0deg.png',
|
| 161 |
+
'./samples/sample_chest_PA_90deg.png',
|
| 162 |
+
'./samples/sample_chest_LAT_180deg.png'
|
| 163 |
],
|
| 164 |
inputs=input_image
|
| 165 |
)
|