File size: 3,566 Bytes
39d9406
dc382c8
55d79e2
39d9406
d56b9d9
dc382c8
 
39d9406
dc382c8
 
1c1b97a
5bebd85
 
 
 
 
 
 
 
 
1c1b97a
5bebd85
 
 
 
 
 
 
 
 
 
 
dc382c8
d56b9d9
1c1b97a
d56b9d9
 
 
 
 
 
dc382c8
d56b9d9
dc382c8
 
 
 
39d9406
 
55d79e2
5bebd85
 
 
 
 
 
 
 
 
 
55d79e2
5bebd85
39d9406
 
 
 
 
 
 
 
 
 
 
 
55d79e2
39d9406
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from functools import partial
from huggingface_hub import InferenceClient
from os import path, unlink
import gradio as gr
from PIL.Image import Image
import pandas as pd
from pandas import DataFrame
from utils import save_image_to_temp_file, request_image


def image_classification(client: InferenceClient, model: str, image: Image) -> DataFrame:
    """Classify an image using Hugging Face Inference API.
    
    This function classifies a recyclable item image into categories:
    cardboard, glass, metal, paper, plastic, or other. The image is saved
    to a temporary file since InferenceClient requires a file path rather than
    a PIL Image object directly.
    
    Args:
        client: Hugging Face InferenceClient instance for API calls.
        model: Hugging Face model ID to use for image classification.
        image: PIL Image object to classify.
    
    Returns:
        Pandas DataFrame with two columns:
            - Label: The classification label (e.g., "cardboard", "glass")
            - Probability: The confidence score as a percentage string (e.g., "95.23%")
    
    Note:
        - Automatically cleans up temporary files after classification.
        - Temporary file is created with format preservation if possible.
    """
    try:
        temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
        classifications = client.image_classification(temp_file_path, model=model)
        return pd.DataFrame({
                                "Label": classification.label,
                                "Probability": f"{classification.score:.2%}"
                            }
                            for classification
                            in classifications)
    finally:
        if temp_file_path and path.exists(temp_file_path): # Clean up temporary file.
            try:
                unlink(temp_file_path)
            except Exception:
                pass # Ignore clean-up errors.


def create_image_classification_tab(client: InferenceClient, model: str):
    """Create the image classification tab in the Gradio interface.

    This function sets up all UI components for image classification, including:
    - URL input textbox for fetching images from the web
    - Button to retrieve image from URL
    - Image preview component
    - Classify button and output dataframe showing labels and probabilities

    Args:
        client: Hugging Face InferenceClient instance to pass to the image_classification function.
        model: Hugging Face model ID to use for image classification.
    """
    gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
    image_classification_url_input = gr.Textbox(label="Image URL")
    image_classification_image_request_button = gr.Button("Get Image")
    image_classification_image_input = gr.Image(label="Image", type="pil")
    image_classification_image_request_button.click(
        fn=request_image,
        inputs=image_classification_url_input,
        outputs=image_classification_image_input
    )
    image_classification_button = gr.Button("Classify")
    image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
    image_classification_button.click(
        fn=partial(image_classification, client, model),
        inputs=image_classification_image_input,
        outputs=image_classification_output
    )