File size: 1,610 Bytes
2c32f6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
from transformers import AutoImageProcessor, ConvNextForImageClassification

# Choose a stronger, free image model from Hugging Face
# You can swap this string for any other image-classification model on HF
model_name = "facebook/convnext-base-224-22k-1k"

# Load pre-trained image processor and model
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = ConvNextForImageClassification.from_pretrained(model_name)

# Define the prediction function (top 5 classes)
def classify_image(img):
    # Preprocess image
    inputs = image_processor(images=img, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    probs = torch.softmax(logits, dim=-1)[0]  # shape: [num_classes]

    # Get top 5 predictions
    topk = torch.topk(probs, k=5)
    top_probs = topk.values
    top_indices = topk.indices

    # Map indices to labels and convert to a dict that Gradio's Label understands
    results = {}
    for score, idx in zip(top_probs, top_indices):
        label = model.config.id2label[idx.item()]
        results[label] = float(score.item())

    return results  # Gradio Label will show top-k nicely


# Build the Gradio interface
interface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=5),
    title="Image Classification with ConvNeXt (Top-5)",
    description="Upload an image to see the top 5 predicted classes using a ConvNeXt image model from Hugging Face."
)

# Launch the app
if __name__ == "__main__":
    interface.launch()