Hamzah-ALQadasi's picture
Update app.py
6ee6c1b
import json
import torch
import urllib
import gradio as gr
from PIL import Image
import torchvision.models as models
import torchvision.transforms as transforms
# Load the JSON file with the ImageNet class labels
imagenet_labels_json_url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json'
imagenet_labels_json = urllib.request.urlopen(imagenet_labels_json_url)
# Decode JSON to a Python dictionary
imagenet_labels_dict = json.load(imagenet_labels_json)
# Convert the dictionary to a list of class labels
imagenet_labels = [imagenet_labels_dict[str(i)][1] for i in range(len(imagenet_labels_dict))]
# Load a pre-trained model (for example, ResNet18)
model = models.resnet18(pretrained=True)
model.eval()
# Define the image transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the classification function
def classify_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image).unsqueeze(0)
with torch.no_grad():
preds = model(image)
# Get the index and the confidence of the predicted class
confidences = torch.nn.functional.softmax(preds, dim=1)[0] * 100
pred_index = confidences.argmax().item()
confidence = confidences[pred_index].item()
# Get the human-readable class name
pred_class_name = imagenet_labels[pred_index]
# Determine the bar color based on the confidence
bar_color = "#FF0000" if confidence < 50 else "#00FF00"
# Create a progress bar using HTML based on the confidence
progress_bar_html = f"""
<div style="margin: 10px 0;">
<label>Confidence Level:</label>
<div style="width: 100%; background-color: #e0e0de;">
<div style="width: {confidence}%; background-color: {bar_color}; padding: 10px; color: white; text-align: center;">
{confidence:.2f}%
</div>
</div>
</div>
"""
# Return the class name and the progress bar HTML
return pred_class_name, progress_bar_html
examples = [
"examples/red_pandas.jpg",
"examples/turtle.jpg",
"examples/gazelle.jpg",
"examples/cat_dog.jpg"
]
# Set up the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(),
outputs=[
gr.Text(label="Class Name"),
gr.HTML(label="Confidence Level"),
],
examples=examples
)
# Launch the application
iface.launch(share=True)