Spaces:
Sleeping
Sleeping
Commit
·
222f1fd
1
Parent(s):
c79dd40
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,12 +20,15 @@ config = {
|
|
| 20 |
'max_lr': 0.1,
|
| 21 |
'max_lr_epoch': 5,
|
| 22 |
'dropout' : 0.01,
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
train_transforms = get_train_transforms()
|
| 26 |
test_transforms = get_test_transforms()
|
| 27 |
model = CustomResNet(config, config['dropout'], train_transforms, test_transforms)
|
| 28 |
-
model.load_state_dict(torch.load("
|
| 29 |
model.setup(stage="test")
|
| 30 |
|
| 31 |
inv_normalize = transforms.Normalize(
|
|
@@ -98,23 +101,32 @@ examples = [[os.path.join(images_folder, "plane.jpg"), 0.5, -1,10],
|
|
| 98 |
[os.path.join(images_folder, "frog.jpg"), 0.5, -1,2],
|
| 99 |
[os.path.join(images_folder, "horse.jpg"), 0.5, -1,10],
|
| 100 |
[os.path.join(images_folder, "ship.jpg"), 0.5, -1,10],
|
| 101 |
-
[os.path.join(images_folder, "truck.
|
| 102 |
-
|
| 103 |
|
|
|
|
| 104 |
input_interface = gr.Interface(
|
| 105 |
inference,
|
| 106 |
-
inputs
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
mislclassified_description = "Misclassified Image for Custom Resnet"
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
misclassified_interface = gr.Interface(show_misclassified_images_wrap,
|
| 119 |
inputs=[gr.Number(value=10, label="Misclassified Inputs",info = "Set the Number of Misclassifed Outputs to be Shown"),
|
| 120 |
gr.Radio(["Yes", "No"], value="No" , label="Enable GradCAM",info = "Do you want to see GradCAM"),
|
|
@@ -123,7 +135,6 @@ misclassified_interface = gr.Interface(show_misclassified_images_wrap,
|
|
| 123 |
outputs=gr.Plot(), description=mislclassified_description)
|
| 124 |
|
| 125 |
demo = gr.TabbedInterface([input_interface, misclassified_interface], tab_names=["Top Classes and Prediction", "Misclassified Images"],
|
| 126 |
-
title=
|
| 127 |
-
|
| 128 |
demo.launch()
|
| 129 |
|
|
|
|
| 20 |
'max_lr': 0.1,
|
| 21 |
'max_lr_epoch': 5,
|
| 22 |
'dropout' : 0.01,
|
| 23 |
+
'LEARNING_RATE' : 1e-5,
|
| 24 |
+
'WEIGHT_DECAY' : 1e-4,
|
| 25 |
+
'NUM_EPOCHS' : 100
|
| 26 |
}
|
| 27 |
|
| 28 |
train_transforms = get_train_transforms()
|
| 29 |
test_transforms = get_test_transforms()
|
| 30 |
model = CustomResNet(config, config['dropout'], train_transforms, test_transforms)
|
| 31 |
+
model.load_state_dict(torch.load("resnet_model_v7.pth", map_location=torch.device('cpu')), strict=False)
|
| 32 |
model.setup(stage="test")
|
| 33 |
|
| 34 |
inv_normalize = transforms.Normalize(
|
|
|
|
| 101 |
[os.path.join(images_folder, "frog.jpg"), 0.5, -1,2],
|
| 102 |
[os.path.join(images_folder, "horse.jpg"), 0.5, -1,10],
|
| 103 |
[os.path.join(images_folder, "ship.jpg"), 0.5, -1,10],
|
| 104 |
+
[os.path.join(images_folder, "truck.jpeg"), 0.5, -1,10]]
|
|
|
|
| 105 |
|
| 106 |
+
# Create the input interface with the modified template
|
| 107 |
input_interface = gr.Interface(
|
| 108 |
inference,
|
| 109 |
+
inputs=[
|
| 110 |
+
gr.Image(shape=(32, 32), label="Input Image"),
|
| 111 |
+
gr.Slider(0, 1, value=0.5, label="Transparency", info="Set the Opacity of CAM"),
|
| 112 |
+
gr.Slider(-2, -1, value=-2, step=1, label="Network Layer", info="GradCAM Network Layer"),
|
| 113 |
+
gr.Slider(1, 10, step=1, value=10, label="Top Classes", info="How many top classes do you want to view")
|
| 114 |
+
],
|
| 115 |
+
outputs=[
|
| 116 |
+
gr.Label(num_top_classes=10),
|
| 117 |
+
gr.Image(shape=(32, 32), label="Model Prediction").style(width=300, height=300)
|
| 118 |
+
],
|
| 119 |
+
description=description,
|
| 120 |
+
examples=[[f'examples/{k}.jpg'] for k in classes_for_categorize.values()],)
|
| 121 |
|
| 122 |
mislclassified_description = "Misclassified Image for Custom Resnet"
|
| 123 |
+
icon_html = '<i class="fas fa-chart-bar"></i>'
|
| 124 |
+
title_with_icon = f"""
|
| 125 |
+
<div style="background-color: #f1f4f0; padding: 10px; display: flex; align-items: center;">
|
| 126 |
+
{icon_html} <span style="margin-left: 10px;">Custom Resnet on CIFAR10 using PyTorch Lightning and GradCAM</span>
|
| 127 |
+
</div>
|
| 128 |
+
"""
|
| 129 |
+
# Create a separate interface for the "Misclassified Images"
|
| 130 |
misclassified_interface = gr.Interface(show_misclassified_images_wrap,
|
| 131 |
inputs=[gr.Number(value=10, label="Misclassified Inputs",info = "Set the Number of Misclassifed Outputs to be Shown"),
|
| 132 |
gr.Radio(["Yes", "No"], value="No" , label="Enable GradCAM",info = "Do you want to see GradCAM"),
|
|
|
|
| 135 |
outputs=gr.Plot(), description=mislclassified_description)
|
| 136 |
|
| 137 |
demo = gr.TabbedInterface([input_interface, misclassified_interface], tab_names=["Top Classes and Prediction", "Misclassified Images"],
|
| 138 |
+
title=title_with_icon,)
|
|
|
|
| 139 |
demo.launch()
|
| 140 |
|