khan994 commited on
Commit
cea103b
·
1 Parent(s): 1733b83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -41
app.py CHANGED
@@ -17,47 +17,7 @@ examples=["filibe-1-1.jpg",
17
  "ohrid-3-1.jpg",
18
  "varna-1-1.jpg"]
19
 
20
- #@title DataLoader
21
- path = "train_val_cropped"
22
- dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
23
- get_items = get_image_files,
24
- splitter = GrandparentSplitter(train_name="train",
25
- valid_name="valid"),
26
- get_y=parent_label,
27
- item_tfms=RandomResizedCrop(128, min_scale=0.7),
28
- batch_tfms=[*aug_transforms(),
29
- Normalize.from_stats(*imagenet_stats)])
30
- dls_augmented = dblock.dataloaders(path, shuffle=True)
31
 
32
- def gradcam(img):
33
- x,= first(dls_augmented.test_dl([img]))
34
- hook_output = Hook()
35
- hook = learn.model[0].register_forward_hook(hook_output.hook_func)
36
- with torch.no_grad(): output = learn.model.eval()(x)
37
- act = hook_output.stored[0]
38
- hook.remove()
39
-
40
- input_size=act.shape[0]
41
- out_size=learn.model[1][-1].in_features
42
- kernel_size=act.shape[1]
43
- new_act = tensor(np.zeros((out_size, kernel_size, kernel_size)))
44
- sum = tensor(np.zeros((1, kernel_size, kernel_size)))
45
- for i in range(0,input_size,4):
46
- sum=tensor(np.zeros((1, kernel_size, kernel_size)))
47
- for j in range(i, i+4):
48
- sum=sum+act[j, :, :]
49
- new_act[int(i/4), :, :]=sum/4
50
- cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, new_act)
51
- gcam=cam_map[1].detach().cpu()
52
- x_dec = TensorImage(dls_augmented.train.decode((x,))[0][0])
53
- _,ax = plt.subplots()
54
- x_dec.show(ctx=ax)
55
- ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,128,128,0), interpolation='bilinear', cmap='magma');
56
- return gcam
57
-
58
-
59
-
60
- demo = gr.Interface(fn=gradcam, inputs=image, outputs="image", examples=examples)
61
- #demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples)
62
 
63
  demo.launch(inline=False)
 
17
  "ohrid-3-1.jpg",
18
  "varna-1-1.jpg"]
19
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  demo.launch(inline=False)