rtik007 commited on
Commit
bd3eb35
·
verified ·
1 Parent(s): 8c51c26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -71
app.py CHANGED
@@ -19,10 +19,11 @@ from PIL import Image
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- # Load a vanilla VGG16 model pretrained on ImageNet
23
  model = models.vgg16(weights="IMAGENET1K_V1").to(device)
24
  model.eval()
25
 
 
26
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
27
  try:
28
  LABELS_CACHE = requests.get(LABELS_URL, timeout=5).json()
@@ -30,13 +31,13 @@ except Exception as e:
30
  print(f"Could not fetch ImageNet labels: {e}")
31
  LABELS_CACHE = [f"Class {i}" for i in range(1000)]
32
 
33
- # Transformation pipeline for input images
34
  transform_pipeline = transforms.Compose([
35
  transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
37
  transforms.Normalize(
38
- mean=[0.485, 0.456, 0.406], # ImageNet means
39
- std=[0.229, 0.224, 0.225] # ImageNet std
40
  )
41
  ])
42
 
@@ -45,117 +46,86 @@ transform_pipeline = transforms.Compose([
45
  # -----------------------------
46
  def classify_image(image, confidence_threshold=0.0):
47
  """
48
- Classify an image using pretrained VGG16 on ImageNet.
49
- Returns the top-3 predictions above confidence_threshold.
50
  """
51
  try:
52
- # Convert Gradio's image (numpy) to PIL
53
  if isinstance(image, np.ndarray):
54
  image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
55
  else:
56
  image_pil = Image.open(image).convert('RGB')
57
 
58
- # Apply preprocessing
59
  input_tensor = transform_pipeline(image_pil).unsqueeze(0).to(device)
60
 
61
  # Inference
62
  with torch.no_grad():
63
  output = model(input_tensor)
64
- probabilities = torch.nn.functional.softmax(output, dim=1)
65
 
66
  # Top-3 predictions
67
- top_probs, top_classes = torch.topk(probabilities, 3)
68
  top_probs = top_probs[0].cpu().numpy()
69
- top_classes = top_classes[0].cpu().numpy()
70
 
71
- # Build dictionary with label -> probability
72
  results = {}
73
- for prob, cls_idx in zip(top_probs, top_classes):
74
- if prob >= confidence_threshold:
75
- label = LABELS_CACHE[cls_idx] if LABELS_CACHE else f"Class {cls_idx}"
76
- results[label] = float(prob)
77
 
78
  if not results:
79
  return "No predictions above the confidence threshold."
80
  return results
 
81
  except Exception as e:
82
  return f"Error during classification: {str(e)}"
83
 
84
  # -----------------------------
85
- # CUSTOM CSS FOR BACKGROUND
86
  # -----------------------------
87
- # Replace the background color/image/gradient with whatever you prefer.
88
- # You can also style text, buttons, etc.
89
  custom_css = """
90
  body {
91
  margin: 0;
92
  padding: 0;
93
- background: linear-gradient(135deg, #f2f2f2, #dceeff);
94
- font-family: 'Arial', sans-serif;
95
- }
96
-
97
- #title {
98
- font-size: 2.5rem;
99
- text-align: center;
100
- margin-top: 20px;
101
- font-weight: bold;
102
- color: #333;
103
  }
104
-
105
- #subtext {
106
  text-align: center;
107
- font-size: 1rem;
108
- color: #555;
109
- margin-bottom: 20px;
110
  }
111
  """
112
 
113
  # -----------------------------
114
- # BUILD GRADIO INTERFACE
115
  # -----------------------------
116
- def build_interface():
117
- # Inputs
118
- image_input = gr.Image(type="numpy", label="Upload an Image")
119
- confidence_slider = gr.Slider(
120
- minimum=0.0,
121
- maximum=1.0,
122
- value=0.0,
123
- step=0.01,
124
- label="Confidence Threshold"
125
- )
126
 
127
- # Outputs
128
- label_output = gr.Label(num_top_classes=3)
 
 
 
129
 
130
- # An optional HTML block (header text, etc.)
131
- with gr.Blocks(css=custom_css) as demo:
132
- gr.HTML("<h1 id='title'>VGG16 ImageNet Classifier</h1>")
133
- gr.HTML("<p id='subtext'>Upload an image to see top ImageNet predictions from a pretrained VGG16 model.</p>")
134
-
135
- # Main interface
136
- with gr.Row():
137
- with gr.Column():
138
- image_in = image_input
139
- conf_slider = confidence_slider
140
- with gr.Column():
141
- label_out = label_output
142
-
143
- # Create the main Interface
144
- btn = gr.Button("Classify")
145
- btn.click(
146
  fn=classify_image,
147
- inputs=[image_in, conf_slider],
148
- outputs=label_out
149
  )
150
 
151
  return demo
152
 
153
  # -----------------------------
154
- # LAUNCH APP
155
  # -----------------------------
156
  if __name__ == "__main__":
157
- interface = build_interface()
158
- # You can set a Gradio theme if you want (e.g., 'Soft', 'Monochrome', 'Glass')
159
- interface.launch(share=True) # share=True if you want a shareable public link locally
160
-
161
-
 
19
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # Load vanilla VGG16 pretrained on ImageNet
23
  model = models.vgg16(weights="IMAGENET1K_V1").to(device)
24
  model.eval()
25
 
26
+ # Download ImageNet labels
27
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
28
  try:
29
  LABELS_CACHE = requests.get(LABELS_URL, timeout=5).json()
 
31
  print(f"Could not fetch ImageNet labels: {e}")
32
  LABELS_CACHE = [f"Class {i}" for i in range(1000)]
33
 
34
+ # Transform pipeline
35
  transform_pipeline = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
38
  transforms.Normalize(
39
+ mean=[0.485, 0.456, 0.406],
40
+ std=[0.229, 0.224, 0.225]
41
  )
42
  ])
