rohithb commited on
Commit
d95c976
Β·
1 Parent(s): 080a25a

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitignore +163 -0
  2. README.md +7 -7
  3. assets/cat_dog.jpg +0 -0
  4. gradcam/app.py +58 -0
  5. gradcam/utils.py +100 -0
  6. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ workspace.code-workspace
2
+ flagged/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: CLIP CamGrad
3
- emoji: πŸ”₯
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.47.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Gradio OpenAI CLIP Grad-CAM
3
+ emoji: πŸ”­
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 2.9.4
8
+ app_file: gradcam/app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
assets/cat_dog.jpg ADDED
gradcam/app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import clip
3
+ import torch
4
+
5
+ import utils
6
+
7
+ #clip_model = "RN50x4"
8
+ clip_model = "RN50x64"
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model, preprocess = clip.load(clip_model, device=device, jit=False)
11
+ model.eval()
12
+
13
+
14
+ def grad_cam_fn(text, img, saliency_layer):
15
+ resize = model.visual.input_resolution
16
+ img = img.resize((resize, resize))
17
+
18
+ text_input = clip.tokenize([text]).to(device)
19
+ text_feature = model.encode_text(text_input).float()
20
+ image_input = preprocess(img).unsqueeze(0).to(device)
21
+
22
+ attn_map = utils.gradCAM(
23
+ model.visual,
24
+ image_input,
25
+ text_feature,
26
+ getattr(model.visual, saliency_layer)
27
+ )
28
+ attn_map = attn_map.squeeze().detach().cpu().numpy()
29
+ attn_map = utils.getAttMap(img, attn_map)
30
+
31
+ return attn_map
32
+
33
+
34
+ interface = gr.Interface(
35
+ fn=grad_cam_fn,
36
+ inputs=[
37
+ gr.inputs.Textbox(
38
+ label="Target Text",
39
+ lines=1),
40
+ gr.inputs.Image(
41
+ label='Input Image',
42
+ image_mode="RGB",
43
+ type='pil',
44
+ shape=(512, 512)),
45
+ gr.inputs.Dropdown(
46
+ ["layer4", "layer3", "layer2", "layer1"],
47
+ default="layer4",
48
+ label="Saliency Layer")
49
+ ],
50
+ outputs=gr.outputs.Image(
51
+ type="pil",
52
+ label="Attention Map"),
53
+ examples=[
54
+ ['a cat lying on the floor', 'assets/cat_dog.jpg', 'layer4'],
55
+ ['a dog sitting', 'assets/cat_dog.jpg', 'layer4']
56
+ ],
57
+ description="OpenAI CLIP Grad CAM")
58
+ interface.launch()
gradcam/utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import matplotlib.cm
6
+ from PIL import Image
7
+
8
+ # Adapted from: https://colab.research.google.com/github/kevinzakka/clip_playground/blob/main/CLIP_GradCAM_Visualization.ipynb
9
+ class Hook:
10
+ """Attaches to a module and records its activations and gradients."""
11
+
12
+ def __init__(self, module: nn.Module):
13
+ self.data = None
14
+ self.hook = module.register_forward_hook(self.save_grad)
15
+
16
+ def save_grad(self, module, input, output):
17
+ self.data = output
18
+ output.requires_grad_(True)
19
+ output.retain_grad()
20
+
21
+ def __enter__(self):
22
+ return self
23
+
24
+ def __exit__(self, exc_type, exc_value, exc_traceback):
25
+ self.hook.remove()
26
+
27
+ @property
28
+ def activation(self) -> torch.Tensor:
29
+ return self.data
30
+
31
+ @property
32
+ def gradient(self) -> torch.Tensor:
33
+ return self.data.grad
34
+
35
+
36
+ # Reference: https://arxiv.org/abs/1610.02391
37
+ def gradCAM(
38
+ model: nn.Module,
39
+ input: torch.Tensor,
40
+ target: torch.Tensor,
41
+ layer: nn.Module
42
+ ) -> torch.Tensor:
43
+ # Zero out any gradients at the input.
44
+ if input.grad is not None:
45
+ input.grad.data.zero_()
46
+
47
+ # Disable gradient settings.
48
+ requires_grad = {}
49
+ for name, param in model.named_parameters():
50
+ requires_grad[name] = param.requires_grad
51
+ param.requires_grad_(False)
52
+
53
+ # Attach a hook to the model at the desired layer.
54
+ assert isinstance(layer, nn.Module)
55
+ with Hook(layer) as hook:
56
+ # Do a forward and backward pass.
57
+ output = model(input)
58
+ output.backward(target)
59
+
60
+ grad = hook.gradient.float()
61
+ act = hook.activation.float()
62
+
63
+ # Global average pool gradient across spatial dimension
64
+ # to obtain importance weights.
65
+ alpha = grad.mean(dim=(2, 3), keepdim=True)
66
+ # Weighted combination of activation maps over channel
67
+ # dimension.
68
+ gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
69
+ # We only want neurons with positive influence so we
70
+ # clamp any negative ones.
71
+ gradcam = torch.clamp(gradcam, min=0)
72
+
73
+ # Resize gradcam to input resolution.
74
+ gradcam = F.interpolate(
75
+ gradcam,
76
+ input.shape[2:],
77
+ mode='bicubic',
78
+ align_corners=False)
79
+
80
+ # Restore gradient settings.
81
+ for name, param in model.named_parameters():
82
+ param.requires_grad_(requires_grad[name])
83
+
84
+ return gradcam
85
+
86
+
87
+ # Modified from: https://github.com/salesforce/ALBEF/blob/main/visualization.ipynb
88
+ def getAttMap(img, attn_map):
89
+ # Normalize attention map
90
+ attn_map = attn_map - attn_map.min()
91
+ if attn_map.max() > 0:
92
+ attn_map = attn_map / attn_map.max()
93
+
94
+ H = matplotlib.cm.jet(attn_map)
95
+ H = (H * 255).astype(np.uint8)[:, :, :3]
96
+ img_heatmap = Image.fromarray(H)
97
+ img_heatmap = img_heatmap.resize((256, 256))
98
+
99
+ return Image.blend(
100
+ img.resize((256, 256)), img_heatmap, 0.4)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=2.9.0,<2.10.0
2
+ torch>=1.10.0,<1.11.0
3
+ git+https://github.com/openai/CLIP.git
4
+ Pillow
5
+ matplotlib
6
+ numpy