paddeh commited on
Commit
5106f97
·
1 Parent(s): fa0ab0e

Import code from training notebook

Browse files
Files changed (4) hide show
  1. app.py +43 -3
  2. classes.json +10 -0
  3. functions.py +119 -0
  4. requirements.txt +1 -0
app.py CHANGED
@@ -1,13 +1,53 @@
1
  import gradio as gr
2
  from transformers import AutoModelForImageClassification, AutoImageProcessor
3
- from PIL import Image
4
  import torch
 
 
5
 
6
- model_name = "paddeh/is-it-max"
7
 
 
 
 
 
 
 
 
 
8
  model = AutoModelForImageClassification.from_pretrained(model_name)
9
  processor = AutoImageProcessor.from_pretrained(model_name)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def classify_image(image):
12
  inputs = processor(images=image, return_tensors="pt")
13
  with torch.no_grad():
@@ -15,5 +55,5 @@ def classify_image(image):
15
  predicted_class = logits.argmax(-1).item()
16
  return f"Predicted class: {predicted_class}"
17
 
18
- iface = gr.Interface(fn=classify_image, inputs="image", outputs="text")
19
  iface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForImageClassification, AutoImageProcessor
 
3
  import torch
4
+ from torchvision import transforms, models
5
+ from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
6
 
 
7
 
8
+ from .functions import import_class_labels, segment_image, crop_dog
9
+
10
+ # Load DeepLabV3 model for segmentation
11
+ seg_model = models.segmentation \
12
+ .deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT)
13
+
14
+ # Load trained model and feature extractor
15
+ model_name = "paddeh/is-it-max"
16
  model = AutoModelForImageClassification.from_pretrained(model_name)
17
  processor = AutoImageProcessor.from_pretrained(model_name)
18
 
19
+ class_labels = import_class_labels('./')
20
+
21
+ # Define image transformations
22
+ transform = transforms.Compose([
23
+ transforms.Resize(model_img_size, interpolation=transforms.InterpolationMode.BICUBIC),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
26
+ ])
27
+
28
+ def classify_image_with_cropping(image):
29
+ # 1. Segment the image
30
+ image, mask = segment_image(image, seg_model)
31
+
32
+ if mask is None:
33
+ print(f"Skipping due to failed segmentation.")
34
+ return None, 'unknown'
35
+
36
+ # 2. Crop to the dog (if found)
37
+ cropped_image = crop_dog(image, mask)
38
+
39
+ # 3. Preprocess and classify the cropped image
40
+ input_tensor = transform(cropped_image).unsqueeze(0).to(device)
41
+
42
+ with torch.no_grad():
43
+ outputs = model(input_tensor)
44
+
45
+ predicted_class_idx = outputs.logits.argmax(-1).item()
46
+ predicted_label = class_labels[predicted_class_idx]
47
+
48
+ return cropped_image, f"Predicted class: {predicted_class}"
49
+
50
+
51
  def classify_image(image):
52
  inputs = processor(images=image, return_tensors="pt")
53
  with torch.no_grad():
 
55
  predicted_class = logits.argmax(-1).item()
56
  return f"Predicted class: {predicted_class}"
57
 
58
+ iface = gr.Interface(fn=classify_image_with_cropping, inputs="image", outputs="image, text")
59
  iface.launch()
