Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| from PIL import Image | |
| import requests | |
| import numpy as np | |
| import pandas as pd | |
| from plottable import Table | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| import random | |
| def classify_image(upload, url, labels): | |
| """ | |
| Classify the image either from an uploaded file or a URL with given labels. | |
| """ | |
| # Check if an image file is uploaded | |
| if upload is not None: | |
| # Read the uploaded file as a byte stream | |
| image = Image.open(BytesIO(upload)) | |
| # Otherwise, load the image from the provided URL | |
| elif url is not None: | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| # If neither, return a message prompting for an input | |
| else: | |
| return "Please upload an image or enter an image URL." | |
| # Split the labels by comma and strip whitespace | |
| labels_list = [label.strip() for label in labels.split(',')] | |
| # Load the image classification model | |
| image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384") | |
| # Perform inference | |
| outputs = image_classifier(image, candidate_labels=labels_list) | |
| # Process outputs | |
| labels = [output["label"] for output in outputs] | |
| scores = [output["score"] for output in outputs] | |
| # Normalize scores to sum up to 100% | |
| total_score = sum(scores) | |
| normalized_scores = [round(score * 100 / total_score, 2) for score in scores] | |
| # Plot the horizontal bar chart with different colors for each label | |
| plt.figure(figsize=(10, 6)) | |
| colors = [plt.cm.viridis(i/len(labels)) for i in range(len(labels))] | |
| plt.barh(labels, normalized_scores, color=colors) | |
| plt.xlabel('Score (%)') | |
| plt.ylabel('Labels') | |
| plt.title('Classification Results') | |
| plt.gca().invert_yaxis() # Invert y-axis to display labels from top to bottom | |
| plt.tight_layout() | |
| # Save the plot to a BytesIO object | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| # Convert BytesIO object to image | |
| result_image = Image.open(buf) | |
| # Create a DataFrame for the classification results | |
| df = pd.DataFrame({"Labels": labels, "Scores (%)": normalized_scores}) | |
| # Create a plottable table | |
| tab = Table(df) | |
| # Plot the table using matplotlib | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| ax.axis('tight') | |
| ax.axis('off') | |
| ax.table(cellText=df.values, colLabels=df.columns, loc='center') | |
| # Save the figure to a BytesIO object | |
| buf_table = BytesIO() | |
| plt.savefig(buf_table, format='png') | |
| buf_table.seek(0) | |
| # Convert BytesIO object to image | |
| result_table_image = Image.open(buf_table) | |
| return result_image, result_table_image | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=classify_image, | |
| inputs=[ | |
| gr.File(type="binary", label="Upload Image"), | |
| gr.Textbox(label="Or, enter Image URL"), | |
| gr.Textbox(label="Enter labels separated by commas (e.g., animal, human, building)") | |
| ], | |
| outputs=[ | |
| gr.Image(label="Classification Results (Bar Chart)"), | |
| gr.Image(label="Classification Results (Table)") | |
| ], | |
| title="Image Classifier", | |
| description="Upload an image or enter an image URL, then specify labels to classify the image." | |
| ) | |
| # Launch the interface | |
| interface.launch() | |