rtik007 commited on
Commit
13692c5
·
verified ·
1 Parent(s): 3bd7f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -46
app.py CHANGED
@@ -5,81 +5,147 @@ Neural Network (CNN) within PyTorch framework. Additionally, Gradio is used to b
5
  interface for easy image uploads and breed predictions.
6
  '''
7
 
8
- import gradio as gr
9
- import torchvision.transforms as transforms
10
- import torch.nn.functional as F
11
- from PIL import Image
 
 
 
 
12
  import numpy as np
13
  import torch
14
  import torchvision.models as models
15
- import torch.nn as nn
 
 
 
16
 
17
- # 1. Load your fine-tuned model
18
- num_breeds = 120 # Example: 120 dog breeds
19
- DOG_BREEDS = ["Chihuahua", "Japanese Spaniel", ..., "Mastiff"] # etc., in correct order
20
 
 
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
- # Start from VGG16 base
24
- fine_tuned_model = models.vgg16(weights="IMAGENET1K_V1")
25
- # Replace classifier (should match whatever you used during training)
26
- fine_tuned_model.classifier[-1] = nn.Linear(in_features=4096, out_features=num_breeds)
27
- fine_tuned_model.to(device)
28
-
29
- # Load the trained weights
30
- fine_tuned_model.load_state_dict(torch.load("dog_breed_vgg16.pth", map_location=device))
31
- fine_tuned_model.eval()
32
-
33
- # 2. Define transforms, including normalization
34
- in_transform = transforms.Compose([
35
- transforms.Resize((224, 224)),
36
- transforms.ToTensor(),
37
- transforms.Normalize(
38
- mean=[0.485, 0.456, 0.406], # ImageNet means
39
- std=[0.229, 0.224, 0.225] # ImageNet std
40
- )
41
- ])
 
 
 
42
 
43
  def load_convert_image_to_tensor(image):
44
- """Converts image (numpy/PIL) to a PyTorch tensor, normalized for VGG16."""
 
 
 
45
  if isinstance(image, np.ndarray):
46
  image = Image.fromarray(image.astype('uint8'), 'RGB')
47
  elif isinstance(image, str):
48
  image = Image.open(image).convert('RGB')
49
- return in_transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def classify_image(image, confidence_threshold=0.0):
52
- """Classify the image as one of the dog breeds."""
 
 
 
 
 
 
53
  try:
54
- image_tensor = load_convert_image_to_tensor(image)
 
 
 
55
  with torch.no_grad():
56
- output = fine_tuned_model(image_tensor)
57
- softmax_output = F.softmax(output, dim=1)
58
- top_probs, top_classes = torch.topk(softmax_output, 3)
 
 
 
 
59
 
 
60
  top_probs = top_probs.cpu().numpy()[0]
61
  top_classes = top_classes.cpu().numpy()[0]
62
 
63
- result = {}
64
- for prob, cls_id in zip(top_probs, top_classes):
 
65
  if prob >= confidence_threshold:
66
- breed_label = DOG_BREEDS[cls_id]
67
- result[breed_label] = float(prob)
68
- return result if result else "No predictions above the confidence threshold."
 
 
 
 
 
69
  except Exception as e:
70
- return f"Error: {str(e)}"
71
 
72
- # Gradio interface
 
 
73
  image_input = gr.Image()
74
- confidence_slider = gr.Slider(0, 1, value=0.1, label="Confidence Threshold")
75
-
 
 
 
 
76
  label_output = gr.Label(num_top_classes=3)
 
77
  interface = gr.Interface(
78
- fn=classify_image,
79
  inputs=[image_input, confidence_slider],
80
  outputs=label_output,
81
- title="Dog Breed Classifier",
82
- description="Upload an image of a dog to see the predicted breed(s)."
83
  )
84
 
 
 
 
85
  interface.launch(share=True)
 
 
5
  interface for easy image uploads and breed predictions.
6
  '''
7
 
8
+ # -----------------------------
9
+ # INSTALL DEPENDENCIES (if needed)
10
+ # -----------------------------
11
+ # !pip install torch torchvision
12
+ # !pip install gradio
13
+ # !pip install requests
14
+ # !pip install pillow
15
+
16
  import numpy as np
17
  import torch
18
  import torchvision.models as models
19
+ import torchvision.transforms as transforms
20
+ import requests
21
+ from PIL import Image
22
+ import gradio as gr
23
 
24
+ # -----------------------------
25
+ # SETUP
26
+ # -----------------------------
27
 
