Shilpaj commited on
Commit
baf2a78
·
verified ·
1 Parent(s): bb98138

Fix: Inference issue

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -58,9 +58,6 @@ def main():
58
  """
59
  )
60
 
61
- # #############################################################################
62
- # ################################ GradCam Tab ################################
63
- # #############################################################################
64
  with gr.Tab("GradCam"):
65
  gr.Markdown(
66
  """
@@ -69,27 +66,32 @@ def main():
69
  """
70
  )
71
  with gr.Row():
72
- img_input = [gr.Image(label="Input Image", type="numpy", height=224)]
73
- gradcam_outputs = [
74
- gr.Label(label="Predictions"),
75
- gr.Image(label="GradCAM Output", height=224)
76
- ]
77
 
78
  with gr.Row():
79
- gradcam_inputs = [
80
- gr.Slider(0, 1, value=0.5, label="Activation Map Transparency"),
81
- gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions"),
82
- gr.Slider(1, 6, value=4, step=1, label="Target Layer Number")
83
- ]
84
 
85
  gradcam_button = gr.Button("Generate GradCAM")
86
 
87
- # Pass model to inference function using partial
88
- from functools import partial
89
- inference_fn = partial(inference, model=model, classes=classes)
90
- gradcam_button.click(inference_fn, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs)
 
 
 
 
 
 
 
 
 
91
 
92
- gr.Markdown("## Examples")
93
  gr.Examples(
94
  examples=[
95
  ["./assets/examples/dog.jpg", 0.5, 3, 4],
@@ -103,13 +105,14 @@ def main():
103
  ["./assets/examples/plane.jpg", 0.5, 3, 4],
104
  ["./assets/examples/ship.png", 0.5, 3, 4]
105
  ],
106
- inputs=img_input + gradcam_inputs,
107
- fn=inference_fn,
108
- outputs=gradcam_outputs
 
109
  )
110
 
111
- # Launch the demo (moved inside the Blocks context)
112
- demo.launch(debug=True)
113
 
114
 
115
  if __name__ == "__main__":
 
58
  """
59
  )
60
 
 
 
 
61
  with gr.Tab("GradCam"):
62
  gr.Markdown(
63
  """
 
66
  """
67
  )
68
  with gr.Row():
69
+ img_input = gr.Image(label="Input Image", type="numpy", height=224)
70
+ with gr.Column():
71
+ label_output = gr.Label(label="Predictions")
72
+ gradcam_output = gr.Image(label="GradCAM Output", height=224)
 
73
 
74
  with gr.Row():
75
+ alpha_slider = gr.Slider(0, 1, value=0.5, label="Activation Map Transparency")
76
+ top_k_slider = gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions")
77
+ target_layer_slider = gr.Slider(1, 6, value=4, step=1, label="Target Layer Number")
 
 
78
 
79
  gradcam_button = gr.Button("Generate GradCAM")
80
 
81
+ def inference_wrapper(image, alpha, top_k, target_layer):
82
+ return inference(image, alpha, top_k, target_layer, model=model, classes=classes)
83
+
84
+ gradcam_button.click(
85
+ fn=inference_wrapper,
86
+ inputs=[
87
+ img_input,
88
+ alpha_slider,
89
+ top_k_slider,
90
+ target_layer_slider
91
+ ],
92
+ outputs=[label_output, gradcam_output]
93
+ )
94
 
 
95
  gr.Examples(
96
  examples=[
97
  ["./assets/examples/dog.jpg", 0.5, 3, 4],
 
105
  ["./assets/examples/plane.jpg", 0.5, 3, 4],
106
  ["./assets/examples/ship.png", 0.5, 3, 4]
107
  ],
108
+ inputs=[img_input, alpha_slider, top_k_slider, target_layer_slider],
109
+ outputs=[label_output, gradcam_output],
110
+ fn=inference_wrapper,
111
+ cache_examples=True
112
  )
113
 
114
+ # Launch the demo
115
+ demo.launch(server_name="0.0.0.0", debug=True)
116
 
117
 
118
  if __name__ == "__main__":