lmedz commited on
Commit
66566b0
·
verified ·
1 Parent(s): 77092c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -1,33 +1,39 @@
1
  import gradio as gr
2
  import torch
3
  import timm
 
4
  from torchvision import transforms
5
  from PIL import Image
6
-
7
- # モデル読み込み
8
  from huggingface_hub import hf_hub_download
9
- import torch
10
- import timm
11
 
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model_path = hf_hub_download(
15
- repo_id="lmedz/ips-model-weights", # ←作ったModel Hubのリポジトリ名
16
  filename="model.pth"
17
  )
18
 
19
-
20
- # 例:保存時が nn.Sequential(model) だった場合
21
- model_base = timm.create_model('convnext_small', pretrained=False, num_classes=2)
22
- model = torch.nn.Sequential(model_base) # ← wrapperを追加
23
-
24
  state_dict = torch.load(model_path, map_location=device)
25
  model.load_state_dict(state_dict)
26
  model.to(device)
27
  model.eval()
28
 
29
-
30
-
31
  # 前処理
32
  transform = transforms.Compose([
33
  transforms.Resize((224, 224)),
@@ -37,17 +43,17 @@ transform = transforms.Compose([
37
  std=[0.229, 0.224, 0.225])
38
  ])
39
 
40
- class_names = ['Low CPM Score', 'High CPM Score'] # 適宜修正
41
-
42
  def predict(img):
43
  img_tensor = transform(img).unsqueeze(0).to(device)
44
  with torch.no_grad():
45
- logits = model(img_tensor)
46
- probs = torch.softmax(logits, dim=1).cpu().numpy().flatten()
47
- result = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
48
- return img, result # ← 画像とスコアを返す
49
-
50
 
 
51
  demo = gr.Interface(
52
  fn=predict,
53
  inputs=gr.Image(type="pil", label="Input Image"),
@@ -56,9 +62,9 @@ demo = gr.Interface(
56
  gr.Label(num_top_classes=2, label="Prediction")
57
  ],
58
  title="iPS Cell Quality Classifier",
59
- description="Upload a microscopy image. The image and predicted cell quality will be displayed."
60
  )
61
 
62
-
63
  if __name__ == "__main__":
64
  demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  import timm
4
+ import torch.nn as nn
5
  from torchvision import transforms
6
  from PIL import Image
 
 
7
  from huggingface_hub import hf_hub_download
 
 
8
 
9
+ # デバイス設定
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # モデル構築(学習時と同じ構造を再現)
13
+ def build_model():
14
+ backbone = timm.create_model('convnext_small', pretrained=False, num_classes=0, global_pool='avg')
15
+ model = nn.Sequential(
16
+ backbone,
17
+ nn.Linear(backbone.num_features, 128),
18
+ nn.ReLU(),
19
+ nn.Dropout(0.4),
20
+ nn.Linear(128, 1)
21
+ )
22
+ return model
23
+
24
+ # モデルの重みを Model Hub からダウンロード
25
  model_path = hf_hub_download(
26
+ repo_id="ryumiyake/ips-model-weights", # あなたのModel Hubリポジトリ
27
  filename="model.pth"
28
  )
29
 
30
+ # モデル初期化・重み読み込み
31
+ model = build_model()
 
 
 
32
  state_dict = torch.load(model_path, map_location=device)
33
  model.load_state_dict(state_dict)
34
  model.to(device)
35
  model.eval()
36
 
 
 
37
  # 前処理
38
  transform = transforms.Compose([
39
  transforms.Resize((224, 224)),
 
43
  std=[0.229, 0.224, 0.225])
44
  ])
45
 
46
+ # 推論関数
47
+ THRESHOLD = 0.5 # 必要に応じて変更可能
48
  def predict(img):
49
  img_tensor = transform(img).unsqueeze(0).to(device)
50
  with torch.no_grad():
51
+ logit = model(img_tensor)
52
+ prob = torch.sigmoid(logit).item() # 出力: 0〜1の確率
53
+ label = "High CPM Score" if prob >= THRESHOLD else "Low CPM Score"
54
+ return img, {label: prob}
 
55
 
56
+ # Gradioインターフェース
57
  demo = gr.Interface(
58
  fn=predict,
59
  inputs=gr.Image(type="pil", label="Input Image"),
 
62
  gr.Label(num_top_classes=2, label="Prediction")
63
  ],
64
  title="iPS Cell Quality Classifier",
65
+ description="Upload a microscopy image to classify cell quality based on CPM score."
66
  )
67
 
 
68
  if __name__ == "__main__":
69
  demo.launch()
70
+