ImAMJayKIM commited on
Commit
a9b808b
·
verified ·
1 Parent(s): b95f88c

feat : decoder 의 마지막 layer 에서만 heatmap 생성

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -345,25 +345,22 @@ def predict_captioning(image):
345
  )
346
 
347
  tmp_dir = tempfile.mkdtemp(prefix="combined_captioning_gradio_")
348
- heatmap_images = []
349
- n_layers = len(runtime["decoder"].layers)
350
-
351
- # 각 decoder layer별 cross-attention heatmap 이미지를 만들어 Gallery에 표시한다.
352
- for layer in range(1, n_layers + 1):
353
- cross_atten_path = Path(tmp_dir) / f"cross_attention_layer_{layer}.jpg"
354
- runtime["decoder"].show_cross_atten(
355
- enc_dec_atten[0],
356
- caption_tokens,
357
- layer,
358
- image_tensor.squeeze(0).detach().cpu(),
 
359
  str(cross_atten_path),
 
360
  )
361
- heatmap_images.append(
362
- (
363
- str(cross_atten_path),
364
- f"Layer {layer}",
365
- )
366
- )
367
 
368
  return caption, heatmap_images
369
 
 
345
  )
346
 
347
  tmp_dir = tempfile.mkdtemp(prefix="combined_captioning_gradio_")
348
+ last_layer = len(runtime["decoder"].layers)
349
+ cross_atten_path = Path(tmp_dir) / "cross_attention_last_layer.jpg"
350
+
351
+ runtime["decoder"].show_cross_atten(
352
+ enc_dec_atten[0],
353
+ caption_tokens,
354
+ last_layer,
355
+ image_tensor.squeeze(0).detach().cpu(),
356
+ str(cross_atten_path),
357
+ )
358
+ heatmap_images = [
359
+ (
360
  str(cross_atten_path),
361
+ f"Last Layer ({last_layer})",
362
  )
363
+ ]
 
 
 
 
 
364
 
365
  return caption, heatmap_images
366