Picha Jetsadapattarakul commited on
Commit
29dbe3c
·
0 Parent(s):

Initial Streamlit DFU ViT app with LFS model

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+
6
+ from transformers import (
7
+ pipeline,
8
+ ViTImageProcessor,
9
+ ViTForImageClassification
10
+ )
11
+ from PIL import Image
12
+
13
+ from pytorch_grad_cam import GradCAM
14
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
15
+ from pytorch_grad_cam.utils.image import show_cam_on_image
16
+
17
+ MODEL_FOLDER_PATH = "final_vit_model"
18
+
19
+
20
+ sample_img = {
21
+ "<None>": None,
22
+ "Right_Normal_1": "img/[0] Normal/Normal_R_1.png",
23
+ "Left_Normal_1": "img/[0] Normal/Normal_L_1.png",
24
+ "Right_Normal_2": "img/[0] Normal/Normal_R_2.png",
25
+ "Left_Normal_2": "img/[0] Normal/Normal_L_2.png",
26
+
27
+ "Right_DFU_1": "img/[1] DFU/DFU_R_1.png",
28
+ "Left_DFU_1": "img/[1] DFU/DFU_L_1.png",
29
+ "Right_DFU_2": "img/[1] DFU/DFU_R_2.png",
30
+ "Left_DFU_2": "img/[1] DFU/DFU_L_2.png",
31
+ }
32
+
33
+ sample_pairs = {
34
+ "<None>": (None, None),
35
+
36
+ "Normal Pair 1": (
37
+ "img/[0] Normal/Normal_R_1.png",
38
+ "img/[0] Normal/Normal_L_1.png",
39
+ ),
40
+ "Normal Pair 2": (
41
+ "img/[0] Normal/Normal_R_2.png",
42
+ "img/[0] Normal/Normal_L_2.png",
43
+ ),
44
+
45
+ "DFU Pair 1": (
46
+ "img/[1] DFU/DFU_R_1.png",
47
+ "img/[1] DFU/DFU_L_1.png",
48
+ ),
49
+ "DFU Pair 2": (
50
+ "img/[1] DFU/DFU_R_2.png",
51
+ "img/[1] DFU/DFU_L_2.png",
52
+ ),
53
+ }
54
+
55
+ def reshape_transform(tensor):
56
+ """
57
+ For ViT: remove CLS token and reshape sequence (N) to (H, W).
58
+ """
59
+ tensor = tensor[:, 1:, :]
60
+ B, N, C = tensor.shape
61
+ H = W = int(N ** 0.5)
62
+ tensor = tensor.reshape(B, H, W, C)
63
+ tensor = tensor.permute(0, 3, 1, 2)
64
+ return tensor
65
+
66
+ class HuggingfaceToTensorModelWrapper(torch.nn.Module):
67
+ def __init__(self, model):
68
+ super().__init__()
69
+ self.model = model
70
+
71
+ def forward(self, x):
72
+ return self.model(x).logits
73
+
74
+
75
+ @st.cache_resource
76
+ def load_classifier():
77
+ return pipeline(
78
+ task="image-classification",
79
+ model=MODEL_FOLDER_PATH
80
+ )
81
+ def load_gradcam():
82
+ device = torch.device(
83
+ "mps" if torch.backends.mps.is_available()
84
+ else ("cuda" if torch.cuda.is_available() else "cpu")
85
+ )
86
+
87
+ processor = ViTImageProcessor.from_pretrained(MODEL_FOLDER_PATH)
88
+ hf_model = ViTForImageClassification.from_pretrained(MODEL_FOLDER_PATH)
89
+
90
+ model = HuggingfaceToTensorModelWrapper(hf_model).to(device).eval()
91
+ target_layers = [model.model.vit.encoder.layer[-1].layernorm_before]
92
+
93
+ cam = GradCAM(
94
+ model=model,
95
+ target_layers=target_layers,
96
+ reshape_transform=reshape_transform
97
+ )
98
+
99
+ return cam, processor, device
100
+
101
+ def compute_gradcam_for_pil(pil_img, target_index: int):
102
+ cam, processor, device = load_gradcam()
103
+
104
+ img_np = np.array(pil_img).astype(np.float32) / 255.0
105
+
106
+ inputs = processor(images=pil_img, return_tensors="pt")
107
+ input_tensor = inputs["pixel_values"].to(device)
108
+
109
+ targets = [ClassifierOutputTarget(target_index)]
110
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
111
+
112
+ H, W, _ = img_np.shape
113
+ grayscale_cam_resized = cv2.resize(grayscale_cam, (W, H))
114
+
115
+ cam_vis = show_cam_on_image(img_np, grayscale_cam_resized, use_rgb=True)
116
+ return img_np, cam_vis
117
+
118
+
119
+ def pretty_label(raw_label: str) -> str:
120
+ mapping = {
121
+ "0": "Normal",
122
+ "1": "Diabetic Foot Ulcers",
123
+ }
124
+
125
+ return mapping.get(raw_label, raw_label)
126
+
127
+
128
+ def get_target_index(raw_label: str) -> int:
129
+ pretty = pretty_label(raw_label)
130
+
131
+ if pretty == "Normal":
132
+ return 0
133
+ elif pretty == "Diabetic Foot Ulcers":
134
+ return 1
135
+
136
+ return 1
137
+
138
+
139
+ def app():
140
+ st.title("Early Detection of Diabetic Foot Ulcers Using Thermal Imaging with Vision Transformer & Grad-CAM")
141
+
142
+ mode = st.radio(
143
+ "Choose input mode",
144
+ ["Use sample pair (Right + Left)", "Upload your own Right & Left images"],
145
+ index=0
146
+ )
147
+
148
+ upload_right = None
149
+ upload_left = None
150
+ if mode == "Upload your own Right & Left images":
151
+ upload_right = st.file_uploader(
152
+ "Upload Right Foot Image",
153
+ type=["png", "jpg", "jpeg"],
154
+ key="right_upl"
155
+ )
156
+ upload_left = st.file_uploader(
157
+ "Upload Left Foot Image",
158
+ type=["png", "jpg", "jpeg"],
159
+ key="left_upl"
160
+ )
161
+
162
+ right_image = None
163
+ left_image = None
164
+ right_path = left_path = None
165
+
166
+ if mode == "Use sample pair (Right + Left)":
167
+ with st.expander("Choose a sample pair and view all sample images", expanded=False):
168
+ pair_name = st.selectbox(
169
+ "Select a sample pair (Right + Left):",
170
+ list(sample_pairs.keys()),
171
+ index=0
172
+ )
173
+
174
+ st.markdown("**Normal Group**")
175
+ c1, c2, c3, c4 = st.columns(4)
176
+ with c1:
177
+ st.image(sample_img["Right_Normal_1"], caption="Right_Normal_1", width='stretch')
178
+ with c2:
179
+ st.image(sample_img["Left_Normal_1"], caption="Left_Normal_1", width='stretch')
180
+ with c3:
181
+ st.image(sample_img["Right_Normal_2"], caption="Right_Normal_2", width='stretch')
182
+ with c4:
183
+ st.image(sample_img["Left_Normal_2"], caption="Left_Normal_2", width='stretch')
184
+
185
+ st.markdown("**Diabetic Foot Ulcers Group**")
186
+ c1, c2, c3, c4 = st.columns(4)
187
+ with c1:
188
+ st.image(sample_img["Right_DFU_1"], caption="Right_DFU_1", width='stretch')
189
+ with c2:
190
+ st.image(sample_img["Left_DFU_1"], caption="Left_DFU_1", width='stretch')
191
+ with c3:
192
+ st.image(sample_img["Right_DFU_2"], caption="Right_DFU_2", width='stretch')
193
+ with c4:
194
+ st.image(sample_img["Left_DFU_2"], caption="Left_DFU_2", width='stretch')
195
+
196
+ right_path, left_path = sample_pairs[pair_name]
197
+
198
+ col_input, col_output = st.columns(2)
199
+ col_input.header("Input Images")
200
+ col_output.header("Predictions")
201
+
202
+ right_col_in, left_col_in = col_input.columns(2)
203
+ right_col_in.subheader("Right Foot")
204
+ left_col_in.subheader("Left Foot")
205
+
206
+ if mode == "Use sample pair (Right + Left)":
207
+ if right_path is not None:
208
+ right_image = Image.open(right_path).convert("RGB")
209
+ right_col_in.image(right_image, caption="Sample Right Foot", width='stretch')
210
+
211
+ if left_path is not None:
212
+ left_image = Image.open(left_path).convert("RGB")
213
+ left_col_in.image(left_image, caption="Sample Left Foot", width='stretch')
214
+
215
+ else:
216
+ if upload_right is not None:
217
+ right_image = Image.open(upload_right).convert("RGB")
218
+ right_col_in.image(right_image, caption="Uploaded Right Foot", width='stretch')
219
+
220
+ if upload_left is not None:
221
+ left_image = Image.open(upload_left).convert("RGB")
222
+ left_col_in.image(left_image, caption="Uploaded Left Foot", width='stretch')
223
+
224
+ run_pred = col_output.button("Run prediction")
225
+
226
+ out_right_col, out_left_col = col_output.columns(2)
227
+ out_right_col.subheader("Right Foot Prediction")
228
+ out_left_col.subheader("Left Foot Prediction")
229
+
230
+ right_cam_vis = None
231
+ left_cam_vis = None
232
+
233
+ if run_pred:
234
+ classifier = load_classifier()
235
+ any_image = False
236
+
237
+ if right_image is not None:
238
+ any_image = True
239
+ preds_right = classifier(right_image, top_k=2)
240
+
241
+ for pred in preds_right:
242
+ label = pretty_label(pred["label"])
243
+ score = float(pred["score"])
244
+ out_right_col.progress(score, text=f"{label}: {score * 100:.2f}%")
245
+
246
+ top_right_raw_label = preds_right[0]["label"]
247
+ right_target_index = get_target_index(top_right_raw_label)
248
+ _, right_cam_vis = compute_gradcam_for_pil(right_image, right_target_index)
249
+
250
+ if left_image is not None:
251
+ any_image = True
252
+ preds_left = classifier(left_image, top_k=2)
253
+
254
+ for pred in preds_left:
255
+ label = pretty_label(pred["label"])
256
+ score = float(pred["score"])
257
+ out_left_col.progress(score, text=f"{label}: {score * 100:.2f}%")
258
+
259
+ top_left_raw_label = preds_left[0]["label"]
260
+ left_target_index = get_target_index(top_left_raw_label)
261
+ _, left_cam_vis = compute_gradcam_for_pil(left_image, left_target_index)
262
+
263
+ if any_image:
264
+ col_output.success("Classification finished ✅")
265
+ else:
266
+ col_output.warning("Please provide images before running prediction.")
267
+
268
+ if right_cam_vis is not None or left_cam_vis is not None:
269
+ st.markdown("---")
270
+ if right_target_index == 1:
271
+ right_target_index = "DFU"
272
+ else:
273
+ right_target_index = "Normal"
274
+
275
+ st.subheader(f"Grad-CAM Visualization (Target class: {right_target_index})")
276
+
277
+ gcol_r, gcol_l = st.columns(2)
278
+
279
+ if right_cam_vis is not None:
280
+ gcol_r.markdown("**Right Foot**")
281
+ gcol_r.image(right_cam_vis, width='stretch')
282
+
283
+ if left_cam_vis is not None:
284
+ gcol_l.markdown("**Left Foot**")
285
+ gcol_l.image(left_cam_vis, width='stretch')
286
+
287
+ if __name__ == "__main__":
288
+ app()
final_vit_model/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTForImageClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "dtype": "float32",
7
+ "encoder_stride": 16,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.0,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "0",
13
+ "1": "1"
14
+ },
15
+ "image_size": 224,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 3072,
18
+ "label2id": {
19
+ "0": 0,
20
+ "1": 1
21
+ },
22
+ "layer_norm_eps": 1e-12,
23
+ "model_type": "vit",
24
+ "num_attention_heads": 12,
25
+ "num_channels": 3,
26
+ "num_hidden_layers": 12,
27
+ "patch_size": 16,
28
+ "pooler_act": "tanh",
29
+ "pooler_output_size": 768,
30
+ "problem_type": "single_label_classification",
31
+ "qkv_bias": true,
32
+ "transformers_version": "4.57.1"
33
+ }
final_vit_model/label_map.pt ADDED
Binary file (1.34 kB). View file
 
final_vit_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:525fef3f5c6834f6f31c960e51fd81cb2c25123465765c99393e92e475794215
3
+ size 343223968
final_vit_model/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "ViTImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 224,
21
+ "width": 224
22
+ }
23
+ }
final_vit_model/training_args.bin ADDED
Binary file (5.84 kB). View file
 
img/[0] Normal/Normal_L_1.png ADDED
img/[0] Normal/Normal_L_2.png ADDED
img/[0] Normal/Normal_R_1.png ADDED
img/[0] Normal/Normal_R_2.png ADDED
img/[1] DFU/DFU_L_1.png ADDED
img/[1] DFU/DFU_L_2.png ADDED
img/[1] DFU/DFU_R_1.png ADDED
img/[1] DFU/DFU_R_2.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch
4
+ pillow
5
+ opencv-python
6
+ numpy
7
+ pytorch-grad-cam