43
 
 
46
  # -----------------------------
47
  def classify_image(image, confidence_threshold=0.0):
48
  """
49
+ Classify an image using the pretrained VGG16 on ImageNet.
50
+ Returns top-3 predictions above the given confidence_threshold.
51
  """
52
  try:
53
+ # Convert Gradio's numpy image to PIL
54
  if isinstance(image, np.ndarray):
55
  image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
56
  else:
57
  image_pil = Image.open(image).convert('RGB')
58
 
59
+ # Preprocess
60
  input_tensor = transform_pipeline(image_pil).unsqueeze(0).to(device)
61
 
62
  # Inference
63
  with torch.no_grad():
64
  output = model(input_tensor)
65
+ probs = torch.nn.functional.softmax(output, dim=1)
66
 
67
  # Top-3 predictions
68
+ top_probs, top_cls_idxs = torch.topk(probs, 3)
69
  top_probs = top_probs[0].cpu().numpy()
70
+ top_cls_idxs = top_cls_idxs[0].cpu().numpy()
71
 
 
72
  results = {}
73
+ for p, cidx in zip(top_probs, top_cls_idxs):
74
+ if p >= confidence_threshold:
75
+ label = LABELS_CACHE[cidx] if LABELS_CACHE else f"Class {cidx}"
76
+ results[label] = float(p)
77
 
78
  if not results:
79
  return "No predictions above the confidence threshold."
80
  return results
81
+
82
  except Exception as e:
83
  return f"Error during classification: {str(e)}"
84
 
85
  # -----------------------------
86
+ # (OPTIONAL) CUSTOM CSS
87
  # -----------------------------
 
 
88
  custom_css = """
89
  body {
90
  margin: 0;
91
  padding: 0;
92
+ background: linear-gradient(135deg, #f6f9fc, #ddeefc);
93
+ font-family: "Helvetica", sans-serif;
 
 
 
 
 
 
 
 
94
  }
95
+ h1, p {
 
96
  text-align: center;
97
+ margin-bottom: 1rem;
 
 
98
  }
99
  """
100
 
101
  # -----------------------------
102
+ # BUILD THE GRADIO APP
103
  # -----------------------------
104
+ def build_app():
105
+ with gr.Blocks(css=custom_css) as demo:
106
+ gr.HTML("<h1>VGG16 ImageNet Classifier</h1>")
107
+ gr.HTML("<p>Upload an image to see the top 3 predicted ImageNet classes.</p>")
 
 
 
 
 
 
108
 
109
+ with gr.Box():
110
+ # Place widgets in a vertical layout
111
+ image_input = gr.Image(type="numpy", label="Upload an Image")
112
+ confidence_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Confidence Threshold")
113
+ classify_button = gr.Button("Classify")
114
 
115
+ label_output = gr.Label(num_top_classes=3, label="Prediction Results")
116
+
117
+ # Connect button click to classification
118
+ classify_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
119
  fn=classify_image,
120
+ inputs=[image_input, confidence_slider],
121
+ outputs=label_output
122
  )
123
 
124
  return demo
125
 
126
  # -----------------------------
127
+ # LAUNCH
128
  # -----------------------------
129
  if __name__ == "__main__":
130
+ demo = build_app()
131
+ demo.launch()