Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,566 Bytes
39d9406 dc382c8 55d79e2 39d9406 d56b9d9 dc382c8 39d9406 dc382c8 1c1b97a 5bebd85 1c1b97a 5bebd85 dc382c8 d56b9d9 1c1b97a d56b9d9 dc382c8 d56b9d9 dc382c8 39d9406 55d79e2 5bebd85 55d79e2 5bebd85 39d9406 55d79e2 39d9406 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from functools import partial
from huggingface_hub import InferenceClient
from os import path, unlink
import gradio as gr
from PIL.Image import Image
import pandas as pd
from pandas import DataFrame
from utils import save_image_to_temp_file, request_image
def image_classification(client: InferenceClient, model: str, image: Image) -> DataFrame:
"""Classify an image using Hugging Face Inference API.
This function classifies a recyclable item image into categories:
cardboard, glass, metal, paper, plastic, or other. The image is saved
to a temporary file since InferenceClient requires a file path rather than
a PIL Image object directly.
Args:
client: Hugging Face InferenceClient instance for API calls.
model: Hugging Face model ID to use for image classification.
image: PIL Image object to classify.
Returns:
Pandas DataFrame with two columns:
- Label: The classification label (e.g., "cardboard", "glass")
- Probability: The confidence score as a percentage string (e.g., "95.23%")
Note:
- Automatically cleans up temporary files after classification.
- Temporary file is created with format preservation if possible.
"""
try:
temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
classifications = client.image_classification(temp_file_path, model=model)
return pd.DataFrame({
"Label": classification.label,
"Probability": f"{classification.score:.2%}"
}
for classification
in classifications)
finally:
if temp_file_path and path.exists(temp_file_path): # Clean up temporary file.
try:
unlink(temp_file_path)
except Exception:
pass # Ignore clean-up errors.
def create_image_classification_tab(client: InferenceClient, model: str):
"""Create the image classification tab in the Gradio interface.
This function sets up all UI components for image classification, including:
- URL input textbox for fetching images from the web
- Button to retrieve image from URL
- Image preview component
- Classify button and output dataframe showing labels and probabilities
Args:
client: Hugging Face InferenceClient instance to pass to the image_classification function.
model: Hugging Face model ID to use for image classification.
"""
gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
image_classification_url_input = gr.Textbox(label="Image URL")
image_classification_image_request_button = gr.Button("Get Image")
image_classification_image_input = gr.Image(label="Image", type="pil")
image_classification_image_request_button.click(
fn=request_image,
inputs=image_classification_url_input,
outputs=image_classification_image_input
)
image_classification_button = gr.Button("Classify")
image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
image_classification_button.click(
fn=partial(image_classification, client, model),
inputs=image_classification_image_input,
outputs=image_classification_output
)
|