Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from models.custom_resnet import CustomResNet | |
| # Run Interface | |
| def run_inference(input_image, gradcam=False, gradcam_layer=3, gradcam_num = 3, gradcam_opacity=0.5, misclassified_num=5, top_classes=10): | |
| """Run inference on a CIFAR-10 image. | |
| Args: | |
| image: The image to be classified. | |
| gradcam: Whether to show GradCAM images. | |
| gradcam_layer: The layer from which to generate GradCAM images. | |
| gradcam_opacity: The opacity of the GradCAM images. | |
| misclassified: Whether to show misclassified images. | |
| misclassified_num: The number of misclassified images to show. | |
| top_classes: The number of top classes to show. | |
| Returns: | |
| The classification results, including the predicted class, the top classes, and the GradCAM images (if requested). | |
| """ | |
| # # Load the CIFAR-10 model | |
| # model = CustomResNet() | |
| # checkpoint = torch.load('weight/epoch=2-step=294.ckpt') | |
| # model.load_state_dict(checkpoint['model_state_dict']) | |
| # # Classify the image | |
| # prediction = model.predict(image) | |
| # predicted_class = np.argmax(prediction) | |
| # # Get the top classes | |
| # top_classes = prediction.argsort()[-top_classes:][::-1] | |
| # # Generate GradCAM images, if requested | |
| # if gradcam: | |
| # gradcam_images = [] | |
| # for layer in range(model.layers.shape[0]): | |
| # gradcam_image = gradcam(model, image, layer, gradcam_opacity) | |
| # gradcam_images.append(gradcam_image) | |
| # # Get the misclassified images, if requested | |
| # misclassified_images = [] | |
| # for i in range(len(prediction)): | |
| # if prediction[i] != y_test[i]: | |
| # misclassified_images.append(image[i]) | |
| # Placeholder for top classes | |
| top_classes = {"dog" : 0.90, "cat": 0.10} | |
| # Placeholder for GradCAM images | |
| gradcam_images = [] | |
| if gradcam: | |
| # Generate GradCAM images for the specified layer and number | |
| for i in range(gradcam_num): | |
| gradcam_images.append(np.random.rand(32, 32, 3)) # Example random image | |
| # Placeholder for misclassified images | |
| misclassified_images = [] | |
| if misclassified_num > 0: | |
| # Get misclassified images | |
| for i in range(misclassified_num): | |
| misclassified_images.append((np.random.rand(32, 32, 3), 'caption')) # Example random image | |
| # Return the classification results | |
| return top_classes, gradcam_images if gradcam else [], misclassified_images | |
| # Gradio Interface | |
| input_image = gr.Image(shape=(32, 32), label="Upload Image", info="Upload a CIFAR-10 image to be classified.") | |
| gradcam= gr.Checkbox(label="View GradCAM images?", info="Whether to show GradCAM images.") | |
| gradcam_layer = gr.Dropdown(["1", "2", "3"], value="2", label="GradCAM Layer", info="The layer from which to generate GradCAM images.") | |
| gradcam_num = gr.Slider(label="Number of GradCAM images", minimum=1, maximum=10, step=1, info="The number of GradCAM images to show.") | |
| gradcam_opacity = gr.Slider(label="GradCAM opacity", minimum=0.0, maximum=1.0, step=0.01, info="The opacity of the GradCAM images.") | |
| misclassified_num = gr.Slider(label="Number of Misclassified images", minimum=0, maximum=10, step=1, info="The number of misclassified images to show.") | |
| top_classes = gr.Slider(label="Number of top classes to show", minimum=1, maximum=10, step=1, info="The number of top classes to show.") | |
| output_label = gr.Label(num_top_classes=3, label="Top Classes") | |
| output_gradcam_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="GradCam Galery") | |
| output_misclassified_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="Misclassified Images") | |
| interface = gr.Interface( | |
| fn=run_inference, | |
| inputs=[ | |
| input_image, | |
| gradcam, | |
| gradcam_layer, | |
| gradcam_num, | |
| gradcam_opacity, | |
| misclassified_num, | |
| top_classes | |
| ], | |
| outputs=[output_label, output_gradcam_gallery, output_misclassified_gallery], | |
| examples=[ | |
| ['assets/0001.jpg', True,"3", 4, 0.5, 3, 3, 2], | |
| ['assets/0002.jpg', False, "2", 1, 0.3, 1, 2, 2], | |
| ['assets/0003.jpg', True, "2", 1, 0.3, 1, 2, 2], | |
| ], | |
| title="Cifar-10 Inference with GradCAM", | |
| description="This is a CIFAR-10 image classifier using custom resnet. Upload a CIFAR-10 image and it will be classified into one of 10 categories.",) | |
| interface.launch(share=False) |