GunaKoppula commited on
Commit
b610a35
·
1 Parent(s): 2c165d8

Adding files

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: ERAV1 Session 12
3
- emoji: 😻
4
  colorFrom: green
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.39.0
8
  app_file: app.py
@@ -10,4 +10,26 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ERA Session12
3
+ emoji: 🔥
4
  colorFrom: green
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.39.0
8
  app_file: app.py
 
10
  license: mit
11
  ---
12
 
13
+ ### Gradio UI for CIFAR10 classification with ResNet
14
+
15
+ ## How to use?
16
+ 1. Select if you want visualize the misclassified images & Select the count of misclassified images.
17
+ 2. Select if you want to visualize the GradCAM images & Also select count of Gradcam images, Model layer and Opacity of the resulting image.
18
+ 3. Click on the upload button to upload the local image to be used for prediction and select the image for prediction.
19
+ 4. If you want use one of the sample images, please pick one from the list of 10 sample images.
20
+ 5. Select the top n classes for which you want see the model performance.
21
+ 6. Click on the Run button
22
+ 7. On the right side of the interface, the top view displays the selected number of misclassified images.
23
+ 8. The second view displays the GradCAM output.
24
+ 9. And Final view displays the top n predicitons for the given image.
25
+
26
+ ## Components Used:
27
+ 1. `gr.Dropdown` : Used for selecting the number of images for Misclassified & GradCAM output and also for the top n classes to be displayed.
28
+ 2. `gr.Checkbox` : Used for boolean inputs like if user wants to visualize Misclassified or if they want to visualize gradCAM images.
29
+ 3. `gr.Slider` : Used to select the opacity paramter to be used with GradCAM viaualization.
30
+ 4. `gr.Gallery`: Used to display a numebr of images, used for displaying input images and output images.
31
+ 5. `gr.UploadButton`: A generic file uplaod button, used for picking and uploading local image file for prediction.
32
+ 6. `gr.Button`: Used for calling the main prediction module.
33
+ 7. `gr.Label`: Used for displaying the top n classification results.
34
+
35
+ https://user-images.githubusercontent.com/23289802/258841585-4d2a75fa-3902-4839-a32a-bbfec4ef72ba.png
app.py CHANGED
@@ -171,7 +171,7 @@ with gr.Blocks() as app:
171
  label="Misclassified Images", info="Display misclassified images?"
172
  )
173
  misclassified_count = gr.Dropdown(
174
- choices=["10", "20"],
175
  label="Select Number of Images",
176
  info="Number of Misclassified images",
177
  visible=False,
@@ -188,7 +188,7 @@ with gr.Blocks() as app:
188
  info="Display GradCAM images?",
189
  )
190
  gradcam_count = gr.Dropdown(
191
- choices=["10", "20"],
192
  label="Select Number of Images",
193
  info="Number of GradCAM images",
194
  interactive=True,
@@ -285,4 +285,4 @@ with gr.Blocks() as app:
285
  )
286
 
287
 
288
- app.launch(server_name="0.0.0.0", server_port=9998)
 
171
  label="Misclassified Images", info="Display misclassified images?"
172
  )
173
  misclassified_count = gr.Dropdown(
174
+ choices=[str(i + 1) for i in range(20)],
175
  label="Select Number of Images",
176
  info="Number of Misclassified images",
177
  visible=False,
 
188
  info="Display GradCAM images?",
189
  )
190
  gradcam_count = gr.Dropdown(
191
+ choices=[str(i + 1) for i in range(20)],
192
  label="Select Number of Images",
193
  info="Number of GradCAM images",
194
  interactive=True,
 
285
  )
286
 
287
 
