MedicalAILabo commited on
Commit
5113ba3
·
verified ·
1 Parent(s): 4bb0fa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -13,16 +13,18 @@ from lib.dataloader import ImageMixin
13
  # ===========================================
14
  # 1) パスなど(修正があれば適宜変更)
15
  # ===========================================
16
- test_weight = './weight_epoch-011_best.pt'
17
  parameter = './parameters.json'
18
 
19
  # ===========================================
20
  # 2) クラスラベルの定義
21
- # - 撮像方向(label_APorPA)と回転方向(label_round)のクラス
22
  # ===========================================
23
  LABEL_APorPA = [
24
- "AP", # class 0
25
- "PA" # class 1
 
 
26
  ]
27
 
28
  LABEL_ROUND = [
@@ -39,15 +41,12 @@ LABEL_ROUND = [
39
  class ImageHandler(ImageMixin):
40
  def __init__(self, params):
41
  self.params = params
42
- # ここでリサイズは省略(推論が重いので)
43
- # 入力画像は既に256×256であることを想定
44
  self.transform = T.Compose([
45
- # T.Resize((256, 256)), # コメントアウト: 画像を256×256にリサイズ
46
- T.ToTensor(), # Tensor化 (0~1, shape: C,H,W)
47
  ])
48
 
49
  def set_image(self, image):
50
- # PIL画像 -> transform -> バッチ次元を付ける
51
  image = self.transform(image)
52
  image = {'image': image.unsqueeze(0)}
53
  return image
@@ -89,8 +88,8 @@ model.eval() # 推論モード
89
  def classify_APorPA_and_round(image):
90
  """
91
  モデルが以下を出力する想定:
92
- outputs["label_APorPA"] -> shape=[1, 2] (2クラス: AP/PA)
93
- outputs["label_round"] -> shape=[1, 4] (4クラス: 0°, 90°, 180°, 270°)
94
  """
95
  image_handler = ImageHandler(args_dataloader)
96
  image_tensor = image_handler.set_image(image)
@@ -98,13 +97,22 @@ def classify_APorPA_and_round(image):
98
  with torch.no_grad():
99
  outputs = model(image_tensor)
100
 
 
 
 
 
 
 
 
101
  # --- label_APorPA ---
102
  if "label_APorPA" not in outputs:
103
  print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
104
  return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
105
 
106
- scores_APorPA = outputs["label_APorPA"] # shape=[1,2]想定
107
  pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
 
 
108
  predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
109
 
110
  # --- label_round ---
@@ -125,7 +133,7 @@ html_content = """
125
  <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
126
  <h3>Chest X-ray: AP/PA & Rotation Classification</h3>
127
  <p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
128
- <p>胸部レントゲン画像に対して、撮像方向(AP or PA)と回転方向(0°, 90°, 180°, 270°)を同時に推定します。</p>
129
  </div>
130
  """
131
 
@@ -135,7 +143,7 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
135
 
136
  with gr.Row():
137
  input_image = gr.Image(type="pil", image_mode="L")
138
- output_APorPA = gr.Label(label="Predicted AP or PA")
139
  output_round = gr.Label(label="Predicted Rotation")
140
 
141
  send_btn = gr.Button("Inference")
@@ -149,9 +157,9 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
149
  # サンプルファイルは実際のパスに置き換えてください
150
  gr.Examples(
151
  examples=[
152
- './sample/sample_AP_inverted.png',
153
- './sample/sample_PA_right.png',
154
- './sample/sample_lateral_upright.png'
155
  ],
156
  inputs=input_image
157
  )
 
13
  # ===========================================
14
  # 1) パスなど(修正があれば適宜変更)
15
  # ===========================================
16
+ test_weight = './weight_epoch-003_best.pt'
17
  parameter = './parameters.json'
18
 
19
  # ===========================================
20
  # 2) クラスラベルの定義
21
+ # - 今回は撮像方向 (3クラス) と回転方向 (4クラス)
22
  # ===========================================
23
  LABEL_APorPA = [
24
+ "AP", # class 0
25
+ "PA", # class 1
26
+ "Lateral", # class 2
27
+ # ↑ 実際のクラス名に合わせて変更してください
28
  ]
29
 
30
  LABEL_ROUND = [
 
41
  class ImageHandler(ImageMixin):
42
  def __init__(self, params):
43
  self.params = params
 
 
44
  self.transform = T.Compose([
45
+ # T.Resize((256, 256)), # 必要ならコメントアウトを外す
46
+ T.ToTensor(),
47
  ])
48
 
49
  def set_image(self, image):
 
50
  image = self.transform(image)
51
  image = {'image': image.unsqueeze(0)}
52
  return image
 
88
  def classify_APorPA_and_round(image):
89
  """
90
  モデルが以下を出力する想定:
91
+ outputs["label_APorPA"] -> shape=[1, 3] (3クラス)
92
+ outputs["label_round"] -> shape=[1, 4] (4クラス)
93
  """
94
  image_handler = ImageHandler(args_dataloader)
95
  image_tensor = image_handler.set_image(image)
 
97
  with torch.no_grad():
98
  outputs = model(image_tensor)
99
 
100
+ # デバッグ用の出力チェック
101
+ print("keys in outputs =", outputs.keys())
102
+ if "label_APorPA" in outputs:
103
+ print("label_APorPA shape =", outputs["label_APorPA"].shape)
104
+ if "label_round" in outputs:
105
+ print("label_round shape =", outputs["label_round"].shape)
106
+
107
  # --- label_APorPA ---
108
  if "label_APorPA" not in outputs:
109
  print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
110
  return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
111
 
112
+ scores_APorPA = outputs["label_APorPA"] # shape=[1,3]想定
113
  pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
114
+
115
+ # IndexError が発生する場合は、pred_APorPA_idx が 0~2 の範囲外になる
116
  predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
117
 
118
  # --- label_round ---
 
133
  <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
134
  <h3>Chest X-ray: AP/PA & Rotation Classification</h3>
135
  <p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
136
+ <p>胸部レントゲン画像に対して、撮像方向(3クラス)と回転方向(4クラス)を同時に推定します。</p>
137
  </div>
138
  """
139
 
 
143
 
144
  with gr.Row():
145
  input_image = gr.Image(type="pil", image_mode="L")
146
+ output_APorPA = gr.Label(label="Predicted AP/PA/Lateral") # 3クラス想定に合わせた名前
147
  output_round = gr.Label(label="Predicted Rotation")
148
 
149
  send_btn = gr.Button("Inference")
 
157
  # サンプルファイルは実際のパスに置き換えてください
158
  gr.Examples(
159
  examples=[
160
+ './samples/sample_chest_AP_0deg.png',
161
+ './samples/sample_chest_PA_90deg.png',
162
+ './samples/sample_chest_LAT_180deg.png'
163
  ],
164
  inputs=input_image
165
  )