Update app.py
Browse files
app.py
CHANGED
|
@@ -18,31 +18,31 @@ parameter = './parameters.json'
|
|
| 18 |
|
| 19 |
# ===========================================
|
| 20 |
# 2) クラスラベルの定義
|
| 21 |
-
# -
|
|
|
|
| 22 |
# ===========================================
|
| 23 |
LABEL_APorPA = [
|
| 24 |
-
"AP",
|
| 25 |
-
"PA",
|
| 26 |
-
"Lateral",
|
| 27 |
-
# ↑ 実際のクラス名に合わせて変更してください
|
| 28 |
]
|
| 29 |
|
| 30 |
LABEL_ROUND = [
|
| 31 |
-
"
|
| 32 |
-
"
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
]
|
| 36 |
|
| 37 |
# ===========================================
|
| 38 |
# 3) 前処理用の ImageHandlerクラス
|
| 39 |
-
# - 画像が既に256×256
|
| 40 |
# ===========================================
|
| 41 |
class ImageHandler(ImageMixin):
|
| 42 |
def __init__(self, params):
|
| 43 |
self.params = params
|
| 44 |
self.transform = T.Compose([
|
| 45 |
-
# T.Resize((256, 256)), #
|
| 46 |
T.ToTensor(),
|
| 47 |
])
|
| 48 |
|
|
@@ -88,8 +88,8 @@ model.eval() # 推論モード
|
|
| 88 |
def classify_APorPA_and_round(image):
|
| 89 |
"""
|
| 90 |
モデルが以下を出力する想定:
|
| 91 |
-
outputs["label_APorPA"] -> shape=[1, 3] (3
|
| 92 |
-
outputs["label_round"] -> shape=[1, 4] (4
|
| 93 |
"""
|
| 94 |
image_handler = ImageHandler(args_dataloader)
|
| 95 |
image_tensor = image_handler.set_image(image)
|
|
@@ -109,10 +109,8 @@ def classify_APorPA_and_round(image):
|
|
| 109 |
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 110 |
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
|
| 111 |
|
| 112 |
-
scores_APorPA = outputs["label_APorPA"] # shape=[1,3]
|
| 113 |
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
|
| 114 |
-
|
| 115 |
-
# IndexError が発生する場合は、pred_APorPA_idx が 0~2 の範囲外になる
|
| 116 |
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
|
| 117 |
|
| 118 |
# --- label_round ---
|
|
@@ -120,7 +118,7 @@ def classify_APorPA_and_round(image):
|
|
| 120 |
print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 121 |
return predicted_APorPA, "ERROR: Missing 'label_round'"
|
| 122 |
|
| 123 |
-
scores_round = outputs["label_round"] # shape=[1,4]
|
| 124 |
pred_round_idx = torch.argmax(scores_round, dim=1).item()
|
| 125 |
predicted_round = LABEL_ROUND[pred_round_idx]
|
| 126 |
|
|
@@ -133,7 +131,8 @@ html_content = """
|
|
| 133 |
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
|
| 134 |
<h3>Chest X-ray: AP/PA & Rotation Classification</h3>
|
| 135 |
<p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
|
| 136 |
-
<p>胸部レントゲン画像に対して、撮像方向(3
|
|
|
|
| 137 |
</div>
|
| 138 |
"""
|
| 139 |
|
|
@@ -143,7 +142,7 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
|
|
| 143 |
|
| 144 |
with gr.Row():
|
| 145 |
input_image = gr.Image(type="pil", image_mode="L")
|
| 146 |
-
output_APorPA = gr.Label(label="Predicted AP/PA/Lateral")
|
| 147 |
output_round = gr.Label(label="Predicted Rotation")
|
| 148 |
|
| 149 |
send_btn = gr.Button("Inference")
|
|
@@ -157,9 +156,9 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
|
|
| 157 |
# サンプルファイルは実際のパスに置き換えてください
|
| 158 |
gr.Examples(
|
| 159 |
examples=[
|
| 160 |
-
'./sample/
|
| 161 |
-
'./sample/
|
| 162 |
-
'./sample/
|
| 163 |
],
|
| 164 |
inputs=input_image
|
| 165 |
)
|
|
|
|
| 18 |
|
| 19 |
# ===========================================
|
| 20 |
# 2) クラスラベルの定義
|
| 21 |
+
# - label_APorPA (3クラス): 0=AP, 1=PA, 2=Lateral
|
| 22 |
+
# - label_round (4クラス): 0=Upright, 1=Inverted, 2=Left rotation, 3=Right rotation
|
| 23 |
# ===========================================
|
| 24 |
LABEL_APorPA = [
|
| 25 |
+
"AP", # class 0
|
| 26 |
+
"PA", # class 1
|
| 27 |
+
"Lateral", # class 2
|
|
|
|
| 28 |
]
|
| 29 |
|
| 30 |
LABEL_ROUND = [
|
| 31 |
+
"Upright", # class 0
|
| 32 |
+
"Inverted", # class 1
|
| 33 |
+
"Left rotation", # class 2
|
| 34 |
+
"Right rotation" # class 3
|
| 35 |
]
|
| 36 |
|
| 37 |
# ===========================================
|
| 38 |
# 3) 前処理用の ImageHandlerクラス
|
| 39 |
+
# - 画像が既に256×256前提
|
| 40 |
# ===========================================
|
| 41 |
class ImageHandler(ImageMixin):
|
| 42 |
def __init__(self, params):
|
| 43 |
self.params = params
|
| 44 |
self.transform = T.Compose([
|
| 45 |
+
# T.Resize((256, 256)), # 必要であればコメントアウトを外す
|
| 46 |
T.ToTensor(),
|
| 47 |
])
|
| 48 |
|
|
|
|
| 88 |
def classify_APorPA_and_round(image):
|
| 89 |
"""
|
| 90 |
モデルが以下を出力する想定:
|
| 91 |
+
outputs["label_APorPA"] -> shape=[1, 3] (3クラス: AP, PA, Lateral)
|
| 92 |
+
outputs["label_round"] -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation)
|
| 93 |
"""
|
| 94 |
image_handler = ImageHandler(args_dataloader)
|
| 95 |
image_tensor = image_handler.set_image(image)
|
|
|
|
| 109 |
print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 110 |
return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
|
| 111 |
|
| 112 |
+
scores_APorPA = outputs["label_APorPA"] # shape=[1,3]
|
| 113 |
pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
|
|
|
|
|
|
|
| 114 |
predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
|
| 115 |
|
| 116 |
# --- label_round ---
|
|
|
|
| 118 |
print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
|
| 119 |
return predicted_APorPA, "ERROR: Missing 'label_round'"
|
| 120 |
|
| 121 |
+
scores_round = outputs["label_round"] # shape=[1,4]
|
| 122 |
pred_round_idx = torch.argmax(scores_round, dim=1).item()
|
| 123 |
predicted_round = LABEL_ROUND[pred_round_idx]
|
| 124 |
|
|
|
|
| 131 |
<div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
|
| 132 |
<h3>Chest X-ray: AP/PA & Rotation Classification</h3>
|
| 133 |
<p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
|
| 134 |
+
<p>胸部レントゲン画像に対して、撮像方向(3クラス: AP, PA, Lateral) と
|
| 135 |
+
回転方向(4クラス: Upright, Inverted, Left rotation, Right rotation)を同時に推定します。</p>
|
| 136 |
</div>
|
| 137 |
"""
|
| 138 |
|
|
|
|
| 142 |
|
| 143 |
with gr.Row():
|
| 144 |
input_image = gr.Image(type="pil", image_mode="L")
|
| 145 |
+
output_APorPA = gr.Label(label="Predicted AP/PA/Lateral")
|
| 146 |
output_round = gr.Label(label="Predicted Rotation")
|
| 147 |
|
| 148 |
send_btn = gr.Button("Inference")
|
|
|
|
| 156 |
# サンプルファイルは実際のパスに置き換えてください
|
| 157 |
gr.Examples(
|
| 158 |
examples=[
|
| 159 |
+
'./sample/sample_AP_upright.png',
|
| 160 |
+
'./sample/sample_PA_inverted.png',
|
| 161 |
+
'./sample/sample_lateral_left.png'
|
| 162 |
],
|
| 163 |
inputs=input_image
|
| 164 |
)
|