Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import urllib | |
| import gradio as gr | |
| from PIL import Image | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| # Load the JSON file with the ImageNet class labels | |
| imagenet_labels_json_url = 'https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json' | |
| imagenet_labels_json = urllib.request.urlopen(imagenet_labels_json_url) | |
| # Decode JSON to a Python dictionary | |
| imagenet_labels_dict = json.load(imagenet_labels_json) | |
| # Convert the dictionary to a list of class labels | |
| imagenet_labels = [imagenet_labels_dict[str(i)][1] for i in range(len(imagenet_labels_dict))] | |
| # Load a pre-trained model (for example, ResNet18) | |
| model = models.resnet18(pretrained=True) | |
| model.eval() | |
| # Define the image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Define the classification function | |
| def classify_image(image): | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| preds = model(image) | |
| # Get the index and the confidence of the predicted class | |
| confidences = torch.nn.functional.softmax(preds, dim=1)[0] * 100 | |
| pred_index = confidences.argmax().item() | |
| confidence = confidences[pred_index].item() | |
| # Get the human-readable class name | |
| pred_class_name = imagenet_labels[pred_index] | |
| # Determine the bar color based on the confidence | |
| bar_color = "#FF0000" if confidence < 50 else "#00FF00" | |
| # Create a progress bar using HTML based on the confidence | |
| progress_bar_html = f""" | |
| <div style="margin: 10px 0;"> | |
| <label>Confidence Level:</label> | |
| <div style="width: 100%; background-color: #e0e0de;"> | |
| <div style="width: {confidence}%; background-color: {bar_color}; padding: 10px; color: white; text-align: center;"> | |
| {confidence:.2f}% | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| # Return the class name and the progress bar HTML | |
| return pred_class_name, progress_bar_html | |
| examples = [ | |
| "examples/red_pandas.jpg", | |
| "examples/turtle.jpg", | |
| "examples/gazelle.jpg", | |
| "examples/cat_dog.jpg" | |
| ] | |
| # Set up the Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(), | |
| outputs=[ | |
| gr.Text(label="Class Name"), | |
| gr.HTML(label="Confidence Level"), | |
| ], | |
| examples=examples | |
| ) | |
| # Launch the application | |
| iface.launch(share=True) | |