MedicalAILabo commited on
Commit
58487cb
·
verified ·
1 Parent(s): afba24e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -44
app.py CHANGED
@@ -1,48 +1,43 @@
1
  import os
2
  import numpy as np
3
- import onnxruntime as ort
 
4
  import gradio as gr
5
  from PIL import Image
6
 
7
- # モデルをONNX Runtimeで読み込む
8
- session = ort.InferenceSession("cxp_projection_rotation.pt", None)
9
- input_name = session.get_inputs()[0].name
10
- # 出力ヘッド名(順序に注意)
11
- projection_name = session.get_outputs()[0].name
12
- rotation_name = session.get_outputs()[1].name
13
 
14
- # クラスラベルの定義
15
  PROJ_LABELS = ["AP", "PA", "Lateral"]
16
  ROT_LABELS = ["Upright", "Inverted", "Left90", "Right90"]
17
 
18
- def softmax(x: np.ndarray) -> np.ndarray:
19
- """数値安定版ソフトマックス"""
20
- e = np.exp(x - np.max(x))
21
- return e / np.sum(e)
22
-
23
  def predict(image_path: str) -> str:
24
  # 画像をグレースケールで読み込み(Lモード)
25
  img = Image.open(image_path).convert("L")
26
- arr = np.array(img, dtype=np.float32) / 255.0 # 正規化
27
- arr = arr[np.newaxis, np.newaxis, :, :] # [1,1,H,W] バッチ&チャネル次元追加
 
28
 
29
  # 推論
30
- proj_logits, rot_logits = session.run(
31
- [projection_name, rotation_name],
32
- {input_name: arr}
33
- )
34
- proj_probs = softmax(proj_logits[0])
35
- rot_probs = softmax(rot_logits[0])
36
 
37
- # 予測結果と確率
38
- proj_idx = int(np.argmax(proj_probs))
39
- rot_idx = int(np.argmax(rot_probs))
40
- proj_lbl = PROJ_LABELS[proj_idx]
41
- rot_lbl = ROT_LABELS[rot_idx]
42
- proj_p = proj_probs[proj_idx]
43
- rot_p = rot_probs[rot_idx]
44
 
45
- # ファイル名から元ラベルを推定(例: "AP_Upright.png" → ["AP","Upright"])
 
 
 
 
 
 
 
 
46
  base = os.path.splitext(os.path.basename(image_path))[0]
47
  try:
48
  orig_proj, orig_rot = base.split("_", 1)
@@ -56,7 +51,7 @@ def predict(image_path: str) -> str:
56
  if orig_rot and orig_rot != rot_lbl:
57
  warnings_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
58
 
59
- # HTML形式で結果を返す
60
  html = (
61
  f"<p><strong>Projection :</strong> {proj_lbl} (p={proj_p:.3f})</p>"
62
  f"<p><strong>Rotation :</strong> {rot_lbl} (p={rot_p:.3f})</p>"
@@ -64,35 +59,33 @@ def predict(image_path: str) -> str:
64
  )
65
  return html
66
 
67
- # Gradio UI構築
68
  with gr.Blocks() as demo:
69
  with gr.Row():
70
  with gr.Column():
71
- # 画像アップロード(PNG/L, 256×256前提)
72
  image_input = gr.Image(
73
  label="Upload PNG (256×256)",
74
  type="filepath",
75
  tool=None
76
  )
77
- # sample_imagesフォルダから4枚まで例示
78
- sample_list = sorted(
79
- [os.path.join("sample_images", f)
80
- for f in os.listdir("sample_images")
81
- if f.lower().endswith(".png")]
82
- )
83
  gr.Examples(
84
- examples=sample_list,
85
- inputs=image_input,
86
  label="Sample Images"
87
  )
88
  with gr.Column():
89
- # 推論結果用HTML表示エリア
90
  result = gr.HTML()
91
 
92
- # 画像選択・アップロード時に自動でpredictを実行
93
  image_input.change(
94
- fn=predict,
95
- inputs=image_input,
96
  outputs=result
97
  )
98
 
 
1
  import os
2
  import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
  import gradio as gr
6
  from PIL import Image
7
 
8
+ # PyTorch モデルをロード
9
+ # map_location="cpu" で CPU 上にロード
10
+ model = torch.load("cxp_projection_rotation.pt", map_location="cpu")
11
+ model.eval() # 評価モードに切り替え
 
 
12
 
13
+ # クラスラベル定義
14
  PROJ_LABELS = ["AP", "PA", "Lateral"]
15
  ROT_LABELS = ["Upright", "Inverted", "Left90", "Right90"]
16
 
 
 
 
 
 
17
  def predict(image_path: str) -> str:
18
  # 画像をグレースケールで読み込み(Lモード)
19
  img = Image.open(image_path).convert("L")
20
+ arr = np.array(img, dtype=np.float32) / 255.0 # 0-1 正規化
21
+ # バッチ&チャンネル次元を追加して Tensor
22
+ tensor = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0) # [1,1,256,256]
23
 
24
  # 推論
25
+ with torch.no_grad():
26
+ proj_logits, rot_logits = model(tensor) # 2 ヘッド出力
 
 
 
 
27
 
28
+ # ソフトマックスで確率化
29
+ proj_probs = F.softmax(proj_logits, dim=1)[0].cpu().numpy()
30
+ rot_probs = F.softmax(rot_logits, dim=1)[0].cpu().numpy()
 
 
 
 
31
 
32
+ # 最も確率の高いラベルと確率
33
+ proj_idx = int(np.argmax(proj_probs))
34
+ rot_idx = int(np.argmax(rot_probs))
35
+ proj_lbl = PROJ_LABELS[proj_idx]
36
+ rot_lbl = ROT_LABELS[rot_idx]
37
+ proj_p = proj_probs[proj_idx]
38
+ rot_p = rot_probs[rot_idx]
39
+
40
+ # ファイル名から元ラベルを取得(例: "AP_Upright.png" → ["AP","Upright"])
41
  base = os.path.splitext(os.path.basename(image_path))[0]
42
  try:
43
  orig_proj, orig_rot = base.split("_", 1)
 
51
  if orig_rot and orig_rot != rot_lbl:
52
  warnings_html += "<p style='color:red'>⚠ Potential mislabeled rotation</p>"
53
 
54
+ # 結果を HTML 形式で返す
55
  html = (
56
  f"<p><strong>Projection :</strong> {proj_lbl} (p={proj_p:.3f})</p>"
57
  f"<p><strong>Rotation :</strong> {rot_lbl} (p={rot_p:.3f})</p>"
 
59
  )
60
  return html
61
 
62
+ # Gradio UI 定義
63
  with gr.Blocks() as demo:
64
  with gr.Row():
65
  with gr.Column():
 
66
  image_input = gr.Image(
67
  label="Upload PNG (256×256)",
68
  type="filepath",
69
  tool=None
70
  )
71
+ # sample_images フォルダからサンプル画像を読み込む
72
+ sample_list = sorted([
73
+ os.path.join("sample_images", f)
74
+ for f in os.listdir("sample_images")
75
+ if f.lower().endswith(".png")
76
+ ])
77
  gr.Examples(
78
+ examples=sample_list,
79
+ inputs=image_input,
80
  label="Sample Images"
81
  )
82
  with gr.Column():
 
83
  result = gr.HTML()
84
 
85
+ # 画像を選択/アップロードしたら自動で predict を実行
86
  image_input.change(
87
+ fn=predict,
88
+ inputs=image_input,
89
  outputs=result
90
  )
91