rtik007's picture
Update app.py
a3e94c1 verified
'''
The code is designed to identify dog breeds from uploaded images by leveraging a pretrained image classification model,
such as VGG16, fine-tuned specifically for dog breed classification. This is achieved by using a Convolutional
Neural Network (CNN) within PyTorch framework. Additionally, Gradio is used to build a user-friendly web-based
interface for easy image uploads and breed predictions.
'''
import gradio as gr
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import requests
import numpy as np
from PIL import Image
# -----------------------------
# PRELOAD MODEL & LABELS
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load vanilla VGG16 pretrained on ImageNet
model = models.vgg16(weights="IMAGENET1K_V1").to(device)
model.eval()
# Download ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
LABELS_CACHE = requests.get(LABELS_URL, timeout=5).json()
except Exception as e:
print(f"Could not fetch ImageNet labels: {e}")
LABELS_CACHE = [f"Class {i}" for i in range(1000)]
# Transform pipeline
transform_pipeline = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# -----------------------------
# CLASSIFICATION FUNCTION
# -----------------------------
def classify_image(image, confidence_threshold=0.0):
"""
Classify an image using the pretrained VGG16 on ImageNet.
Returns top-3 predictions above the given confidence_threshold.
"""
try:
# Convert Gradio's numpy image to PIL
if isinstance(image, np.ndarray):
image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
else:
image_pil = Image.open(image).convert('RGB')
# Preprocess
input_tensor = transform_pipeline(image_pil).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output = model(input_tensor)
probs = torch.nn.functional.softmax(output, dim=1)
# Top-3 predictions
top_probs, top_cls_idxs = torch.topk(probs, 3)
top_probs = top_probs[0].cpu().numpy()
top_cls_idxs = top_cls_idxs[0].cpu().numpy()
results = {}
for p, cidx in zip(top_probs, top_cls_idxs):
if p >= confidence_threshold:
label = LABELS_CACHE[cidx] if LABELS_CACHE else f"Class {cidx}"
results[label] = float(p)
if not results:
return "No predictions above the confidence threshold."
return results
except Exception as e:
return f"Error during classification: {str(e)}"
# -----------------------------
# (OPTIONAL) CUSTOM CSS
# -----------------------------
custom_css = """
body {
margin: 0;
padding: 0;
background: linear-gradient(135deg, #f6f9fc, #ddeefc);
font-family: "Helvetica", sans-serif;
}
h1, p {
text-align: center;
margin-bottom: 1rem;
}
"""
# -----------------------------
# BUILD THE GRADIO APP
# -----------------------------
def build_app():
with gr.Blocks(css=custom_css) as demo:
gr.HTML("<h1>VGG16 ImageNet Classifier</h1>")
gr.HTML("<p>Upload an image to see the top 3 predicted ImageNet classes.</p>")
#with gr.Box():
# Place widgets in a vertical layout
image_input = gr.Image(type="numpy", label="Upload an Image")
confidence_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Confidence Threshold")
classify_button = gr.Button("What Breed of Dog is That?")
label_output = gr.Label(num_top_classes=3, label="Prediction Results")
# Connect button click to classification
classify_button.click(
fn=classify_image,
inputs=[image_input, confidence_slider],
outputs=label_output
)
return demo
# -----------------------------
# LAUNCH
# -----------------------------
if __name__ == "__main__":
demo = build_app()
demo.launch()