import json import torch import urllib import gradio as gr from PIL import Image import torchvision.models as models import torchvision.transforms as transforms # Load the JSON file with the ImageNet class labels imagenet_labels_json_url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json' imagenet_labels_json = urllib.request.urlopen(imagenet_labels_json_url) # Decode JSON to a Python dictionary imagenet_labels_dict = json.load(imagenet_labels_json) # Convert the dictionary to a list of class labels imagenet_labels = [imagenet_labels_dict[str(i)][1] for i in range(len(imagenet_labels_dict))] # Load a pre-trained model (for example, ResNet18) model = models.resnet18(pretrained=True) model.eval() # Define the image transformation transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Define the classification function def classify_image(image): image = Image.fromarray(image.astype('uint8'), 'RGB') image = transform(image).unsqueeze(0) with torch.no_grad(): preds = model(image) # Get the index and the confidence of the predicted class confidences = torch.nn.functional.softmax(preds, dim=1)[0] * 100 pred_index = confidences.argmax().item() confidence = confidences[pred_index].item() # Get the human-readable class name pred_class_name = imagenet_labels[pred_index] # Determine the bar color based on the confidence bar_color = "#FF0000" if confidence < 50 else "#00FF00" # Create a progress bar using HTML based on the confidence progress_bar_html = f"""
{confidence:.2f}%
""" # Return the class name and the progress bar HTML return pred_class_name, progress_bar_html examples = [ "examples/red_pandas.jpg", "examples/turtle.jpg", "examples/gazelle.jpg", "examples/cat_dog.jpg" ] # Set up the Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(), outputs=[ gr.Text(label="Class Name"), gr.HTML(label="Confidence Level"), ], examples=examples ) # Launch the application iface.launch(share=True)