Spaces:
Runtime error
Runtime error
app gradio file with all required fields
Browse files- .gitignore +2 -1
- app-2.py +0 -65
- app.py +94 -45
- assets/0001.jpg +0 -0
- assets/0002.jpg +0 -0
- assets/0003.jpg +0 -0
.gitignore
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
-
__pycache__
|
| 2 |
data
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
data
|
| 3 |
+
flagged
|
app-2.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
# Importing necessary libraries
|
| 2 |
-
import gradio as gr
|
| 3 |
-
import torch
|
| 4 |
-
import torchvision.transforms as transforms
|
| 5 |
-
from torchvision.models import resnet18
|
| 6 |
-
from torchvision.datasets import CIFAR10
|
| 7 |
-
from torch.nn import functional as F
|
| 8 |
-
import numpy as np
|
| 9 |
-
import matplotlib.pyplot as plt
|
| 10 |
-
from PIL import Image
|
| 11 |
-
|
| 12 |
-
# Load CIFAR10 pretrained model
|
| 13 |
-
model = resnet18(pretrained=True)
|
| 14 |
-
model.eval()
|
| 15 |
-
|
| 16 |
-
# Define transformation for CIFAR10
|
| 17 |
-
transform = transforms.Compose([
|
| 18 |
-
transforms.Resize((224, 224)),
|
| 19 |
-
transforms.ToTensor(),
|
| 20 |
-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 21 |
-
])
|
| 22 |
-
|
| 23 |
-
# Load CIFAR10 dataset for example images
|
| 24 |
-
cifar10_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
|
| 25 |
-
example_images = [cifar10_dataset[i][0] for i in range(10)]
|
| 26 |
-
|
| 27 |
-
# Define class names for CIFAR10
|
| 28 |
-
class_names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
| 29 |
-
|
| 30 |
-
def predict(img, gradcam, num_gradcam, layer, opacity, misclassified, num_misclassified, top_classes):
|
| 31 |
-
# Transform and predict
|
| 32 |
-
img_tensor = transform(img).unsqueeze(0)
|
| 33 |
-
outputs = model(img_tensor)
|
| 34 |
-
_, predicted = outputs.max(1)
|
| 35 |
-
probs = F.softmax(outputs, dim=1)[0] * 100
|
| 36 |
-
|
| 37 |
-
# Get top classes
|
| 38 |
-
top_probs, top_labels = torch.topk(probs, min(top_classes, 10))
|
| 39 |
-
top_classes = [class_names[label] for label in top_labels]
|
| 40 |
-
|
| 41 |
-
# GradCAM
|
| 42 |
-
gradcam_images = []
|
| 43 |
-
if gradcam:
|
| 44 |
-
# TODO: Implement GradCAM
|
| 45 |
-
pass
|
| 46 |
-
|
| 47 |
-
# Misclassified images
|
| 48 |
-
misclassified_images = []
|
| 49 |
-
if misclassified:
|
| 50 |
-
# TODO: Get misclassified images
|
| 51 |
-
pass
|
| 52 |
-
|
| 53 |
-
return {'Prediction': top_classes, 'Probabilities': top_probs.tolist(), 'GradCAM': gradcam_images, 'Misclassified': misclassified_images}
|
| 54 |
-
|
| 55 |
-
# Gradio Interface
|
| 56 |
-
image = gr.inputs.Image()
|
| 57 |
-
gradcam = gr.inputs.Checkbox(label='Show GradCAM Images')
|
| 58 |
-
num_gradcam = gr.inputs.Number(label='Number of GradCAM Images', default=1, minimum=1, maximum=10)
|
| 59 |
-
layer = gr.inputs.Dropdown(choices=['layer1', 'layer2', 'layer3', 'layer4'], label='GradCAM Layer')
|
| 60 |
-
opacity = gr.inputs.Slider(minimum=0, maximum=1, default=0.5, label='Opacity')
|
| 61 |
-
misclassified = gr.inputs.Checkbox(label='Show Misclassified Images')
|
| 62 |
-
num_misclassified = gr.inputs.Number(label='Number of Misclassified Images', default=1, minimum=1, maximum=10)
|
| 63 |
-
top_classes = gr.inputs.Number(label='Number of Top Classes to Show', default=1, minimum=1, maximum=10)
|
| 64 |
-
|
| 65 |
-
gr.Interface(fn=predict, inputs=[image, gradcam, num_gradcam, layer, opacity, misclassified, num_misclassified, top_classes], outputs='json').launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,51 +1,100 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
if
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
inputs=[
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
gr.inputs.Number(default=1, label="Number of top classes to show", min=1, max=10),
|
| 38 |
-
],
|
| 39 |
-
outputs=[
|
| 40 |
-
gr.outputs.Label(num_top_classes=10, label="Predictions"),
|
| 41 |
-
gr.outputs.Image(type="numpy", label="GradCAM Images"),
|
| 42 |
-
gr.outputs.Image(type="numpy", label="Misclassified Images"),
|
| 43 |
],
|
|
|
|
| 44 |
examples=[
|
| 45 |
-
|
| 46 |
-
[
|
| 47 |
-
[
|
| 48 |
-
]
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import models, transforms
|
| 5 |
+
from PIL import Image
|
| 6 |
|
| 7 |
+
|
| 8 |
+
# Run Interface
|
| 9 |
+
def run_inference(input_image, gradcam=False, gradcam_layer=3, gradcam_num = 3, gradcam_opacity=0.5, misclassified_num=5, top_classes=10):
|
| 10 |
+
"""Run inference on a CIFAR-10 image.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
image: The image to be classified.
|
| 14 |
+
gradcam: Whether to show GradCAM images.
|
| 15 |
+
gradcam_layer: The layer from which to generate GradCAM images.
|
| 16 |
+
gradcam_opacity: The opacity of the GradCAM images.
|
| 17 |
+
misclassified: Whether to show misclassified images.
|
| 18 |
+
misclassified_num: The number of misclassified images to show.
|
| 19 |
+
top_classes: The number of top classes to show.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
The classification results, including the predicted class, the top classes, and the GradCAM images (if requested).
|
| 23 |
+
"""
|
| 24 |
+
# # Load the CIFAR-10 model
|
| 25 |
+
|
| 26 |
+
# # Classify the image
|
| 27 |
+
# prediction = model.predict(image)
|
| 28 |
+
# predicted_class = np.argmax(prediction)
|
| 29 |
+
|
| 30 |
+
# # Get the top classes
|
| 31 |
+
# top_classes = prediction.argsort()[-top_classes:][::-1]
|
| 32 |
+
|
| 33 |
+
# # Generate GradCAM images, if requested
|
| 34 |
+
# if gradcam:
|
| 35 |
+
# gradcam_images = []
|
| 36 |
+
# for layer in range(model.layers.shape[0]):
|
| 37 |
+
# gradcam_image = gradcam(model, image, layer, gradcam_opacity)
|
| 38 |
+
# gradcam_images.append(gradcam_image)
|
| 39 |
+
|
| 40 |
+
# # Get the misclassified images, if requested
|
| 41 |
+
# misclassified_images = []
|
| 42 |
+
# for i in range(len(prediction)):
|
| 43 |
+
# if prediction[i] != y_test[i]:
|
| 44 |
+
# misclassified_images.append(image[i])
|
| 45 |
+
|
| 46 |
+
# Placeholder for top classes
|
| 47 |
+
top_classes = {"dog" : 0.90, "cat": 0.10}
|
| 48 |
+
|
| 49 |
+
# Placeholder for GradCAM images
|
| 50 |
+
gradcam_images = []
|
| 51 |
+
if gradcam:
|
| 52 |
+
# Generate GradCAM images for the specified layer and number
|
| 53 |
+
for i in range(gradcam_num):
|
| 54 |
+
gradcam_images.append(np.random.rand(32, 32, 3)) # Example random image
|
| 55 |
+
|
| 56 |
+
# Placeholder for misclassified images
|
| 57 |
+
misclassified_images = []
|
| 58 |
+
if misclassified_num > 0:
|
| 59 |
+
# Get misclassified images
|
| 60 |
+
for i in range(misclassified_num):
|
| 61 |
+
misclassified_images.append((np.random.rand(32, 32, 3), 'caption')) # Example random image
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Return the classification results
|
| 65 |
+
return top_classes, gradcam_images if gradcam else [], misclassified_images
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Gradio Interface
|
| 69 |
+
input_image = gr.Image(shape=(32, 32), label="Upload Image", info="Upload a CIFAR-10 image to be classified.")
|
| 70 |
+
gradcam= gr.Checkbox(label="View GradCAM images?", info="Whether to show GradCAM images.")
|
| 71 |
+
gradcam_layer = gr.Dropdown(["1", "2", "3"], value="2", label="GradCAM Layer", info="The layer from which to generate GradCAM images.")
|
| 72 |
+
gradcam_num = gr.Slider(label="Number of GradCAM images", minimum=1, maximum=10, step=1, info="The number of GradCAM images to show.")
|
| 73 |
+
gradcam_opacity = gr.Slider(label="GradCAM opacity", minimum=0.0, maximum=1.0, step=0.01, info="The opacity of the GradCAM images.")
|
| 74 |
+
misclassified_num = gr.Slider(label="Number of Misclassified images", minimum=0, maximum=10, step=1, info="The number of misclassified images to show.")
|
| 75 |
+
top_classes = gr.Slider(label="Number of top classes to show", minimum=1, maximum=10, step=1, info="The number of top classes to show.")
|
| 76 |
+
|
| 77 |
+
output_label = gr.Label(num_top_classes=3, label="Top Classes")
|
| 78 |
+
output_gradcam_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="GradCam Galery")
|
| 79 |
+
output_misclassified_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="Misclassified Images")
|
| 80 |
+
|
| 81 |
+
interface = gr.Interface(
|
| 82 |
+
fn=run_inference,
|
| 83 |
inputs=[
|
| 84 |
+
input_image,
|
| 85 |
+
gradcam,
|
| 86 |
+
gradcam_layer,
|
| 87 |
+
gradcam_num,
|
| 88 |
+
gradcam_opacity,
|
| 89 |
+
misclassified_num,
|
| 90 |
+
top_classes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
],
|
| 92 |
+
outputs=[output_label, output_gradcam_gallery, output_misclassified_gallery],
|
| 93 |
examples=[
|
| 94 |
+
['assets/0001.jpg', True,"3", 4, 0.5, 3, 3, 2],
|
| 95 |
+
['assets/0002.jpg', False, "2", 1, 0.3, 1, 2, 2],
|
| 96 |
+
['assets/0003.jpg', True, "2", 1, 0.3, 1, 2, 2],
|
| 97 |
+
],
|
| 98 |
+
title="Cifar-10 Inference with GradCAM",
|
| 99 |
+
description="This is a CIFAR-10 image classifier using custom resnet. Upload a CIFAR-10 image and it will be classified into one of 10 categories.",)
|
| 100 |
+
interface.launch(share=False)
|
assets/0001.jpg
ADDED
|
assets/0002.jpg
ADDED
|
assets/0003.jpg
ADDED
|