Spaces:
Sleeping
Sleeping
| ''' | |
| The code is designed to identify dog breeds from uploaded images by leveraging a pretrained image classification model, | |
| such as VGG16, fine-tuned specifically for dog breed classification. This is achieved by using a Convolutional | |
| Neural Network (CNN) within PyTorch framework. Additionally, Gradio is used to build a user-friendly web-based | |
| interface for easy image uploads and breed predictions. | |
| ''' | |
| import gradio as gr | |
| import torch | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| import requests | |
| import numpy as np | |
| from PIL import Image | |
| # ----------------------------- | |
| # PRELOAD MODEL & LABELS | |
| # ----------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load vanilla VGG16 pretrained on ImageNet | |
| model = models.vgg16(weights="IMAGENET1K_V1").to(device) | |
| model.eval() | |
| # Download ImageNet labels | |
| LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
| try: | |
| LABELS_CACHE = requests.get(LABELS_URL, timeout=5).json() | |
| except Exception as e: | |
| print(f"Could not fetch ImageNet labels: {e}") | |
| LABELS_CACHE = [f"Class {i}" for i in range(1000)] | |
| # Transform pipeline | |
| transform_pipeline = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # ----------------------------- | |
| # CLASSIFICATION FUNCTION | |
| # ----------------------------- | |
| def classify_image(image, confidence_threshold=0.0): | |
| """ | |
| Classify an image using the pretrained VGG16 on ImageNet. | |
| Returns top-3 predictions above the given confidence_threshold. | |
| """ | |
| try: | |
| # Convert Gradio's numpy image to PIL | |
| if isinstance(image, np.ndarray): | |
| image_pil = Image.fromarray(image.astype('uint8'), 'RGB') | |
| else: | |
| image_pil = Image.open(image).convert('RGB') | |
| # Preprocess | |
| input_tensor = transform_pipeline(image_pil).unsqueeze(0).to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| # Top-3 predictions | |
| top_probs, top_cls_idxs = torch.topk(probs, 3) | |
| top_probs = top_probs[0].cpu().numpy() | |
| top_cls_idxs = top_cls_idxs[0].cpu().numpy() | |
| results = {} | |
| for p, cidx in zip(top_probs, top_cls_idxs): | |
| if p >= confidence_threshold: | |
| label = LABELS_CACHE[cidx] if LABELS_CACHE else f"Class {cidx}" | |
| results[label] = float(p) | |
| if not results: | |
| return "No predictions above the confidence threshold." | |
| return results | |
| except Exception as e: | |
| return f"Error during classification: {str(e)}" | |
| # ----------------------------- | |
| # (OPTIONAL) CUSTOM CSS | |
| # ----------------------------- | |
| custom_css = """ | |
| body { | |
| margin: 0; | |
| padding: 0; | |
| background: linear-gradient(135deg, #f6f9fc, #ddeefc); | |
| font-family: "Helvetica", sans-serif; | |
| } | |
| h1, p { | |
| text-align: center; | |
| margin-bottom: 1rem; | |
| } | |
| """ | |
| # ----------------------------- | |
| # BUILD THE GRADIO APP | |
| # ----------------------------- | |
| def build_app(): | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.HTML("<h1>VGG16 ImageNet Classifier</h1>") | |
| gr.HTML("<p>Upload an image to see the top 3 predicted ImageNet classes.</p>") | |
| #with gr.Box(): | |
| # Place widgets in a vertical layout | |
| image_input = gr.Image(type="numpy", label="Upload an Image") | |
| confidence_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Confidence Threshold") | |
| classify_button = gr.Button("What Breed of Dog is That?") | |
| label_output = gr.Label(num_top_classes=3, label="Prediction Results") | |
| # Connect button click to classification | |
| classify_button.click( | |
| fn=classify_image, | |
| inputs=[image_input, confidence_slider], | |
| outputs=label_output | |
| ) | |
| return demo | |
| # ----------------------------- | |
| # LAUNCH | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| demo = build_app() | |
| demo.launch() | |