cfoli commited on
Commit
af8ad28
·
1 Parent(s): 91a39a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -42,7 +42,8 @@ configs = {
42
  "MEAN": (0.485, 0.456, 0.406),
43
  "STD": (0.229, 0.224, 0.225),
44
 
45
- "DEFAULT_BACKBONE": "EfficientNet(b3)"
 
46
  }
47
 
48
  """### Define helper functions"""
@@ -210,7 +211,7 @@ class modelModule(torch_light.LightningModule):
210
 
211
  """### Create function for running inference (i.e., assistive medical diagnosis)"""
212
 
213
- def run_diagnosis(backbone_name, input_image, preprocess_fn = None, Idx2labels=None):
214
 
215
  input_tensor = preprocess_fn(input_image)
216
  input_tensor = input_tensor.unsqueeze(dim = 0)
@@ -263,7 +264,7 @@ example_list = [
263
  # example_list = [['/content/new_labels.csv',"ResNet50"]]
264
 
265
  gradio_app = gradio.Interface(
266
- fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict),
267
 
268
  inputs = [gradio.Dropdown(["ConvNeXt(small)", "ConvNeXt(tiny)", "EfficientNet(v2_small)", "EfficientNet(b3)", "RegNet(x3_2GF)","ResNet50"], value="EfficientNet(b3)", label="Select Backbone Model"),
269
  gradio.Image(type="pil", label="Load chest-X-ray image here")],
 
42
  "MEAN": (0.485, 0.456, 0.406),
43
  "STD": (0.229, 0.224, 0.225),
44
 
45
+ "DEFAULT_BACKBONE": "EfficientNet(b3)",
46
+ "THRESHOLD": 0.5
47
  }
48
 
49
  """### Define helper functions"""
 
211
 
212
  """### Create function for running inference (i.e., assistive medical diagnosis)"""
213
 
214
+ def run_diagnosis(backbone_name, input_image, preprocess_fn = None, Idx2labels=None, threshold = configs["THRESHOLD"]):
215
 
216
  input_tensor = preprocess_fn(input_image)
217
  input_tensor = input_tensor.unsqueeze(dim = 0)
 
264
  # example_list = [['/content/new_labels.csv',"ResNet50"]]
265
 
266
  gradio_app = gradio.Interface(
267
+ fn = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict, threshold = configs["THRESHOLD"]),
268
 
269
  inputs = [gradio.Dropdown(["ConvNeXt(small)", "ConvNeXt(tiny)", "EfficientNet(v2_small)", "EfficientNet(b3)", "RegNet(x3_2GF)","ResNet50"], value="EfficientNet(b3)", label="Select Backbone Model"),
270
  gradio.Image(type="pil", label="Load chest-X-ray image here")],