MissingBreath commited on
Commit
4452a89
·
verified ·
1 Parent(s): 59b41d8

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +1 -32
api.py CHANGED
@@ -304,30 +304,6 @@ async def classify(image: UploadFile = File(...)):
304
  return {"error": "No image provided"}
305
 
306
 
307
- def make_gradcam_heatmap_Palm(img_array, model, base_model_name, last_conv_layer_name, pred_index=None):
308
- base_model = model.get_layer(base_model_name)
309
- last_conv_layer = base_model.get_layer(last_conv_layer_name)
310
-
311
- grad_model = tf.keras.models.Model(
312
- inputs=[model.inputs],
313
- outputs=[last_conv_layer.output, model.output]
314
- )
315
-
316
- with tf.GradientTape() as tape:
317
- conv_outputs, predictions = grad_model(img_array)
318
- if pred_index is None:
319
- pred_index = tf.argmax(predictions[0])
320
- loss = predictions[:, pred_index]
321
-
322
- grads = tape.gradient(loss, conv_outputs)
323
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
324
- conv_outputs = conv_outputs[0]
325
- heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
326
- heatmap = tf.squeeze(heatmap)
327
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
328
- return heatmap.numpy()
329
-
330
-
331
  @app.post("/palmclassify")
332
  async def palmclassify(image: UploadFile = File(...)):
333
  if image is not None:
@@ -340,14 +316,7 @@ async def palmclassify(image: UploadFile = File(...)):
340
  predicted_class_idx = int(predicted_class_idx)
341
 
342
  last_mb = "Conv_1"
343
- img_array = tf.convert_to_tensor(img_array, dtype=tf.float32)
344
- # heatmap = make_gradcam_heatmap(img_array, modelPalm, last_mb)
345
- heatmap = make_gradcam_heatmap_Palm(
346
- img_array,
347
- modelPalm,
348
- base_model_name='mobilenetv2_1.00_224',
349
- last_conv_layer_name='Conv_1'
350
- )
351
 
352
  base64_image = display_gradcam( np.array(img), heatmap)
353
 
 
304
  return {"error": "No image provided"}
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  @app.post("/palmclassify")
308
  async def palmclassify(image: UploadFile = File(...)):
309
  if image is not None:
 
316
  predicted_class_idx = int(predicted_class_idx)
317
 
318
  last_mb = "Conv_1"
319
+ heatmap = make_gradcam_heatmap(img_array, modelPalm, last_mb)
 
 
 
 
 
 
 
320
 
321
  base64_image = display_gradcam( np.array(img), heatmap)
322