MedicalAILabo commited on
Commit
ad18100
·
verified ·
1 Parent(s): 80c8afb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -22
app.py CHANGED
@@ -18,31 +18,31 @@ parameter = './parameters.json'
18
 
19
  # ===========================================
20
  # 2) クラスラベルの定義
21
- # - 今回は撮像方向 (3クラス) と回転方向 (4クラス)
 
22
  # ===========================================
23
  LABEL_APorPA = [
24
- "AP", # class 0
25
- "PA", # class 1
26
- "Lateral", # class 2
27
- # ↑ 実際のクラス名に合わせて変更してください
28
  ]
29
 
30
  LABEL_ROUND = [
31
- "0° Rotation", # class 0
32
- "90° Rotation", # class 1
33
- "180° Rotation", # class 2
34
- "270° Rotation" # class 3
35
  ]
36
 
37
  # ===========================================
38
  # 3) 前処理用の ImageHandlerクラス
39
- # - 画像が既に256×256前提。Resizeはコメントアウトで残す
40
  # ===========================================
41
  class ImageHandler(ImageMixin):
42
  def __init__(self, params):
43
  self.params = params
44
  self.transform = T.Compose([
45
- # T.Resize((256, 256)), # 必要ならコメントアウトを外す
46
  T.ToTensor(),
47
  ])
48
 
@@ -88,8 +88,8 @@ model.eval() # 推論モード
88
  def classify_APorPA_and_round(image):
89
  """
90
  モデルが以下を出力する想定:
91
- outputs["label_APorPA"] -> shape=[1, 3] (3クラス)
92
- outputs["label_round"] -> shape=[1, 4] (4クラス)
93
  """
94
  image_handler = ImageHandler(args_dataloader)
95
  image_tensor = image_handler.set_image(image)
@@ -109,10 +109,8 @@ def classify_APorPA_and_round(image):
109
  print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
110
  return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
111
 
112
- scores_APorPA = outputs["label_APorPA"] # shape=[1,3]想定
113
  pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
114
-
115
- # IndexError が発生する場合は、pred_APorPA_idx が 0~2 の範囲外になる
116
  predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
117
 
118
  # --- label_round ---
@@ -120,7 +118,7 @@ def classify_APorPA_and_round(image):
120
  print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
121
  return predicted_APorPA, "ERROR: Missing 'label_round'"
122
 
123
- scores_round = outputs["label_round"] # shape=[1,4]想定
124
  pred_round_idx = torch.argmax(scores_round, dim=1).item()
125
  predicted_round = LABEL_ROUND[pred_round_idx]
126
 
@@ -133,7 +131,8 @@ html_content = """
133
  <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
134
  <h3>Chest X-ray: AP/PA & Rotation Classification</h3>
135
  <p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
136
- <p>胸部レントゲン画像に対して、撮像方向(3クラス)と回転方向(4クラス)を同時に推定します。</p>
 
137
  </div>
138
  """
139
 
@@ -143,7 +142,7 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
143
 
144
  with gr.Row():
145
  input_image = gr.Image(type="pil", image_mode="L")
146
- output_APorPA = gr.Label(label="Predicted AP/PA/Lateral") # 3クラス想定に合わせた名前
147
  output_round = gr.Label(label="Predicted Rotation")
148
 
149
  send_btn = gr.Button("Inference")
@@ -157,9 +156,9 @@ with gr.Blocks(title="Chest X-ray: AP/PA & Rotation Classification") as demo:
157
  # サンプルファイルは実際のパスに置き換えてください
158
  gr.Examples(
159
  examples=[
160
- './sample/sample_AP_inverted.png',
161
- './sample/sample_PA_right.png',
162
- './sample/sample_lateral_upright.png'
163
  ],
164
  inputs=input_image
165
  )
 
18
 
19
  # ===========================================
20
  # 2) クラスラベルの定義
21
+ # - label_APorPA (3クラス): 0=AP, 1=PA, 2=Lateral
22
+ # - label_round (4クラス): 0=Upright, 1=Inverted, 2=Left rotation, 3=Right rotation
23
  # ===========================================
24
  LABEL_APorPA = [
25
+ "AP", # class 0
26
+ "PA", # class 1
27
+ "Lateral", # class 2
 
28
  ]
29
 
30
  LABEL_ROUND = [
31
+ "Upright", # class 0
32
+ "Inverted", # class 1
33
+ "Left rotation", # class 2
34
+ "Right rotation" # class 3
35
  ]
36
 
37
  # ===========================================
38
  # 3) 前処理用の ImageHandlerクラス
39
+ # - 画像が既に256×256前提
40
  # ===========================================
41
  class ImageHandler(ImageMixin):
42
  def __init__(self, params):
43
  self.params = params
44
  self.transform = T.Compose([
45
+ # T.Resize((256, 256)), # 必要であればコメントアウトを外す
46
  T.ToTensor(),
47
  ])
48
 
 
88
  def classify_APorPA_and_round(image):
89
  """
90
  モデルが以下を出力する想定:
91
+ outputs["label_APorPA"] -> shape=[1, 3] (3クラス: AP, PA, Lateral)
92
+ outputs["label_round"] -> shape=[1, 4] (4クラス: Upright, Inverted, Left rotation, Right rotation)
93
  """
94
  image_handler = ImageHandler(args_dataloader)
95
  image_tensor = image_handler.set_image(image)
 
109
  print(f"[ERROR] 'label_APorPA' not found in outputs. Actual keys: {list(outputs.keys())}")
110
  return "ERROR: Missing 'label_APorPA'", "ERROR: Missing 'label_round'"
111
 
112
+ scores_APorPA = outputs["label_APorPA"] # shape=[1,3]
113
  pred_APorPA_idx = torch.argmax(scores_APorPA, dim=1).item()
 
 
114
  predicted_APorPA = LABEL_APorPA[pred_APorPA_idx]
115
 
116
  # --- label_round ---
 
118
  print(f"[ERROR] 'label_round' not found in outputs. Actual keys: {list(outputs.keys())}")
119
  return predicted_APorPA, "ERROR: Missing 'label_round'"
120
 
121
+ scores_round = outputs["label_round"] # shape=[1,4]
122
  pred_round_idx = torch.argmax(scores_round, dim=1).item()
123
  predicted_round = LABEL_ROUND[pred_round_idx]
124
 
 
131
  <div style="padding: 15px; border: 1px solid #e0e0e0; border-radius: 5px;">
132
  <h3>Chest X-ray: AP/PA & Rotation Classification</h3>
133
  <p>入力画像は既に256×256(グレースケール)であることを想定し、内部でのリサイズは行いません。</p>
134
+ <p>胸部レントゲン画像に対して、撮像方向(3クラス: AP, PA, Lateral)
135
+ 回転方向(4クラス: Upright, Inverted, Left rotation, Right rotation)を同時に推定します。</p>
136
  </div>
137
  """
138
 
 
142
 
143
  with gr.Row():
144
  input_image = gr.Image(type="pil", image_mode="L")
145
+ output_APorPA = gr.Label(label="Predicted AP/PA/Lateral")
146
  output_round = gr.Label(label="Predicted Rotation")
147
 
148
  send_btn = gr.Button("Inference")
 
156
  # サンプルファイルは実際のパスに置き換えてください
157
  gr.Examples(
158
  examples=[
159
+ './sample/sample_AP_upright.png',
160
+ './sample/sample_PA_inverted.png',
161
+ './sample/sample_lateral_left.png'
162
  ],
163
  inputs=input_image
164
  )