File size: 4,196 Bytes
3de01bd
c2bec91
2060a55
6cd019c
c2bec91
3de01bd
 
8c51c26
3de01bd
 
13692c5
 
8c51c26
13692c5
3bd7f39
13692c5
8c51c26
13692c5
3bd7f39
 
 
bd3eb35
13692c5
8c51c26
 
bd3eb35
8c51c26
 
 
 
 
 
 
bd3eb35
8c51c26
 
 
 
bd3eb35
 
8c51c26
 
3de01bd
8c51c26
 
 
3de01bd
13692c5
bd3eb35
 
13692c5
3de01bd
bd3eb35
8c51c26
 
 
 
13692c5
bd3eb35
8c51c26
 
 
3bd7f39
13692c5
bd3eb35
13692c5
8c51c26
bd3eb35
8c51c26
bd3eb35
3de01bd
13692c5
bd3eb35
 
 
 
13692c5
 
 
 
bd3eb35
3de01bd
13692c5
3de01bd
13692c5
bd3eb35
8c51c26
 
 
 
 
bd3eb35
 
8c51c26
bd3eb35
8c51c26
bd3eb35
8c51c26
 
 
 
bd3eb35
13692c5
bd3eb35
 
 
 
8c51c26
44f8040
bd3eb35
44f8040
 
a3e94c1
8c51c26
44f8040
bd3eb35
 
 
8c51c26
bd3eb35
 
8c51c26
 
 
3bd7f39
13692c5
bd3eb35
13692c5
8c51c26
bd3eb35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
'''
The code is designed to identify dog breeds from uploaded images by leveraging a pretrained image classification model, 
such as VGG16, fine-tuned specifically for dog breed classification. This is achieved by using a Convolutional
Neural Network (CNN) within PyTorch framework. Additionally, Gradio is used to build a user-friendly web-based
interface for easy image uploads and breed predictions.
'''

import gradio as gr
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import requests
import numpy as np
from PIL import Image

# -----------------------------
#  PRELOAD MODEL & LABELS
# -----------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load vanilla VGG16 pretrained on ImageNet
model = models.vgg16(weights="IMAGENET1K_V1").to(device)
model.eval()

# Download ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
    LABELS_CACHE = requests.get(LABELS_URL, timeout=5).json()
except Exception as e:
    print(f"Could not fetch ImageNet labels: {e}")
    LABELS_CACHE = [f"Class {i}" for i in range(1000)]

# Transform pipeline
transform_pipeline = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# -----------------------------
#  CLASSIFICATION FUNCTION
# -----------------------------
def classify_image(image, confidence_threshold=0.0):
    """
    Classify an image using the pretrained VGG16 on ImageNet.
    Returns top-3 predictions above the given confidence_threshold.
    """
    try:
        # Convert Gradio's numpy image to PIL
        if isinstance(image, np.ndarray):
            image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
        else:
            image_pil = Image.open(image).convert('RGB')

        # Preprocess
        input_tensor = transform_pipeline(image_pil).unsqueeze(0).to(device)

        # Inference
        with torch.no_grad():
            output = model(input_tensor)
            probs = torch.nn.functional.softmax(output, dim=1)

        # Top-3 predictions
        top_probs, top_cls_idxs = torch.topk(probs, 3)
        top_probs = top_probs[0].cpu().numpy()
        top_cls_idxs = top_cls_idxs[0].cpu().numpy()

        results = {}
        for p, cidx in zip(top_probs, top_cls_idxs):
            if p >= confidence_threshold:
                label = LABELS_CACHE[cidx] if LABELS_CACHE else f"Class {cidx}"
                results[label] = float(p)

        if not results:
            return "No predictions above the confidence threshold."
        return results

    except Exception as e:
        return f"Error during classification: {str(e)}"

# -----------------------------
#  (OPTIONAL) CUSTOM CSS
# -----------------------------
custom_css = """
body {
    margin: 0;
    padding: 0;
    background: linear-gradient(135deg, #f6f9fc, #ddeefc);
    font-family: "Helvetica", sans-serif;
}
h1, p {
    text-align: center;
    margin-bottom: 1rem;
}
"""

# -----------------------------
#  BUILD THE GRADIO APP
# -----------------------------
def build_app():
    with gr.Blocks(css=custom_css) as demo:
        gr.HTML("<h1>VGG16 ImageNet Classifier</h1>")
        gr.HTML("<p>Upload an image to see the top 3 predicted ImageNet classes.</p>")

        #with gr.Box():
            # Place widgets in a vertical layout
        image_input = gr.Image(type="numpy", label="Upload an Image")
        confidence_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Confidence Threshold")
        classify_button = gr.Button("What Breed of Dog is That?")

        label_output = gr.Label(num_top_classes=3, label="Prediction Results")

        # Connect button click to classification
        classify_button.click(
            fn=classify_image,
            inputs=[image_input, confidence_slider],
            outputs=label_output
        )

    return demo

# -----------------------------
#  LAUNCH
# -----------------------------
if __name__ == "__main__":
    demo = build_app()
    demo.launch()