VikramR commited on
Commit
a9d56ef
·
1 Parent(s): 51eab78

Uploaded app

Browse files
Files changed (6) hide show
  1. .gitignore +39 -0
  2. README.md +3 -2
  3. app.py +396 -0
  4. external_models.py +154 -0
  5. requirements.txt +6 -0
  6. utils.py +42 -0
.gitignore ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
30
+ __pypackages__/
31
+
32
+ # Environments
33
+ .env
34
+ .venv
35
+ env/
36
+ venv/
37
+ ENV/
38
+ env.bak/
39
+ venv.bak/
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: NematodeClassifier
3
- emoji: 🏆
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.3.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: NematodeClassifier
3
+ emoji: 🪱
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.40.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms.functional import pil_to_tensor
3
+
4
+ import gradio as gr
5
+ from gradio.utils import get_upload_folder
6
+
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from external_models import EfficientNet, MobileNet, ResNet, Swin
10
+ from utils import get_preprocessing
11
+
12
+ from pathlib import Path
13
+ from PIL import Image
14
+ from tempfile import NamedTemporaryFile
15
+ import json
16
+ import os
17
+
18
+ import cv2
19
+
20
+ import pandas as pd
21
+ import numpy as np
22
+
23
+ device = "cpu"
24
+
25
+ models = {
26
+ "mbnet": MobileNet,
27
+ "effnet": EfficientNet,
28
+ "resnet": ResNet,
29
+ "swin": Swin,
30
+ }
31
+ model_filenames = {
32
+ "EfficientNetV2-S": "efficientnetv2s.pth",
33
+ "MobileNetV3-L": "mobilenetv3l.pth",
34
+ "ResNet101": "resnet101.pth",
35
+ "Swin V2-B": "swinv2b.pth",
36
+ }
37
+ model_names = {
38
+ "effnet": "EfficientNetV2-S",
39
+ "mbnet": "MobileNetV3-L",
40
+ "resnet": "ResNet101",
41
+ "swin": "Swin V2-B",
42
+ }
43
+
44
+
45
+ def cropped_img(img: np.ndarray | Image.Image | str):
46
+ """
47
+ Takes an image and automatically crops the nematode. Returns the cropped image
48
+ and the binary mask of the original image that outlines the nematode
49
+
50
+ Parameters
51
+ ----------
52
+ img : np.ndarray
53
+ Image
54
+
55
+ Returns
56
+ -------
57
+ tuple[float, float, float, float]
58
+ Cropped image bounding box
59
+ """
60
+ if isinstance(img, str):
61
+ img = Image.open(img).convert("RGB")
62
+ if isinstance(img, Image.Image):
63
+ img = np.array(img)
64
+ rgb = img
65
+ gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
66
+
67
+ # EDGE DETECTION
68
+ edges = cv2.Canny(gray, 25, 25, apertureSize=3, L2gradient=True)
69
+
70
+ # FILLS IN NEMATODE EDGES BY "PUFFING" IT UP, ALSO REMOVES OTHER DEBRIS
71
+ kernel = np.ones((11, 11), np.uint8)
72
+ edges_dilate = cv2.dilate(edges, kernel, iterations=3)
73
+ edges_erode = cv2.erode(edges_dilate, kernel, iterations=3)
74
+ cnts, _ = cv2.findContours(edges_erode, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
75
+ cnt = max(cnts, key=cv2.contourArea)
76
+ fill = np.zeros(edges.shape, np.uint8)
77
+ cv2.drawContours(fill, [cnt], -1, 255, cv2.FILLED)
78
+
79
+ # CROPS THE BINARY IMAGE DEPENDING ON WHERE THE WHITE PIXELS ARE
80
+ x1, y1 = (
81
+ max(np.argmax(fill.max(0)), 0),
82
+ max(np.argmax(fill.max(1)), 0),
83
+ )
84
+ x2, y2 = (
85
+ min(
86
+ fill.shape[1] - np.argmax(np.flip(fill.max(0))),
87
+ fill.shape[1],
88
+ ),
89
+ min(
90
+ fill.shape[0] - np.argmax(np.flip(fill.max(1))),
91
+ fill.shape[0],
92
+ ),
93
+ )
94
+ if y2 - y1 < x2 - x1:
95
+ delta = ((x2 - x1) - (y2 - y1)) // 2
96
+ if y1 < delta:
97
+ y2 += 2 * delta - y1
98
+ y1 = 0
99
+ else:
100
+ y1 -= delta
101
+ y2 += delta
102
+ else:
103
+ delta = ((y2 - y1) - (x2 - x1)) // 2
104
+ if x1 < delta:
105
+ x2 += 2 * delta - x1
106
+ x1 = 0
107
+ else:
108
+ x1 -= delta
109
+ x2 += delta
110
+ y, x = rgb.shape[:2]
111
+ x2 = min(x2, x)
112
+ y2 = min(y2, y)
113
+ x1 = max(0, x1)
114
+ y1 = max(0, y1)
115
+ # CROPS AND RESIZES IMAGE
116
+ return x1, y1, x2, y2
117
+
118
+
119
+ model, preprocessing, class_to_idx, idx_to_class = None, None, None, None
120
+
121
+ current_model_type = None
122
+
123
+ results_cache: dict[str, str] = {}
124
+ current_image = None
125
+ autocrop = True
126
+
127
+ temp_files: list[str] = []
128
+ all_images: list[str] = []
129
+
130
+
131
+ def load_model(model_name: str = "EfficientNetV2-S"):
132
+ """
133
+ Loads model and modifies global state
134
+ """
135
+ global model, preprocessing, class_to_idx, idx_to_class, current_model_type
136
+ if model_name is not None:
137
+ filename = model_filenames[model_name]
138
+ filepath = hf_hub_download(
139
+ repo_id="VikramR/NematodeClassification",
140
+ filename=filename,
141
+ )
142
+ (model_state, _, _, _, _, _, config, class_to_idx, _) = torch.load(
143
+ filepath, map_location=device
144
+ )
145
+
146
+ current_model_type = config["model_type"]
147
+ model = models[config["model_type"]](config).to(device)
148
+ model.load_state_dict(model_state)
149
+ model = model.eval()
150
+ idx_to_class = {idx: img_cls for img_cls, idx in class_to_idx.items()}
151
+ preprocessing = get_preprocessing(current_model_type)
152
+
153
+
154
+ def display_model():
155
+ """
156
+ Displays the current selected model in the textbox
157
+ """
158
+ global current_model_type
159
+ model_name = model_names[current_model_type]
160
+ return f"Current Model Type: {model_name}. Reupload model to change it."
161
+
162
+
163
+ def clear():
164
+ """
165
+ Resets global state
166
+ """
167
+ global results_cache, current_image
168
+ results_cache = {}
169
+ current_image = None
170
+ for file in all_images:
171
+ os.remove(file)
172
+ for file in temp_files:
173
+ os.remove(file)
174
+
175
+
176
+ @torch.no_grad()
177
+ def run_image(img: Image.Image):
178
+ global preprocessing, device, model, class_to_idx, idx_to_class
179
+ img = pil_to_tensor(img)[None].to(device)
180
+ img = preprocessing(img)
181
+ logits = model(img)
182
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
183
+ prob, label = torch.max(probs, dim=0)
184
+ n_classes = len(class_to_idx)
185
+ results = {
186
+ "Probability": list(range(n_classes)),
187
+ "Class": [idx_to_class[i] for i in range(n_classes)],
188
+ }
189
+ for i in range(n_classes):
190
+ results["Probability"][i] = float(probs[i].item())
191
+ label = idx_to_class[label.item()]
192
+ prob = prob.item()
193
+ return results, (prob, label)
194
+
195
+
196
+ def prev_crop_preview() -> str:
197
+ """
198
+ Preview for the current cropped image
199
+ """
200
+ global autocrop, current_image, temp_files
201
+ img = Image.open(current_image).convert("RGB")
202
+ if autocrop:
203
+ box = cropped_img(img)
204
+ img = img.crop(box)
205
+ with NamedTemporaryFile(
206
+ mode="wb", dir=get_upload_folder(), suffix=".png", delete=False
207
+ ) as f:
208
+ pth = f.name
209
+ img.save(f)
210
+ temp_files.append(f.name)
211
+ return pth
212
+
213
+
214
+ def predict(img: str) -> gr.BarPlot:
215
+ global results_cache
216
+ img = Image.open(img).convert("RGB")
217
+ result, (prob, label) = run_image(img)
218
+ df = pd.DataFrame(result)
219
+ current_image_name = Path(current_image).name
220
+ result = dict(zip(result["Class"], result["Probability"]))
221
+ results_cache[current_image_name] = {
222
+ "Distribution": result,
223
+ "Classification": {"Probability": prob, "Label": label},
224
+ }
225
+ return gr.BarPlot(
226
+ df, x="Class", y="Probability", tooltip=class_to_idx.keys(), y_lim=(0, 1)
227
+ )
228
+
229
+
230
+ def predict_all():
231
+ global all_images, results_cache
232
+ for img in all_images:
233
+ current_image_name = Path(img).name
234
+ img = Image.open(img).convert("RGB")
235
+ if autocrop:
236
+ box = cropped_img(img)
237
+ img = img.crop(box)
238
+ result, (prob, label) = run_image(img)
239
+ result = dict(zip(result["Class"], result["Probability"]))
240
+ results_cache[current_image_name] = {
241
+ "Distribution": result,
242
+ "Classification": {"Probability": prob, "Label": label},
243
+ }
244
+
245
+
246
+ def get_results_cache():
247
+ global results_cache
248
+ return results_cache
249
+
250
+
251
+ def save_results():
252
+ global results_cache
253
+ with NamedTemporaryFile(
254
+ "w",
255
+ delete=False,
256
+ prefix="model_predictions_",
257
+ suffix=".json",
258
+ ) as f:
259
+ json.dump(results_cache, f, indent=4)
260
+ temp_files.append(f.name)
261
+ return f.name
262
+
263
+
264
+ def select_image(files, sd: gr.SelectData):
265
+ # Returns the name of the image which you click on in the file upload
266
+ global current_image
267
+ current_image = files[sd.index].name
268
+ return files[sd.index].name
269
+
270
+
271
+ def show_crop_panel():
272
+ global current_image
273
+ return current_image
274
+
275
+
276
+ def upload_files(files):
277
+ global all_images
278
+ all_images = files
279
+
280
+
281
+ def toggle_autocrop(res):
282
+ global autocrop
283
+ autocrop = res
284
+
285
+
286
+ def show_preview(x):
287
+ # When you click the crop button, the preview is updated and cached
288
+ return x["composite"]
289
+
290
+
291
+ def show_current_filename():
292
+ orig_msg = "Crop Image Here (Optional), then click Run to Predict"
293
+ current_img_name = Path(current_image).name
294
+ return f"{orig_msg}\n\nCurrent File: {current_img_name}"
295
+
296
+
297
+ with gr.Blocks() as demo:
298
+ demo.load(load_model)
299
+ with gr.Row():
300
+ gr.Textbox(
301
+ "Only use this application on the following classes of nematodes: "
302
+ + "Helicotylenchus, Hoplolaimus, Meloidogyne, Mesocriconema, "
303
+ "Pratylenchus, Trichodorus, and Tylenchorhynchus",
304
+ text_align="center",
305
+ label="DISCLAIMER",
306
+ )
307
+ with gr.Row():
308
+ model_text = gr.Textbox(
309
+ "Default model: EfficientNetV2-S. To choose a different model, choose one from the dropdown on the right",
310
+ label="Current Model",
311
+ )
312
+ model_select = gr.Dropdown(
313
+ choices=["EfficientNetV2-S", "MobileNetV3-L", "ResNet101", "Swin V2-B"],
314
+ value="EfficientNetV2-S",
315
+ label="Select Model Architecture",
316
+ )
317
+ with gr.Row():
318
+ with gr.Column():
319
+ gr.Textbox(
320
+ "Upload Images, then Select Each One to Crop & Run",
321
+ show_label=False,
322
+ )
323
+ files = gr.File(file_types=["image"], file_count="multiple")
324
+ batch_predict = gr.Button("Predict All")
325
+
326
+ with gr.Column():
327
+ mid_col_text = gr.Textbox(
328
+ "Crop Image Here (Optional), then Click Run to Predict",
329
+ show_label=False,
330
+ )
331
+ autocrop_toggle = gr.Checkbox(value=True, label="Automatic Cropping")
332
+ cropper = gr.ImageEditor(
333
+ type="filepath",
334
+ sources=None,
335
+ layers=False,
336
+ brush=False,
337
+ )
338
+ crop = gr.Button("Crop")
339
+ with gr.Column():
340
+ gr.Textbox(
341
+ "Image Preview (What will be run through network)",
342
+ show_label=False,
343
+ )
344
+ preview = gr.Image(
345
+ sources=None,
346
+ type="filepath",
347
+ height=250,
348
+ )
349
+ classify = gr.Button("Classify")
350
+ plot = gr.BarPlot()
351
+
352
+ with gr.Row():
353
+ gr.Textbox(
354
+ "Here are the predicted labels for your images in JSON format",
355
+ label="Predictions",
356
+ )
357
+ with gr.Row():
358
+ json_results = gr.JSON()
359
+ download = gr.DownloadButton("Download Predictions")
360
+
361
+ download.click(save_results, outputs=download)
362
+ model_select.change(load_model, inputs=model_select).then(
363
+ display_model, outputs=model_text
364
+ )
365
+ model_select
366
+
367
+ files.upload(upload_files, inputs=files)
368
+ files.select(select_image, inputs=files, outputs=cropper).then(
369
+ show_current_filename,
370
+ outputs=mid_col_text,
371
+ ).then(
372
+ prev_crop_preview,
373
+ outputs=preview,
374
+ )
375
+
376
+ autocrop_toggle.change(toggle_autocrop, inputs=autocrop_toggle).then(
377
+ show_crop_panel, outputs=cropper
378
+ ).then(
379
+ prev_crop_preview,
380
+ outputs=preview,
381
+ )
382
+
383
+ batch_predict.click(predict_all).then(get_results_cache, outputs=json_results)
384
+
385
+ files.clear(clear).then(get_results_cache, outputs=json_results)
386
+
387
+ crop.click(show_preview, inputs=cropper, outputs=preview)
388
+
389
+ classify.click(predict, inputs=preview, outputs=plot).then(
390
+ get_results_cache, outputs=json_results
391
+ )
392
+ demo.unload(clear)
393
+
394
+
395
+ if __name__ == "__main__":
396
+ demo.launch()
external_models.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision.models import (
4
+ efficientnet_v2_s,
5
+ mobilenet_v3_large,
6
+ resnet101,
7
+ swin_v2_b,
8
+ )
9
+
10
+ import math
11
+
12
+ NUM_GRADUAL_UNFREEZING_STAGES = 5
13
+ SEED = 123
14
+
15
+
16
+ ACT_FUNCS = {
17
+ "relu": nn.ReLU,
18
+ "tanh": nn.Tanh, # Tanh is not used
19
+ }
20
+
21
+
22
+ def classification_head(in_features: int, config: dict, flatten=False) -> nn.Sequential:
23
+ torch.manual_seed(SEED)
24
+ first_linear = nn.Linear(in_features, config["units"], bias=False)
25
+ nn.init.kaiming_uniform_(first_linear.weight, nonlinearity=config["activation"])
26
+ head = nn.Sequential(
27
+ first_linear,
28
+ nn.LayerNorm(config["units"]),
29
+ ACT_FUNCS[config["activation"]](),
30
+ nn.Dropout(config["dropout"]),
31
+ nn.Linear(config["units"], 7),
32
+ )
33
+ if flatten:
34
+ head.insert(0, nn.Flatten())
35
+
36
+ return head
37
+
38
+
39
+ class PretrainedModel(nn.Module):
40
+ def __init__(self, config):
41
+ super().__init__()
42
+ self.unfreezing_stage = 0
43
+ # The layers in forwarding order
44
+ self.layers_to_unfreeze: list[nn.Module] = []
45
+ self.model_type: str = config["model_type"]
46
+ self.grad_cam_layer: list[nn.Module] = []
47
+
48
+ def set_head_trainable(self):
49
+ """
50
+ Requires overriding if the classification head is not called
51
+ "model.classifier"
52
+ """
53
+ self.model.classifier.requires_grad_(True)
54
+
55
+ def inc_grad_unfreezing(self):
56
+ """
57
+ Increments the gradual unfreezing process by unfreezing
58
+ the next 100% / NUM_GRADUAL_UNFREEZING_STAGES layers
59
+ """
60
+ if self.unfreezing_stage <= NUM_GRADUAL_UNFREEZING_STAGES:
61
+ self.unfreezing_stage += 1
62
+ self.set_unfreezing_stage(self.unfreezing_stage)
63
+
64
+ def set_unfreezing_stage(self, unfreezing_stage: int):
65
+ self.unfreezing_stage = unfreezing_stage
66
+ if self.unfreezing_stage > NUM_GRADUAL_UNFREEZING_STAGES:
67
+ self.unfreezing_stage = NUM_GRADUAL_UNFREEZING_STAGES
68
+ self.requires_grad_(True)
69
+ return
70
+ else:
71
+ # Make sure all layers are untrainable before
72
+ # setting the trainable layers to be trainable
73
+ self.requires_grad_(False)
74
+ layer_index = math.ceil(
75
+ self.unfreezing_stage
76
+ * len(self.layers_to_unfreeze)
77
+ / NUM_GRADUAL_UNFREEZING_STAGES
78
+ )
79
+ for module in self.layers_to_unfreeze[-layer_index:]:
80
+ module.requires_grad_(True)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ return self.model(x)
84
+
85
+
86
+ class EfficientNet(PretrainedModel):
87
+ def __init__(self, config: dict):
88
+ super().__init__(config)
89
+ self.model = efficientnet_v2_s()
90
+ in_features = self.model.classifier[1].in_features
91
+ self.model.classifier = classification_head(in_features, config)
92
+ self.layers_to_unfreeze = [
93
+ self.model.features[i] for i in range(len(self.model.features))
94
+ ]
95
+ self.grad_cam_layer = [self.model.features[-1][-1]]
96
+
97
+
98
+ class MobileNet(PretrainedModel):
99
+ """
100
+ MobileNet V3 or V4, customized for our transfer learning
101
+
102
+ V4 paper:
103
+ https://arxiv.org/abs/2404.10518
104
+ """
105
+
106
+ def __init__(self, config: dict, version: str = "v3"):
107
+ super().__init__(config)
108
+ # MBNetV4 is in a MBNetV3 object for some reason
109
+ if version == "v3":
110
+ self.model = mobilenet_v3_large()
111
+ in_features = self.model.classifier[0].in_features
112
+
113
+ self.layers_to_unfreeze = [
114
+ self.model.features[i] for i in range(len(self.model.features))
115
+ ]
116
+ self.grad_cam_layer = [self.model.features[-1][-1]]
117
+ else:
118
+ raise NotImplementedError()
119
+ self.model.classifier = classification_head(in_features, config)
120
+
121
+
122
+ class ResNet(PretrainedModel):
123
+ def __init__(self, config: dict):
124
+ super().__init__(config)
125
+ self.model = resnet101()
126
+ in_features = self.model.fc.in_features
127
+ self.model.fc = classification_head(in_features, config)
128
+ self.layers_to_unfreeze = [
129
+ self.model.conv1,
130
+ self.model.bn1,
131
+ self.model.layer1,
132
+ self.model.layer2,
133
+ self.model.layer3,
134
+ self.model.layer4,
135
+ ]
136
+ self.grad_cam_layer = [self.model.layer4[-1]]
137
+
138
+ def set_head_trainable(self):
139
+ self.model.fc.requires_grad_(True)
140
+
141
+
142
+ class Swin(PretrainedModel):
143
+ def __init__(self, config: dict):
144
+ super().__init__(config)
145
+ self.model = swin_v2_b()
146
+ in_features = self.model.head.in_features
147
+ self.model.head = classification_head(in_features, config)
148
+ self.layers_to_unfreeze = [
149
+ self.model.features[i] for i in range(len(self.model.features))
150
+ ] + [self.model.norm]
151
+ self.grad_cam_layer = [self.model.permute]
152
+
153
+ def set_head_trainable(self):
154
+ self.model.head.requires_grad_(True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ huggingface_hub==0.23.4
4
+ numpy==1.26.4
5
+ opencv-python==4.10.0.82
6
+ gradio==4.40.0
utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import v2
3
+
4
+ RESIZE = {
5
+ "effnet": 384,
6
+ "resnet": 224,
7
+ "mbnet": 224,
8
+ "swin": 256,
9
+ }
10
+
11
+
12
+ def get_preprocessing(model_type: str) -> v2.Compose:
13
+ """
14
+ Gets the right image preprocessing transform for each model
15
+
16
+ Parameters
17
+ ----------
18
+ model_type : str
19
+ Model nickname
20
+
21
+ Returns
22
+ -------
23
+ v2.Compose
24
+ Preprocessing transform
25
+
26
+ Raises
27
+ ------
28
+ NotImplementedError
29
+ If it's an invalid model_type
30
+ """
31
+ resize = RESIZE[model_type]
32
+ transform = v2.Compose(
33
+ [
34
+ v2.ToImage(),
35
+ v2.Resize((resize, resize)),
36
+ v2.ToDtype(torch.float, True),
37
+ v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
38
+ v2.Grayscale(3),
39
+ ]
40
+ )
41
+
42
+ return transform