classes.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "class_names": [
3
+ "max",
4
+ "not_max"
5
+ ],
6
+ "class_to_idx": {
7
+ "max": 0,
8
+ "not_max": 1
9
+ }
10
+ }
functions.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from torchvision import transforms
5
+ import numpy as np
6
+ import cv2
7
+ import skimage.segmentation as seg
8
+
9
+ dog_class = 12
10
+
11
+
12
+ def import_class_labels(model_path):
13
+ """Imports class labels from the classes.json file, ensuring correct sorting."""
14
+ classes_file_path = os.path.join(model_path, "classes.json")
15
+
16
+ with open(classes_file_path, "r") as f:
17
+ class_data = json.load(f)
18
+
19
+ # Get class names and their original indices
20
+ class_names = class_data["class_names"]
21
+ class_to_idx = class_data["class_to_idx"]
22
+
23
+ # Create a list of (index, class_name) tuples
24
+ idx_class_pairs = [(idx, class_name) for class_name, idx in class_to_idx.items()]
25
+
26
+ # Sort the list by index to ensure the correct order
27
+ idx_class_pairs.sort(key=lambda item: item[0])
28
+
29
+ # Extract the sorted class names
30
+ sorted_class_names = [class_name for _, class_name in idx_class_pairs]
31
+
32
+ return sorted_class_names
33
+
34
+
35
+ def refine_dog_mask(mask, image):
36
+
37
+ # Merge all dog segments together
38
+ dog_mask = np.zeros_like(mask, dtype=np.uint8)
39
+ for class_id in np.unique(mask):
40
+ if class_id == 12: # Dog class
41
+ dog_mask[mask == class_id] = 1
42
+
43
+ # Apply morphological operations to connect fragmented segments
44
+ kernel = np.ones((15, 15), np.uint8)
45
+ dog_mask = cv2.morphologyEx(dog_mask, cv2.MORPH_CLOSE, kernel) # Close gaps
46
+ dog_mask = cv2.dilate(dog_mask, kernel, iterations=2) # Expand segmentation
47
+
48
+ # Refine mask using superpixel segmentation
49
+ segments = seg.slic(np.array(image), n_segments=100, compactness=10)
50
+ refined_dog_mask = np.where(dog_mask == 1, segments, 0)
51
+
52
+ # Restore the dog class label (12) in refined regions
53
+ refined_dog_mask[dog_mask == 1] = dog_class
54
+
55
+ # Restore the dog class label (12) in refined regions
56
+ mask[refined_dog_mask > 0] = dog_class
57
+
58
+ # Convert mask to np.uint8 if necessary
59
+ return mask.astype(np.uint8)
60
+
61
+
62
+ def segment_image(image, seg_model):
63
+ image = image.convert("RGB")
64
+ orig_size = image.size
65
+ transform = transforms.Compose([
66
+ transforms.ToTensor()
67
+ ])
68
+ image_tensor = transform(image).unsqueeze(0).to(device)
69
+
70
+ with torch.no_grad():
71
+ output = seg_model(image_tensor)['out'][0]
72
+ mask = output.argmax(0) # Keep on GPU
73
+
74
+ # Dynamically determine the main object class
75
+ unique_classes = mask.unique()
76
+ unique_classes = unique_classes[unique_classes != 0] # Remove background class (0)
77
+ if len(unique_classes) == 0:
78
+ print(f'No segmentation found for {image_path}')
79
+ return image, None # Skip image if no valid segmentation found
80
+
81
+ mask = mask.cpu().numpy() # Move to CPU only when needed
82
+ mask = refine_dog_mask(mask, image)
83
+
84
+ return image, mask
85
+
86
+ def crop_dog(image, mask, target_aspect=1, padding=20):
87
+ # Get bounding box of the dog
88
+ y_indices, x_indices = np.where(mask == dog_class) # Dog class pixels
89
+ if len(y_indices) == 0 or len(x_indices) == 0:
90
+ return image # No dog detected
91
+
92
+ x_min, x_max = x_indices.min(), x_indices.max()
93
+ y_min, y_max = y_indices.min(), y_indices.max()
94
+
95
+ # Calculate aspect ratio of resize target
96
+ width = x_max - x_min
97
+ height = y_max - y_min
98
+ current_aspect = width / height
99
+
100
+ # Adjust bounding box to match target aspect ratio
101
+ if current_aspect > target_aspect:
102
+ new_height = width / target_aspect
103
+ diff = (new_height - height) / 2
104
+ y_min = max(0, int(y_min - diff))
105
+ y_max = min(mask.shape[0], int(y_max + diff))
106
+ else:
107
+ new_width = height * target_aspect
108
+ diff = (new_width - width) / 2
109
+ x_min = max(0, int(x_min - diff))
110
+ x_max = min(mask.shape[1], int(x_max + diff))
111
+
112
+ # Apply padding
113
+ x_min = max(0, x_min - padding)
114
+ x_max = min(mask.shape[1], x_max + padding)
115
+ y_min = max(0, y_min - padding)
116
+ y_max = min(mask.shape[0], y_max + padding)
117
+
118
+ cropped_image = image.crop((x_min, y_min, x_max, y_max))
119
+ return cropped_image
requirements.txt CHANGED
@@ -2,3 +2,4 @@ transformers
2
  torch
3
  gradio
4
  Pillow
 
 
2
  torch
3
  gradio
4
  Pillow
5
+ torchvision