28
+ # Prefer GPU if available
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
+ # Load the pretrained VGG16 model
32
+ model = models.vgg16(weights="IMAGENET1K_V1").to(device)
33
+ model.eval() # Important: set to evaluation mode
34
+
35
+ # Global variable to hold ImageNet labels once downloaded
36
+ LABELS_CACHE = None
37
+
38
+ def prefetch_labels():
39
+ """
40
+ Fetch the human-readable labels for ImageNet classes.
41
+ This uses a known list from GitHub.
42
+ """
43
+ global LABELS_CACHE
44
+ LABELS_MAP_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
45
+ try:
46
+ LABELS_CACHE = requests.get(LABELS_MAP_URL, timeout=5).json()
47
+ except requests.exceptions.RequestException as e:
48
+ LABELS_CACHE = None
49
+ print(f"Error fetching labels: {e}")
50
+
51
+ # Fetch labels when the script starts
52
+ prefetch_labels()
53
 
54
  def load_convert_image_to_tensor(image):
55
+ """
56
+ Takes in a Gradio image (numpy or file path), converts it to
57
+ a PyTorch tensor, and applies the standard transforms for ImageNet models.
58
+ """
59
  if isinstance(image, np.ndarray):
60
  image = Image.fromarray(image.astype('uint8'), 'RGB')
61
  elif isinstance(image, str):
62
  image = Image.open(image).convert('RGB')
63
+
64
+ # Note: We normalize with the same mean/std used for ImageNet-trained models
65
+ transform_pipeline = transforms.Compose([
66
+ transforms.Resize((224, 224)),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(
69
+ mean=[0.485, 0.456, 0.406], # ImageNet means
70
+ std=[0.229, 0.224, 0.225] # ImageNet stds
71
+ )
72
+ ])
73
+
74
+ tensor = transform_pipeline(image).unsqueeze(0).to(device)
75
+ return tensor
76
+
77
+ def get_human_readable_label_for_class_id(class_id):
78
+ """
79
+ Convert a class ID (0-999) into a human-readable label
80
+ based on ImageNet categories.
81
+ """
82
+ if LABELS_CACHE is None or class_id >= len(LABELS_CACHE):
83
+ return f"Unknown class ID: {class_id}"
84
+ return LABELS_CACHE[class_id]
85
 
86
  def classify_image(image, confidence_threshold=0.0):
87
+ """
88
+ Classify the input image (via Gradio) into ImageNet classes,
89
+ returning top-3 predictions that exceed the confidence threshold.
90
+ """
91
+ if LABELS_CACHE is None:
92
+ return "Error: ImageNet labels not loaded."
93
+
94
  try:
95
+ # Convert image to a normalized tensor
96
+ input_tensor = load_convert_image_to_tensor(image)
97
+
98
+ # Forward pass through the model
99
  with torch.no_grad():
100
+ output = model(input_tensor)
101
+
102
+ # Compute softmax probabilities
103
+ probabilities = torch.nn.functional.softmax(output, dim=1)
104
+
105
+ # Get top-3 predictions
106
+ top_probs, top_classes = torch.topk(probabilities, 3)
107
 
108
+ # Move to CPU and convert to numpy for easy handling
109
  top_probs = top_probs.cpu().numpy()[0]
110
  top_classes = top_classes.cpu().numpy()[0]
111
 
112
+ # Build a result dict
113
+ results = {}
114
+ for prob, class_id in zip(top_probs, top_classes):
115
  if prob >= confidence_threshold:
116
+ label = get_human_readable_label_for_class_id(int(class_id))
117
+ results[label] = float(prob)
118
+
119
+ # If nothing meets the threshold, return a message
120
+ if not results:
121
+ return "No predictions above the confidence threshold."
122
+ return results
123
+
124
  except Exception as e:
125
+ return f"Error during classification: {str(e)}"
126
 
127
+ # -----------------------------
128
+ # BUILD THE GRADIO INTERFACE
129
+ # -----------------------------
130
  image_input = gr.Image()
131
+ confidence_slider = gr.Slider(
132
+ minimum=0.0,
133
+ maximum=1.0,
134
+ value=0.0, # default threshold
135
+ label="Confidence Threshold"
136
+ )
137
  label_output = gr.Label(num_top_classes=3)
138
+
139
  interface = gr.Interface(
140
+ fn=classify_image, # Function to call for classification
141
  inputs=[image_input, confidence_slider],
142
  outputs=label_output,
143
+ title="VGG16 ImageNet Classifier",
144
+ description="Upload an image to see the top ImageNet predictions from a pretrained VGG16 model."
145
  )
146
 
147
+ # -----------------------------
148
+ # LAUNCH THE APP
149
+ # -----------------------------
150
  interface.launch(share=True)
151
+