Spaces:
Sleeping
Sleeping
| from gradio_utils import * | |
| def process_images_gradcam(show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity): | |
| if show_gradcam: | |
| inv_normalize = transforms.Normalize( | |
| mean=[-1.9899, -1.9844, -1.7111], | |
| std=[4.0486, 4.1152, 3.8314]) | |
| classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
| 'dog', 'frog', 'horse', 'ship', 'truck') | |
| misclassified_data = get_misclassified_data(modelfin, "cpu", test_loader) | |
| if gradcam_layer=="1": | |
| images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer1[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity) | |
| if gradcam_layer=="2": | |
| images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer2[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity) | |
| if gradcam_layer=="3": | |
| images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer3[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity) | |
| if gradcam_layer=="4": | |
| images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer4[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity) | |
| return images | |
| def process_images_misclass(show_misclassify, misclassify_count): | |
| if show_misclassify: | |
| misclassified_data = get_misclassified_data(modelfin, "cpu", test_loader) | |
| image = display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=misclassify_count) | |
| return image | |
| def predict_classes(upload_image, top_classes): | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), # Resize to 32x32 pixels | |
| transforms.ToTensor(), # Convert image to tensor | |
| transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], # CIFAR-10 normalization | |
| std=[0.2023, 0.1994, 0.2010])]) | |
| # Load and transform an image | |
| image = upload_image | |
| image = transform(image) | |
| image = image.unsqueeze(0) | |
| device = next(modelfin.parameters()).device | |
| image = image.to(device) | |
| # Ensure the model is in evaluation mode | |
| modelfin.eval() | |
| # Disable gradient computation for inference | |
| with torch.no_grad(): | |
| output = modelfin(image) | |
| # Get the top 5 predictions and their indices | |
| probabilities = torch.nn.functional.softmax(output, dim=1) | |
| top_prob, top_catid = torch.topk(probabilities, top_classes) | |
| # CIFAR-10 classes | |
| classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| # Initialize an empty string to collect predictions | |
| predictions_str = "" | |
| # Collect top 5 predictions in the string with line breaks | |
| for i in range(top_prob.size(1)): | |
| predictions_str += f"{classes[top_catid[0][i]]}: {top_prob[0][i].item()*100:.2f}%\n" | |
| # Print or return the complete predictions string | |
| return predictions_str | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| show_gradcam = gr.Checkbox(label="Show GradCAM Images?") | |
| gradcam_count = gr.Number(label="How many GradCAM images?", value=1, precision=0) | |
| gradcam_layer = gr.Radio(choices=["1", "2", "3", "4"], label="Choose a layer", value=4) | |
| gradcam_opacity = gr.Slider(minimum=0, maximum=1, label="Opacity of overlay", value=0.5) | |
| # with gr.Column(): | |
| # show_misclassified = gr.Checkbox(label="Show Misclassified Images?") | |
| # misclassified_count = gr.Number(label="How many misclassified images?", value=1, precision=0) | |
| #uploaded_images = gr.File(label="Upload New Images", type="file", accept="image/*", multiple=True) | |
| #top_classes = gr.Number(label="How many top classes to show?", value=5, minimum=1, maximum=10, precision=0) | |
| submit_button = gr.Button("GradCam") | |
| outputs = gr.Image(label="Output") | |
| show_misclassify = gr.Checkbox(label="Show misclassified images?") | |
| misclassify_count=gr.Number(label="How many misclassified images?") | |
| submit_button_misclass = gr.Button("Misclassified") | |
| outputs_misclass = gr.Image(label="Output") | |
| upload_image = gr.Image(label="Upload your image", interactive = True, type='pil') | |
| top_classes = gr.Number(label="How many top classes would you like to see?", maximum=10) | |
| upload_btn = gr.Button("Classify your image") | |
| show_classes = gr.Textbox(label="Your top classes", interactive=False) | |
| submit_button_misclass.click( | |
| process_images_misclass, | |
| inputs=[show_misclassify, misclassify_count], | |
| outputs=outputs_misclass | |
| ) | |
| submit_button.click( | |
| process_images_gradcam, | |
| inputs=[show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity], | |
| outputs=outputs | |
| ) | |
| upload_btn.click( | |
| predict_classes, | |
| inputs=[upload_image, top_classes], | |
| outputs=show_classes | |
| ) | |
| demo.launch() |