PrarthanaTS commited on
Commit
222f1fd
·
1 Parent(s): c79dd40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -20,12 +20,15 @@ config = {
20
  'max_lr': 0.1,
21
  'max_lr_epoch': 5,
22
  'dropout' : 0.01,
 
 
 
23
  }
24
 
25
  train_transforms = get_train_transforms()
26
  test_transforms = get_test_transforms()
27
  model = CustomResNet(config, config['dropout'], train_transforms, test_transforms)
28
- model.load_state_dict(torch.load("resnet_model_v2.pth", map_location=torch.device('cpu')), strict=False)
29
  model.setup(stage="test")
30
 
31
  inv_normalize = transforms.Normalize(
@@ -98,23 +101,32 @@ examples = [[os.path.join(images_folder, "plane.jpg"), 0.5, -1,10],
98
  [os.path.join(images_folder, "frog.jpg"), 0.5, -1,2],
99
  [os.path.join(images_folder, "horse.jpg"), 0.5, -1,10],
100
  [os.path.join(images_folder, "ship.jpg"), 0.5, -1,10],
101
- [os.path.join(images_folder, "truck.jpg"), 0.5, -1,10]]
102
-
103
 
 
104
  input_interface = gr.Interface(
105
  inference,
106
- inputs = [gr.Image(shape=(32, 32), label="Input Image"),
107
- gr.Slider(0, 1, value = 0.5, label="Transparency",info = "Set the Opacity of CAM"),
108
- gr.Slider(-2, -1, value = -2, step=1, label="Network Layer", info = "GradCAM Network Layer"),
109
- gr.Slider(1, 10, step=1, value=10 , label="Top Classes",info = "How many top classes do you want to view")],
110
- outputs = [gr.Label(num_top_classes=10), gr.Image(shape=(32, 32), label="Model Prediction").style(width=300, height=300)],
111
- description = description,
112
- examples=[[f'examples/{k}.jpg'] for k in classes_for_categorize.values()],
113
- )
 
 
 
 
114
 
115
  mislclassified_description = "Misclassified Image for Custom Resnet"
116
-
117
- # Create a separate interface for the "Misclassified Images"
 
 
 
 
 
118
  misclassified_interface = gr.Interface(show_misclassified_images_wrap,
119
  inputs=[gr.Number(value=10, label="Misclassified Inputs",info = "Set the Number of Misclassifed Outputs to be Shown"),
120
  gr.Radio(["Yes", "No"], value="No" , label="Enable GradCAM",info = "Do you want to see GradCAM"),
@@ -123,7 +135,6 @@ misclassified_interface = gr.Interface(show_misclassified_images_wrap,
123
  outputs=gr.Plot(), description=mislclassified_description)
124
 
125
  demo = gr.TabbedInterface([input_interface, misclassified_interface], tab_names=["Top Classes and Prediction", "Misclassified Images"],
126
- title="Custom Resnet on CIFAR10 using pytorch Lightening and GradCAM")
127
-
128
  demo.launch()
129
 
 
20
  'max_lr': 0.1,
21
  'max_lr_epoch': 5,
22
  'dropout' : 0.01,
23
+ 'LEARNING_RATE' : 1e-5,
24
+ 'WEIGHT_DECAY' : 1e-4,
25
+ 'NUM_EPOCHS' : 100
26
  }
27
 
28
  train_transforms = get_train_transforms()
29
  test_transforms = get_test_transforms()
30
  model = CustomResNet(config, config['dropout'], train_transforms, test_transforms)
31
+ model.load_state_dict(torch.load("resnet_model_v7.pth", map_location=torch.device('cpu')), strict=False)
32
  model.setup(stage="test")
33
 
34
  inv_normalize = transforms.Normalize(
 
101
  [os.path.join(images_folder, "frog.jpg"), 0.5, -1,2],
102
  [os.path.join(images_folder, "horse.jpg"), 0.5, -1,10],
103
  [os.path.join(images_folder, "ship.jpg"), 0.5, -1,10],
104
+ [os.path.join(images_folder, "truck.jpeg"), 0.5, -1,10]]
 
105
 
106
+ # Create the input interface with the modified template
107
  input_interface = gr.Interface(
108
  inference,
109
+ inputs=[
110
+ gr.Image(shape=(32, 32), label="Input Image"),
111
+ gr.Slider(0, 1, value=0.5, label="Transparency", info="Set the Opacity of CAM"),
112
+ gr.Slider(-2, -1, value=-2, step=1, label="Network Layer", info="GradCAM Network Layer"),
113
+ gr.Slider(1, 10, step=1, value=10, label="Top Classes", info="How many top classes do you want to view")
114
+ ],
115
+ outputs=[
116
+ gr.Label(num_top_classes=10),
117
+ gr.Image(shape=(32, 32), label="Model Prediction").style(width=300, height=300)
118
+ ],
119
+ description=description,
120
+ examples=[[f'examples/{k}.jpg'] for k in classes_for_categorize.values()],)
121
 
122
  mislclassified_description = "Misclassified Image for Custom Resnet"
123
+ icon_html = '<i class="fas fa-chart-bar"></i>'
124
+ title_with_icon = f"""
125
+ <div style="background-color: #f1f4f0; padding: 10px; display: flex; align-items: center;">
126
+ {icon_html} <span style="margin-left: 10px;">Custom Resnet on CIFAR10 using PyTorch Lightning and GradCAM</span>
127
+ </div>
128
+ """
129
+ # Create a separate interface for the "Misclassified Images"
130
  misclassified_interface = gr.Interface(show_misclassified_images_wrap,
131
  inputs=[gr.Number(value=10, label="Misclassified Inputs",info = "Set the Number of Misclassifed Outputs to be Shown"),
132
  gr.Radio(["Yes", "No"], value="No" , label="Enable GradCAM",info = "Do you want to see GradCAM"),
 
135
  outputs=gr.Plot(), description=mislclassified_description)
136
 
137
  demo = gr.TabbedInterface([input_interface, misclassified_interface], tab_names=["Top Classes and Prediction", "Misclassified Images"],
138
+ title=title_with_icon,)
 
139
  demo.launch()
140