rtik007 commited on
Commit
3bd7f39
·
verified ·
1 Parent(s): 2060a55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -63
app.py CHANGED
@@ -5,89 +5,81 @@ Neural Network (CNN) within PyTorch framework. Additionally, Gradio is used to b
5
  interface for easy image uploads and breed predictions.
6
  '''
7
 
8
- #!pip install torch torchvision
9
- #!pip install matplotlib
10
- #!pip install gradio
11
-
12
  import numpy as np
13
  import torch
14
  import torchvision.models as models
15
- from PIL import Image
16
- import torchvision.transforms as transforms
17
- import requests
18
- import gradio as gr
19
- import os
20
-
21
- # Load pretrained VGG16 model
22
- VGG16 = models.vgg16(weights="IMAGENET1K_V1")
23
- use_cuda = torch.cuda.is_available()
24
- if use_cuda:
25
- VGG16 = VGG16.cuda()
26
-
27
- # Global cache for labels
28
- LABELS_CACHE = None
29
-
30
- def prefetch_labels():
31
- global LABELS_CACHE
32
- LABELS_MAP_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
33
- try:
34
- LABELS_CACHE = requests.get(LABELS_MAP_URL, timeout=5).json()
35
- except requests.exceptions.RequestException as e:
36
- LABELS_CACHE = None
37
- print(f"Error fetching labels: {e}")
38
-
39
- # Fetch labels on startup
40
- prefetch_labels()
 
41
 
42
  def load_convert_image_to_tensor(image):
 
43
  if isinstance(image, np.ndarray):
44
  image = Image.fromarray(image.astype('uint8'), 'RGB')
45
  elif isinstance(image, str):
46
  image = Image.open(image).convert('RGB')
47
-
48
- in_transform = transforms.Compose([
49
- transforms.Resize(size=(224, 224)),
50
- transforms.ToTensor()
51
- ])
52
- image = in_transform(image)[:3, :, :].unsqueeze(0)
53
- return image
54
-
55
- def get_human_readable_label_for_class_id(class_id, labels_cache=None):
56
- if labels_cache is None or class_id >= len(labels_cache):
57
- return f"Unknown class ID: {class_id}"
58
- return labels_cache[class_id]
59
 
60
  def classify_image(image, confidence_threshold=0.0):
61
- global LABELS_CACHE
62
- if LABELS_CACHE is None:
63
- return "Error: Labels not loaded"
64
-
65
  try:
66
  image_tensor = load_convert_image_to_tensor(image)
67
- if use_cuda:
68
- image_tensor = image_tensor.cuda()
 
 
69
 
70
- output = VGG16(image_tensor)
71
- softmax_output = torch.nn.functional.softmax(output, dim=1)
72
- top_probs, top_classes = torch.topk(softmax_output, 3)
73
- top_probs = top_probs.cpu().detach().numpy() if use_cuda else top_probs.detach().numpy()
74
- top_classes = top_classes.cpu().detach().numpy() if use_cuda else top_classes.detach().numpy()
75
 
76
  result = {}
77
- for prob, cls_id in zip(top_probs[0], top_classes[0]):
78
  if prob >= confidence_threshold:
79
- label = get_human_readable_label_for_class_id(int(cls_id), LABELS_CACHE)
80
- result[label] = prob
81
  return result if result else "No predictions above the confidence threshold."
82
  except Exception as e:
83
  return f"Error: {str(e)}"
84
 
85
- # Gradio Interface
86
  image_input = gr.Image()
87
- confidence_slider = gr.Slider(0, 1, 0.0, label="Confidence Threshold (Optional)") # Changed this line
88
- label_output = gr.Label(num_top_classes=3)
89
 
90
- interface = gr.Interface(fn=classify_image, inputs=[image_input, confidence_slider], outputs=label_output)
91
-
92
- # Launch Gradio with shareable link
93
- interface.launch(share=True)
 
 
 
 
 
 
 
5
  interface for easy image uploads and breed predictions.
6
  '''
7
 
8
+ import gradio as gr
9
+ import torchvision.transforms as transforms
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
  import numpy as np
13
  import torch
14
  import torchvision.models as models
15
+ import torch.nn as nn
16
+
17
+ # 1. Load your fine-tuned model
18
+ num_breeds = 120 # Example: 120 dog breeds
19
+ DOG_BREEDS = ["Chihuahua", "Japanese Spaniel", ..., "Mastiff"] # etc., in correct order
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # Start from VGG16 base
24
+ fine_tuned_model = models.vgg16(weights="IMAGENET1K_V1")
25
+ # Replace classifier (should match whatever you used during training)
26
+ fine_tuned_model.classifier[-1] = nn.Linear(in_features=4096, out_features=num_breeds)
27
+ fine_tuned_model.to(device)
28
+
29
+ # Load the trained weights
30
+ fine_tuned_model.load_state_dict(torch.load("dog_breed_vgg16.pth", map_location=device))
31
+ fine_tuned_model.eval()
32
+
33
+ # 2. Define transforms, including normalization
34
+ in_transform = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(
38
+ mean=[0.485, 0.456, 0.406], # ImageNet means
39
+ std=[0.229, 0.224, 0.225] # ImageNet std
40
+ )
41
+ ])
42
 
43
  def load_convert_image_to_tensor(image):
44
+ """Converts image (numpy/PIL) to a PyTorch tensor, normalized for VGG16."""
45
  if isinstance(image, np.ndarray):
46
  image = Image.fromarray(image.astype('uint8'), 'RGB')
47
  elif isinstance(image, str):
48
  image = Image.open(image).convert('RGB')
49
+ return in_transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def classify_image(image, confidence_threshold=0.0):
52
+ """Classify the image as one of the dog breeds."""
 
 
 
53
  try:
54
  image_tensor = load_convert_image_to_tensor(image)
55
+ with torch.no_grad():
56
+ output = fine_tuned_model(image_tensor)
57
+ softmax_output = F.softmax(output, dim=1)
58
+ top_probs, top_classes = torch.topk(softmax_output, 3)
59
 
60
+ top_probs = top_probs.cpu().numpy()[0]
61
+ top_classes = top_classes.cpu().numpy()[0]
 
 
 
62
 
63
  result = {}
64
+ for prob, cls_id in zip(top_probs, top_classes):
65
  if prob >= confidence_threshold:
66
+ breed_label = DOG_BREEDS[cls_id]
67
+ result[breed_label] = float(prob)
68
  return result if result else "No predictions above the confidence threshold."
69
  except Exception as e:
70
  return f"Error: {str(e)}"
71
 
72
+ # Gradio interface
73
  image_input = gr.Image()
74
+ confidence_slider = gr.Slider(0, 1, value=0.1, label="Confidence Threshold")
 
75
 
76
+ label_output = gr.Label(num_top_classes=3)
77
+ interface = gr.Interface(
78
+ fn=classify_image,
79
+ inputs=[image_input, confidence_slider],
80
+ outputs=label_output,
81
+ title="Dog Breed Classifier",
82
+ description="Upload an image of a dog to see the predicted breed(s)."
83
+ )
84
+
85
+ interface.launch(share=True)