288
+ app.launch()
models/custom_resnet.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ResBlock(nn.Module):
6
+ def __init__(self, channels):
7
+ super(ResBlock, self).__init__()
8
+
9
+ self.resblock = nn.Sequential(
10
+ nn.Conv2d(
11
+ in_channels=channels,
12
+ out_channels=channels,
13
+ kernel_size=3,
14
+ stride=1,
15
+ padding=1,
16
+ bias=False,
17
+ ),
18
+ nn.BatchNorm2d(channels),
19
+ nn.ReLU(),
20
+ nn.Conv2d(
21
+ in_channels=channels,
22
+ out_channels=channels,
23
+ kernel_size=3,
24
+ stride=1,
25
+ padding=1,
26
+ bias=False,
27
+ ),
28
+ nn.BatchNorm2d(channels),
29
+ nn.ReLU(),
30
+ )
31
+
32
+ def forward(self, x):
33
+ return x + self.resblock(x)
34
+
35
+
36
+ class CustomResnet(nn.Module):
37
+ def __init__(self):
38
+ super(CustomResnet, self).__init__()
39
+
40
+ self.prep = nn.Sequential(
41
+ nn.Conv2d(
42
+ in_channels=3,
43
+ out_channels=64,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ bias=False,
48
+ ),
49
+ nn.BatchNorm2d(64),
50
+ nn.ReLU(),
51
+ )
52
+
53
+ self.layer1 = nn.Sequential(
54
+ nn.Conv2d(
55
+ in_channels=64,
56
+ out_channels=128,
57
+ kernel_size=3,
58
+ padding=1,
59
+ stride=1,
60
+ bias=False,
61
+ ),
62
+ nn.MaxPool2d(kernel_size=2),
63
+ nn.BatchNorm2d(128),
64
+ nn.ReLU(),
65
+ ResBlock(channels=128),
66
+ )
67
+
68
+ self.layer2 = nn.Sequential(
69
+ nn.Conv2d(
70
+ in_channels=128,
71
+ out_channels=256,
72
+ kernel_size=3,
73
+ padding=1,
74
+ stride=1,
75
+ bias=False,
76
+ ),
77
+ nn.MaxPool2d(kernel_size=2),
78
+ nn.BatchNorm2d(256),
79
+ nn.ReLU(),
80
+ )
81
+
82
+ self.layer3 = nn.Sequential(
83
+ nn.Conv2d(
84
+ in_channels=256,
85
+ out_channels=512,
86
+ kernel_size=3,
87
+ padding=1,
88
+ stride=1,
89
+ bias=False,
90
+ ),
91
+ nn.MaxPool2d(kernel_size=2),
92
+ nn.BatchNorm2d(512),
93
+ nn.ReLU(),
94
+ ResBlock(channels=512),
95
+ )
96
+
97
+ self.pool = nn.MaxPool2d(kernel_size=4)
98
+
99
+ self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
100
+
101
+ self.softmax = nn.Softmax(dim=-1)
102
+
103
+ def forward(self, x):
104
+ x = self.prep(x)
105
+ x = self.layer1(x)
106
+ x = self.layer2(x)
107
+ x = self.layer3(x)
108
+ x = self.pool(x)
109
+ x = x.view(-1, 512)
110
+ x = self.fc(x)
111
+ # x = self.softmax(x)
112
+ return x
models/resnet_lightning.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lightning as L
4
+ from torchmetrics import Accuracy
5
+ from typing import Any
6
+
7
+ from utils.common import one_cycle_lr
8
+
9
+ class ResidualBlock(L.LightningModule):
10
+ def __init__(self, channels):
11
+ super(ResidualBlock, self).__init__()
12
+
13
+ self.residual_block = nn.Sequential(
14
+ nn.Conv2d(
15
+ in_channels=channels,
16
+ out_channels=channels,
17
+ kernel_size=3,
18
+ stride=1,
19
+ padding=1,
20
+ bias=False,
21
+ ),
22
+ nn.BatchNorm2d(channels),
23
+ nn.ReLU(),
24
+ nn.Conv2d(
25
+ in_channels=channels,
26
+ out_channels=channels,
27
+ kernel_size=3,
28
+ stride=1,
29
+ padding=1,
30
+ bias=False,
31
+ ),
32
+ nn.BatchNorm2d(channels),
33
+ nn.ReLU(),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return x + self.residual_block(x)
38
+
39
+ class ResNet(L.LightningModule):
40
+ def __init__(
41
+ self, batch_size=512, shuffle=True, num_workers=4, learning_rate=0.003, scheduler_steps=None, maxlr=None, epochs=None
42
+ ):
43
+ super(ResNet, self).__init__()
44
+ self.data_dir = "./data"
45
+ self.batch_size = batch_size
46
+ self.shuffle = shuffle
47
+ self.num_workers = num_workers
48
+ self.learning_rate = learning_rate
49
+ self.scheduler_steps = scheduler_steps
50
+ self.maxlr = maxlr if maxlr is not None else learning_rate
51
+ self.epochs = epochs
52
+
53
+ self.prep = nn.Sequential(
54
+ nn.Conv2d(
55
+ in_channels=3,
56
+ out_channels=64,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1,
60
+ bias=False,
61
+ ),
62
+ nn.BatchNorm2d(64),
63
+ nn.ReLU(),
64
+ )
65
+
66
+ self.layer1 = nn.Sequential(
67
+ nn.Conv2d(
68
+ in_channels=64,
69
+ out_channels=128,
70
+ kernel_size=3,
71
+ padding=1,
72
+ stride=1,
73
+ bias=False,
74
+ ),
75
+ nn.MaxPool2d(kernel_size=2),
76
+ nn.BatchNorm2d(128),
77
+ nn.ReLU(),
78
+ ResidualBlock(channels=128),
79
+ )
80
+
81
+ self.layer2 = nn.Sequential(
82
+ nn.Conv2d(
83
+ in_channels=128,
84
+ out_channels=256,
85
+ kernel_size=3,
86
+ padding=1,
87
+ stride=1,
88
+ bias=False,
89
+ ),
90
+ nn.MaxPool2d(kernel_size=2),
91
+ nn.BatchNorm2d(256),
92
+ nn.ReLU(),
93
+ )
94
+
95
+ self.layer3 = nn.Sequential(
96
+ nn.Conv2d(
97
+ in_channels=256,
98
+ out_channels=512,
99
+ kernel_size=3,
100
+ padding=1,
101
+ stride=1,
102
+ bias=False,
103
+ ),
104
+ nn.MaxPool2d(kernel_size=2),
105
+ nn.BatchNorm2d(512),
106
+ nn.ReLU(),
107
+ ResidualBlock(channels=512),
108
+ )
109
+
110
+ self.pool = nn.MaxPool2d(kernel_size=4)
111
+
112
+ self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
113
+
114
+ self.softmax = nn.Softmax(dim=-1)
115
+
116
+ self.accuracy = Accuracy(task="multiclass", num_classes=10)
117
+
118
+ def forward(self, x):
119
+ x = self.prep(x)
120
+ x = self.layer1(x)
121
+ x = self.layer2(x)
122
+ x = self.layer3(x)
123
+ x = self.pool(x)
124
+ x = x.view(-1, 512)
125
+ x = self.fc(x)
126
+ # x = self.softmax(x)
127
+ return x
128
+
129
+ def configure_optimizers(self) -> Any:
130
+ optimizer = torch.optim.Adam(
131
+ self.parameters(), lr=self.learning_rate, weight_decay=1e-4
132
+ )
133
+ scheduler = one_cycle_lr(
134
+ optimizer=optimizer, maxlr=self.maxlr, steps=self.scheduler_steps, epochs=self.epochs
135
+ )
136
+ return {"optimizer": optimizer,
137
+ "lr_scheduler": {"scheduler": scheduler,
138
+ "interval": "step"}}
139
+
140
+ def training_step(self, batch, batch_idx):
141
+ X, y = batch
142
+ y_pred = self(X)
143
+ loss = nn.CrossEntropyLoss()(y_pred, y)
144
+
145
+ preds = torch.argmax(y_pred, dim=1)
146
+
147
+ accuracy = self.accuracy(preds, y)
148
+
149
+ self.log_dict({"train_loss": loss, "train_acc": accuracy}, prog_bar=True)
150
+ return loss
151
+
152
+ def validation_step(self, batch, batch_idx):
153
+ X, y = batch
154
+ y_pred = self(X)
155
+ loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
156
+
157
+ preds = torch.argmax(y_pred, dim=1)
158
+
159
+ accuracy = self.accuracy(preds, y)
160
+
161
+ self.log_dict({"val_loss": loss, "val_acc": accuracy}, prog_bar=True)
162
+
163
+ return loss
164
+
165
+ def test_step(self, batch, batch_idx):
166
+ X, y = batch
167
+ y_pred = self(X)
168
+ loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
169
+ preds = torch.argmax(y_pred, dim=1)
170
+
171
+ accuracy = self.accuracy(preds, y)
172
+
173
+ self.log_dict({"test_loss": loss, "test_acc": accuracy}, prog_bar=True)
requirements.txt CHANGED
@@ -1,228 +1,13 @@
1
- absl-py==1.4.0
2
- adbc-driver-manager==0.5.1
3
- adbc-driver-sqlite==0.5.1
4
- aiofiles==23.1.0
5
- aiohttp==3.8.5
6
- aiosignal==1.3.1
7
- albumentations==1.3.1
8
- altair==5.0.1
9
- annotated-types==0.5.0
10
- anyio==3.7.1
11
- argon2-cffi==21.3.0
12
- argon2-cffi-bindings==21.2.0
13
- arrow==1.2.3
14
- asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work
15
- async-lru==2.0.4
16
- async-timeout==4.0.2
17
- attrs==23.1.0
18
- Babel==2.12.1
19
- backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
20
- backoff==2.2.1
21
- backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work
22
- beautifulsoup4==4.12.2
23
- black==23.7.0
24
- bleach==6.0.0
25
- blessed==1.20.0
26
- cachetools==5.3.1
27
- certifi==2022.12.7
28
- cffi==1.15.1
29
- charset-normalizer==2.1.1
30
- click==8.1.6
31
- cloudpickle==2.2.1
32
- cmake==3.25.0
33
- connectorx==0.3.1
34
- contourpy==1.1.0
35
- croniter==1.4.1
36
- cycler==0.11.0
37
- dateutils==0.6.12
38
- debugpy @ file:///home/builder/ci_310/debugpy_1640789504635/work
39
- decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
40
- deepdiff==6.3.1
41
- defusedxml==0.7.1
42
- deltalake==0.10.0
43
- entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
44
- exceptiongroup==1.1.2
45
- executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work
46
- fastapi==0.100.1
47
- fastjsonschema==2.18.0
48
- ffmpy==0.3.1
49
- filelock==3.12.2
50
- fonttools==4.41.0
51
- fqdn==1.5.1
52
- frozenlist==1.4.0
53
- fsspec==2023.6.0
54
- google-auth==2.22.0
55
- google-auth-oauthlib==1.0.0
56
- grad-cam==1.4.8
57
- gradio==3.39.0
58
- gradio_client==0.3.0
59
- greenlet==2.0.2
60
- grpcio==1.56.2
61
- h11==0.14.0
62
- httpcore==0.17.3
63
- httpx==0.24.1
64
- huggingface-hub==0.16.4
65
- idna==3.4
66
- imageio==2.31.1
67
- inquirer==3.1.3
68
- ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1655369107642/work
69
- ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1685727741709/work
70
- ipywidgets==8.0.7
71
- isoduration==20.11.0
72
- itsdangerous==2.1.2
73
- jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work
74
- Jinja2==3.1.2
75
- joblib==1.3.1
76
- json5==0.9.14
77
- jsonpointer==2.4
78
- jsonschema==4.18.6
79
- jsonschema-specifications==2023.7.1
80
- jupyter-events==0.7.0
81
- jupyter-lsp==2.2.0
82
- jupyter_client==8.3.0
83
- jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1686775611663/work
84
- jupyter_server==2.7.0
85
- jupyter_server_terminals==0.4.4
86
- jupyterlab==4.0.4
87
- jupyterlab-pygments==0.2.2
88
- jupyterlab-widgets==3.0.8
89
- jupyterlab_server==2.24.0
90
- kiwisolver==1.4.4
91
- lazy_loader==0.3
92
- lightning==2.0.6
93
- lightning-cloud==0.5.37
94
- lightning-utilities==0.9.0
95
- linkify-it-py==2.0.2
96
- lit==15.0.7
97
- Markdown==3.4.3
98
- markdown-it-py==2.2.0
99
- MarkupSafe==2.1.2
100
- matplotlib==3.7.2
101
- matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
102
- mdit-py-plugins==0.3.3
103
- mdurl==0.1.2
104
- mistune==3.0.1
105
- mpmath==1.2.1
106
- multidict==6.0.4
107
- mypy-extensions==1.0.0
108
- nbclient==0.8.0
109
- nbconvert==7.7.3
110
- nbformat==5.9.2
111
- nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work
112
- netron==7.0.6
113
- networkx==3.0
114
- notebook_shim==0.2.3
115
- numpy==1.24.1
116
- nvidia-cublas-cu11==11.10.3.66
117
- nvidia-cuda-cupti-cu11==11.7.101
118
- nvidia-cuda-nvrtc-cu11==11.7.99
119
- nvidia-cuda-runtime-cu11==11.7.99
120
- nvidia-cudnn-cu11==8.5.0.96
121
- nvidia-cufft-cu11==10.9.0.58
122
- nvidia-curand-cu11==10.2.10.91
123
- nvidia-cusolver-cu11==11.4.0.1
124
- nvidia-cusparse-cu11==11.7.4.91
125
- nvidia-nccl-cu11==2.14.3
126
- nvidia-nvtx-cu11==11.7.91
127
- oauthlib==3.2.2
128
- opencv-python==4.8.0.74
129
- opencv-python-headless==4.8.0.74
130
- ordered-set==4.1.0
131
- orjson==3.9.3
132
- overrides==7.3.1
133
- packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work
134
- pandas==2.0.3
135
- pandocfilters==1.5.0
136
- parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
137
- pathspec==0.11.2
138
- pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
139
- pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
140
- Pillow==10.0.0
141
- platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1689538620473/work
142
- polars==0.18.8
143
- prometheus-client==0.17.1
144
- prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work
145
- protobuf==4.23.4
146
- psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
147
- ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
148
- pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
149
- pyarrow==12.0.1
150
- pyasn1==0.5.0
151
- pyasn1-modules==0.3.0
152
- pycparser==2.21
153
- pydantic==2.0.3
154
- pydantic_core==2.3.0
155
- pydub==0.25.1
156
- Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1681904169130/work
157
- PyJWT==2.8.0
158
- pyparsing==3.0.9
159
- python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
160
- python-editor==1.0.4
161
- python-json-logger==2.0.7
162
- python-multipart==0.0.6
163
- pytorch-lightning==2.0.6
164
- pytz==2023.3
165
- PyWavelets==1.4.1
166
- PyYAML==6.0.1
167
- pyzmq @ file:///croot/pyzmq_1686601365461/work
168
- qudida==0.0.4
169
- readchar==4.0.5
170
- referencing==0.30.2
171
- requests==2.28.1
172
- requests-oauthlib==1.3.1
173
- rfc3339-validator==0.1.4
174
- rfc3986-validator==0.1.1
175
- rich==13.5.0
176
- rpds-py==0.9.2
177
- rsa==4.9
178
- ruff==0.0.280
179
- scikit-image==0.21.0
180
- scikit-learn==1.3.0
181
- scipy==1.11.1
182
- semantic-version==2.10.0
183
- Send2Trash==1.8.2
184
- six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
185
- sniffio==1.3.0
186
- soupsieve==2.4.1
187
- SQLAlchemy==2.0.19
188
- stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
189
- starlette==0.27.0
190
- starsessions==1.3.0
191
- sympy==1.11.1
192
- tensorboard==2.13.0
193
- tensorboard-data-server==0.7.1
194
- terminado==0.17.1
195
- threadpoolctl==3.2.0
196
- tifffile==2023.7.18
197
- tinycss2==1.2.1
198
- toml==0.10.2
199
- tomli==2.0.1
200
- toolz==0.12.0
201
- torch==2.0.1+cu118
202
- torch-lr-finder==0.2.1
203
- torch-tb-profiler==0.4.1
204
- torchaudio==2.0.2+cu118
205
- torchinfo==1.8.0
206
- torchmetrics==1.0.1
207
- torchvision==0.15.2+cu118
208
- tornado==6.3.2
209
- tqdm==4.65.0
210
- traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work
211
- triton==2.0.0
212
- ttach==0.0.3
213
- typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1688315532570/work
214
- tzdata==2023.3
215
- uc-micro-py==1.0.2
216
- uri-template==1.3.0
217
- urllib3==1.26.13
218
- uvicorn==0.23.1
219
- wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work
220
- webcolors==1.13
221
- webencodings==0.5.1
222
- websocket-client==1.6.1
223
- websockets==11.0.3
224
- Werkzeug==2.3.6
225
- widgetsnbextension==4.0.8
226
- xlsx2csv==0.8.1
227
- XlsxWriter==3.1.2
228
- yarl==1.9.2
 
1
+ numpy
2
+ pandas
3
+ matplotlib
4
+ torch
5
+ torchvision
6
+ lightning
7
+ gradio
8
+ grad-cam
9
+ torchinfo
10
+ torch_lr_finder
11
+ pydantic
12
+ tqdm
13
+ albumentations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/common.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchinfo import summary
8
+ from torch_lr_finder import LRFinder
9
+
10
+
11
+ def find_lr(model, optimizer, criterion, device, trainloader, numiter, startlr, endlr):
12
+ lr_finder = LRFinder(
13
+ model=model, optimizer=optimizer, criterion=criterion, device=device
14
+ )
15
+
16
+ lr_finder.range_test(
17
+ train_loader=trainloader,
18
+ start_lr=startlr,
19
+ end_lr=endlr,
20
+ num_iter=numiter,
21
+ step_mode="exp",
22
+ )
23
+
24
+ lr_finder.plot()
25
+
26
+ lr_finder.reset()
27
+
28
+
29
+ def one_cycle_lr(optimizer, maxlr, steps, epochs):
30
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
31
+ optimizer=optimizer,
32
+ max_lr=maxlr,
33
+ steps_per_epoch=steps,
34
+ epochs=epochs,
35
+ pct_start=5 / epochs,
36
+ div_factor=100,
37
+ three_phase=False,
38
+ final_div_factor=100,
39
+ anneal_strategy="linear",
40
+ )
41
+ return scheduler
42
+
43
+
44
+ def show_random_images_for_each_class(train_data, num_images_per_class=16):
45
+ for c, cls in enumerate(train_data.classes):
46
+ rand_targets = random.sample(
47
+ [n for n, x in enumerate(train_data.targets) if x == c],
48
+ k=num_images_per_class,
49
+ )
50
+ show_img_grid(np.transpose(train_data.data[rand_targets], axes=(0, 3, 1, 2)))
51
+ plt.title(cls)
52
+
53
+
54
+ def show_img_grid(data):
55
+ try:
56
+ grid_img = torchvision.utils.make_grid(data.cpu().detach())
57
+ except:
58
+ data = torch.from_numpy(data)
59
+ grid_img = torchvision.utils.make_grid(data)
60
+
61
+ plt.figure(figsize=(10, 10))
62
+ plt.imshow(grid_img.permute(1, 2, 0))
63
+
64
+
65
+ def show_random_images(data_loader):
66
+ data, target = next(iter(data_loader))
67
+ show_img_grid(data)
68
+
69
+
70
+ def show_model_summary(model, batch_size):
71
+ summary(
72
+ model=model,
73
+ input_size=(batch_size, 3, 32, 32),
74
+ col_names=["input_size", "output_size", "num_params", "kernel_size"],
75
+ verbose=1,
76
+ )
77
+
78
+
79
+ def lossacc_plots(results):
80
+ plt.plot(results["epoch"], results["trainloss"])
81
+ plt.plot(results["epoch"], results["testloss"])
82
+ plt.legend(["Train Loss", "Validation Loss"])
83
+ plt.xlabel("Epochs")
84
+ plt.ylabel("Loss")
85
+ plt.title("Loss vs Epochs")
86
+ plt.show()
87
+
88
+ plt.plot(results["epoch"], results["trainacc"])
89
+ plt.plot(results["epoch"], results["testacc"])
90
+ plt.legend(["Train Acc", "Validation Acc"])
91
+ plt.xlabel("Epochs")
92
+ plt.ylabel("Accuracy")
93
+ plt.title("Accuracy vs Epochs")
94
+ plt.show()
95
+
96
+
97
+ def lr_plots(results, length):
98
+ plt.plot(range(length), results["lr"])
99
+ plt.xlabel("Epochs")
100
+ plt.ylabel("Learning Rate")
101
+ plt.title("Learning Rate vs Epochs")
102
+ plt.show()
103
+
104
+
105
+ def get_misclassified(model, testloader, device, mis_count=10):
106
+ misimgs, mistgts, mispreds = [], [], []
107
+ with torch.no_grad():
108
+ for data, target in testloader:
109
+ data, target = data.to(device), target.to(device)
110
+ output = model(data)
111
+ pred = output.argmax(dim=1, keepdim=True)
112
+ misclassified = torch.argwhere(pred.squeeze() != target).squeeze()
113
+ for idx in misclassified:
114
+ if len(misimgs) >= mis_count:
115
+ break
116
+ misimgs.append(data[idx])
117
+ mistgts.append(target[idx])
118
+ mispreds.append(pred[idx].squeeze())
119
+ return misimgs, mistgts, mispreds
120
+
121
+
122
+ # def plot_misclassified(misimgs, mistgts, mispreds, classes):
123
+ # fig, axes = plt.subplots(len(misimgs) // 2, 2)
124
+ # fig.tight_layout()
125
+ # for ax, img, tgt, pred in zip(axes.ravel(), misimgs, mistgts, mispreds):
126
+ # ax.imshow((img / img.max()).permute(1, 2, 0).cpu())
127
+ # ax.set_title(f"{classes[tgt]} | {classes[pred]}")
128
+ # ax.grid(False)
129
+ # ax.set_axis_off()
130
+ # plt.show()
131
+
132
+ def get_misclassified_data(model, device, test_loader, count):
133
+ """
134
+ Function to run the model on test set and return misclassified images
135
+ :param model: Network Architecture
136
+ :param device: CPU/GPU
137
+ :param test_loader: DataLoader for test set
138
+ """
139
+ # Prepare the model for evaluation i.e. drop the dropout layer
140
+ model.eval()
141
+
142
+ # List to store misclassified Images
143
+ misclassified_data = []
144
+
145
+ # Reset the gradients
146
+ with torch.no_grad():
147
+ # Extract images, labels in a batch
148
+ for data, target in test_loader:
149
+
150
+ # Migrate the data to the device
151
+ data, target = data.to(device), target.to(device)
152
+
153
+ # Extract single image, label from the batch
154
+ for image, label in zip(data, target):
155
+
156
+ # Add batch dimension to the image
157
+ image = image.unsqueeze(0)
158
+
159
+ # Get the model prediction on the image
160
+ output = model(image)
161
+
162
+ # Convert the output from one-hot encoding to a value
163
+ pred = output.argmax(dim=1, keepdim=True)
164
+
165
+ # If prediction is incorrect, append the data
166
+ if pred != label:
167
+ misclassified_data.append((image, label, pred))
168
+ if len(misclassified_data) >= count:
169
+ break
170
+
171
+ return misclassified_data[:count]
172
+
173
+ def plot_misclassified(data, classes, size=(10, 10), rows=2, cols=5, inv_normalize=None):
174
+ fig = plt.figure(figsize=size)
175
+ number_of_samples = len(data)
176
+ for i in range(number_of_samples):
177
+ plt.subplot(rows, cols, i + 1)
178
+ img = data[i][0].squeeze().to('cpu')
179
+ if inv_normalize is not None:
180
+ img = inv_normalize(img)
181
+ plt.imshow(np.transpose(img, (1, 2, 0)))
182
+ plt.title(f"Label: {classes[data[i][1].item()]} \n Prediction: {classes[data[i][2].item()]}")
183
+ plt.xticks([])
184
+ plt.yticks([])
185
+
utils/config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import toml
2
+ from pydantic import BaseModel
3
+
4
+ TOML_PATH = "config.toml"
5
+
6
+
7
+ class Data(BaseModel):
8
+ batch_size: int = 512
9
+ shuffle: bool = True
10
+ num_workers: int = 4
11
+
12
+
13
+ class LRFinder(BaseModel):
14
+ numiter: int = 600
15
+ endlr: float = 10
16
+ startlr: float = 1e-2
17
+
18
+
19
+ class Training(BaseModel):
20
+ epochs: int = 20
21
+ optimizer: str = "adam"
22
+ criterion: str = "crossentropy"
23
+ lr: float = 0.003
24
+ weight_decay: float = 1e-4
25
+ lrfinder: LRFinder
26
+
27
+
28
+ class Config(BaseModel):
29
+ data: Data
30
+ training: Training
31
+
32
+
33
+ with open(TOML_PATH) as f:
34
+ toml_config = toml.load(f)
35
+
36
+ config = Config(**toml_config)
utils/data.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import lightning as L
3
+ from torch.utils.data import DataLoader
4
+ from utils.transforms import train_transform, test_transform
5
+
6
+
7
+ class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
8
+ def __init__(self, root="~/data", train=True, download=True, transform=None):
9
+ super().__init__(root=root, train=train, download=download, transform=transform)
10
+
11
+ def __getitem__(self, index):
12
+ image, label = self.data[index], self.targets[index]
13
+ if self.transform is not None:
14
+ transformed = self.transform(image=image)
15
+ image = transformed["image"]
16
+
17
+ return image, label
18
+
19
+
20
+ class CIFARDataModule(L.LightningDataModule):
21
+ def __init__(
22
+ self, data_dir="data", batch_size=512, shuffle=True, num_workers=4
23
+ ) -> None:
24
+ super().__init__()
25
+ self.data_dir = data_dir
26
+ self.batch_size = batch_size
27
+ self.shuffle = shuffle
28
+ self.num_workers = num_workers
29
+
30
+ def prepare_data(self) -> None:
31
+ pass
32
+
33
+ def setup(self, stage=None):
34
+ self.train_dataset = Cifar10SearchDataset(
35
+ root=self.data_dir, train=True, transform=train_transform
36
+ )
37
+
38
+ self.val_dataset = Cifar10SearchDataset(
39
+ root=self.data_dir, train=False, transform=test_transform
40
+ )
41
+
42
+ self.test_dataset = Cifar10SearchDataset(
43
+ root=self.data_dir, train=False, transform=test_transform
44
+ )
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ dataset=self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ )
53
+
54
+ def val_dataloader(self):
55
+ return DataLoader(
56
+ dataset=self.val_dataset,
57
+ batch_size=self.batch_size,
58
+ shuffle=self.shuffle,
59
+ num_workers=self.num_workers,
60
+ )
61
+
62
+ def test_dataloader(self):
63
+ return DataLoader(
64
+ dataset=self.test_dataset,
65
+ batch_size=self.batch_size,
66
+ shuffle=self.shuffle,
67
+ num_workers=self.num_workers,
68
+ )
utils/gradcam.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pytorch_grad_cam import GradCAM
3
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
4
+ from pytorch_grad_cam.utils.image import show_cam_on_image
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ def generate_gradcam(model, target_layers, images, labels, rgb_imgs):
10
+ results = []
11
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
12
+
13
+ for image, label, np_image in zip(images, labels, rgb_imgs):
14
+ targets = [ClassifierOutputTarget(label.item())]
15
+
16
+ # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
17
+ grayscale_cam = cam(
18
+ input_tensor=image.unsqueeze(0), targets=targets, aug_smooth=True
19
+ )
20
+
21
+ # In this example grayscale_cam has only one image in the batch:
22
+ grayscale_cam = grayscale_cam[0, :]
23
+ visualization = show_cam_on_image(
24
+ np_image / np_image.max(), grayscale_cam, use_rgb=True
25
+ )
26
+ results.append(visualization)
27
+ return results
28
+
29
+
30
+ def visualize_gradcam(misimgs, mistgts, mispreds, classes):
31
+ fig, axes = plt.subplots(len(misimgs) // 2, 2)
32
+ fig.tight_layout()
33
+ for ax, img, tgt, pred in zip(axes.ravel(), misimgs, mistgts, mispreds):
34
+ ax.imshow(img)
35
+ ax.set_title(f"{classes[tgt]} | {classes[pred]}")
36
+ ax.grid(False)
37
+ ax.set_axis_off()
38
+ plt.show()
39
+
40
+ def plot_gradcam(model, data, classes, target_layers, number_of_samples, inv_normalize=None, targets=None, transparency = 0.60, figsize=(10,10), rows=2, cols=5):
41
+
42
+ fig = plt.figure(figsize=figsize)
43
+
44
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
45
+ for i in range(number_of_samples):
46
+ plt.subplot(rows, cols, i + 1)
47
+ input_tensor = data[i][0]
48
+
49
+ # Get the activations of the layer for the images
50
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
51
+ grayscale_cam = grayscale_cam[0, :]
52
+
53
+ # Get back the original image
54
+ img = input_tensor.squeeze(0).to('cpu')
55
+ if inv_normalize is not None:
56
+ img = inv_normalize(img)
57
+ rgb_img = np.transpose(img, (1, 2, 0))
58
+ rgb_img = rgb_img.numpy()
59
+
60
+ # Mix the activations on the original image
61
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
62
+
63
+ # Display the images on the plot
64
+ plt.imshow(visualization)
65
+ plt.title(f"Label: {classes[data[i][1].item()]} \n Prediction: {classes[data[i][2].item()]}")
66
+ plt.xticks([])
67
+ plt.yticks([])
utils/training.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def train(
7
+ model,
8
+ device,
9
+ train_loader,
10
+ optimizer,
11
+ criterion,
12
+ scheduler,
13
+ L1=False,
14
+ l1_lambda=0.01,
15
+ ):
16
+ model.train()
17
+ pbar = tqdm(train_loader)
18
+
19
+ train_losses = []
20
+ train_acc = []
21
+ lrs = []
22
+
23
+ correct = 0
24
+ processed = 0
25
+ train_loss = 0
26
+
27
+ for batch_idx, (data, target) in enumerate(pbar):
28
+ data, target = data.to(device), target.to(device)
29
+ optimizer.zero_grad()
30
+ y_pred = model(data)
31
+
32
+ # Calculate loss
33
+ loss = criterion(y_pred, target)
34
+ if L1:
35
+ l1_loss = 0
36
+ for p in model.parameters():
37
+ l1_loss = l1_loss + p.abs().sum()
38
+ loss = loss + l1_lambda * l1_loss
39
+ else:
40
+ loss = loss
41
+
42
+ train_loss += loss.item()
43
+ train_losses.append(loss.item())
44
+
45
+ # Backpropagation
46
+ loss.backward()
47
+ optimizer.step()
48
+ scheduler.step()
49
+
50
+ # Update pbar-tqdm
51
+ pred = y_pred.argmax(
52
+ dim=1, keepdim=True
53
+ ) # get the index of the max log-probability
54
+ correct += pred.eq(target.view_as(pred)).sum().item()
55
+ processed += len(data)
56
+
57
+ pbar.set_description(
58
+ desc=f"Loss={loss.item():0.2f} Accuracy={100*correct/processed:0.2f}"
59
+ )
60
+ train_acc.append(100 * correct / processed)
61
+ lrs.append(scheduler.get_last_lr())
62
+
63
+ return train_losses, train_acc, lrs
64
+
65
+
66
+ def test(model, device, criterion, test_loader):
67
+ model.eval()
68
+ test_loss = 0
69
+ correct = 0
70
+ with torch.no_grad():
71
+ for data, target in test_loader:
72
+ data, target = data.to(device), target.to(device)
73
+ output = model(data)
74
+ test_loss += F.cross_entropy(output, target, reduction="sum").item()
75
+ pred = output.argmax(dim=1, keepdim=True)
76
+ correct += pred.eq(target.view_as(pred)).sum().item()
77
+
78
+ test_loss /= len(test_loader.dataset)
79
+
80
+ print(
81
+ "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
82
+ test_loss,
83
+ correct,
84
+ len(test_loader.dataset),
85
+ 100.0 * correct / len(test_loader.dataset),
86
+ )
87
+ )
88
+ test_acc = 100.0 * correct / len(test_loader.dataset)
89
+
90
+ return test_loss, test_acc
utils/transforms.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations.pytorch import ToTensorV2
3
+
4
+ train_transform = A.Compose(
5
+ [
6
+ A.PadIfNeeded(min_height=40, min_width=40, always_apply=True),
7
+ A.RandomCrop(height=32, width=32, always_apply=True),
8
+ A.HorizontalFlip(),
9
+ A.CoarseDropout(
10
+ min_holes=1,
11
+ max_holes=1,
12
+ min_height=8,
13
+ min_width=8,
14
+ max_height=8,
15
+ max_width=8,
16
+ fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], # type: ignore
17
+ p=0.5,
18
+ ),
19
+ A.Normalize((0.49139968, 0.48215827, 0.44653124),
20
+ (0.24703233, 0.24348505, 0.26158768)),
21
+ ToTensorV2(),
22
+ ]
23
+ )
24
+
25
+ test_transform = A.Compose(
26
+ [
27
+ A.Normalize((0.49139968, 0.48215827, 0.44653124),
28
+ (0.24703233, 0.24348505, 0.26158768)),
29
+ ToTensorV2(),
30
+ ]
31
+ )