Spaces:
Sleeping
Sleeping
File size: 4,196 Bytes
3de01bd c2bec91 2060a55 6cd019c c2bec91 3de01bd 8c51c26 3de01bd 13692c5 8c51c26 13692c5 3bd7f39 13692c5 8c51c26 13692c5 3bd7f39 bd3eb35 13692c5 8c51c26 bd3eb35 8c51c26 bd3eb35 8c51c26 bd3eb35 8c51c26 3de01bd 8c51c26 3de01bd 13692c5 bd3eb35 13692c5 3de01bd bd3eb35 8c51c26 13692c5 bd3eb35 8c51c26 3bd7f39 13692c5 bd3eb35 13692c5 8c51c26 bd3eb35 8c51c26 bd3eb35 3de01bd 13692c5 bd3eb35 13692c5 bd3eb35 3de01bd 13692c5 3de01bd 13692c5 bd3eb35 8c51c26 bd3eb35 8c51c26 bd3eb35 8c51c26 bd3eb35 8c51c26 bd3eb35 13692c5 bd3eb35 8c51c26 44f8040 bd3eb35 44f8040 a3e94c1 8c51c26 44f8040 bd3eb35 8c51c26 bd3eb35 8c51c26 3bd7f39 13692c5 bd3eb35 13692c5 8c51c26 bd3eb35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
'''
The code is designed to identify dog breeds from uploaded images by leveraging a pretrained image classification model,
such as VGG16, fine-tuned specifically for dog breed classification. This is achieved by using a Convolutional
Neural Network (CNN) within PyTorch framework. Additionally, Gradio is used to build a user-friendly web-based
interface for easy image uploads and breed predictions.
'''
import gradio as gr
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import requests
import numpy as np
from PIL import Image
# -----------------------------
# PRELOAD MODEL & LABELS
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load vanilla VGG16 pretrained on ImageNet
model = models.vgg16(weights="IMAGENET1K_V1").to(device)
model.eval()
# Download ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
LABELS_CACHE = requests.get(LABELS_URL, timeout=5).json()
except Exception as e:
print(f"Could not fetch ImageNet labels: {e}")
LABELS_CACHE = [f"Class {i}" for i in range(1000)]
# Transform pipeline
transform_pipeline = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# -----------------------------
# CLASSIFICATION FUNCTION
# -----------------------------
def classify_image(image, confidence_threshold=0.0):
"""
Classify an image using the pretrained VGG16 on ImageNet.
Returns top-3 predictions above the given confidence_threshold.
"""
try:
# Convert Gradio's numpy image to PIL
if isinstance(image, np.ndarray):
image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
else:
image_pil = Image.open(image).convert('RGB')
# Preprocess
input_tensor = transform_pipeline(image_pil).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output = model(input_tensor)
probs = torch.nn.functional.softmax(output, dim=1)
# Top-3 predictions
top_probs, top_cls_idxs = torch.topk(probs, 3)
top_probs = top_probs[0].cpu().numpy()
top_cls_idxs = top_cls_idxs[0].cpu().numpy()
results = {}
for p, cidx in zip(top_probs, top_cls_idxs):
if p >= confidence_threshold:
label = LABELS_CACHE[cidx] if LABELS_CACHE else f"Class {cidx}"
results[label] = float(p)
if not results:
return "No predictions above the confidence threshold."
return results
except Exception as e:
return f"Error during classification: {str(e)}"
# -----------------------------
# (OPTIONAL) CUSTOM CSS
# -----------------------------
custom_css = """
body {
margin: 0;
padding: 0;
background: linear-gradient(135deg, #f6f9fc, #ddeefc);
font-family: "Helvetica", sans-serif;
}
h1, p {
text-align: center;
margin-bottom: 1rem;
}
"""
# -----------------------------
# BUILD THE GRADIO APP
# -----------------------------
def build_app():
with gr.Blocks(css=custom_css) as demo:
gr.HTML("<h1>VGG16 ImageNet Classifier</h1>")
gr.HTML("<p>Upload an image to see the top 3 predicted ImageNet classes.</p>")
#with gr.Box():
# Place widgets in a vertical layout
image_input = gr.Image(type="numpy", label="Upload an Image")
confidence_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Confidence Threshold")
classify_button = gr.Button("What Breed of Dog is That?")
label_output = gr.Label(num_top_classes=3, label="Prediction Results")
# Connect button click to classification
classify_button.click(
fn=classify_image,
inputs=[image_input, confidence_slider],
outputs=label_output
)
return demo
# -----------------------------
# LAUNCH
# -----------------------------
if __name__ == "__main__":
demo = build_app()
demo.launch()
|