visualise-segmentation

#1
by paddeh - opened
Files changed (6) hide show
  1. .gitignore +3 -1
  2. app.py +36 -53
  3. classification.py +38 -0
  4. functions.py +0 -95
  5. requirements.txt +1 -0
  6. segmentation.py +156 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  venv/
2
  __pycache__/
3
- .gradio/
 
 
 
1
  venv/
2
  __pycache__/
3
+ .gradio/
4
+
5
+ *.iml
app.py CHANGED
@@ -1,72 +1,55 @@
1
  import gradio as gr
2
- from transformers import AutoModelForImageClassification, AutoImageProcessor
3
- import torch
4
- from torchvision import transforms, models
5
- from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
6
  import numpy as np
7
  from PIL import Image
8
 
9
- from functions import import_class_labels, segment_image, crop_dog
 
10
 
11
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
- print(f"Using device: {device}")
13
 
14
- # Load DeepLabV3 model for segmentation
15
- seg_model = models.segmentation \
16
- .deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT) \
17
- .to(device) \
18
- .eval()
19
 
20
- # Load trained model and feature extractor
21
- model_name = "paddeh/is-it-max"
22
- model_img_size = (224,224)
23
- model = AutoModelForImageClassification.from_pretrained(model_name) \
24
- .to(device) \
25
- .eval()
26
- processor = AutoImageProcessor.from_pretrained(model_name)
27
 
28
- class_labels = import_class_labels('./')
29
-
30
- # Define image transformations
31
- transform = transforms.Compose([
32
- transforms.Resize(model_img_size, interpolation=transforms.InterpolationMode.BICUBIC),
33
- transforms.ToTensor(),
34
- transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
35
- ])
36
-
37
- def classify_image_with_cropping(image):
38
- if isinstance(image, np.ndarray):
39
- image = Image.fromarray(image) # Convert ndarray to PIL Image
40
 
41
  # 1. Segment the image
42
- print("Segmenting...")
43
- image, mask = segment_image(image, seg_model)
44
-
45
- if mask is None:
46
- print(f"Skipping due to failed segmentation.")
47
- return None, 'unknown'
48
-
49
- # 2. Crop to the dog (if found)
50
- print("Cropping...")
51
- cropped_image = crop_dog(image, mask)
 
 
 
 
 
 
52
 
53
  # 3. Preprocess and classify the cropped image
54
- input_tensor = transform(cropped_image).unsqueeze(0).to(device)
 
55
 
56
- print("Running model...")
57
- with torch.no_grad():
58
- outputs = model(input_tensor)
59
 
60
- predicted_class_idx = outputs.logits.argmax(-1).item()
61
- predicted_label = class_labels[predicted_class_idx]
62
-
63
- return cropped_image, f"Predicted class: {predicted_label}"
64
 
65
  iface = gr.Interface(
66
  fn=classify_image_with_cropping,
67
- inputs="image",
68
- outputs=[gr.Image(type="pil"), gr.Text()]
 
 
 
69
  )
70
- iface.launch()
71
 
72
- # TODO: Add option to visualise segmentation step
 
1
  import gradio as gr
 
 
 
 
2
  import numpy as np
3
  from PIL import Image
4
 
5
+ from segmentation import segment_image, crop_dog, visualize_segmentation
6
+ from classification import classify
7
 
8
+ # config
9
+ pre_scale_size = (2048, 2048)
10
 
 
 
 
 
 
11
 
12
+ def classify_image_with_cropping(original_image, pre_segment):
13
+ if isinstance(original_image, np.ndarray):
14
+ original_image = Image.fromarray(original_image) # Convert ndarray to PIL Image
 
 
 
 
15
 
16
+ # 1. Pre-scale
17
+ if original_image.width > pre_scale_size[0] or original_image.height > pre_scale_size[1]:
18
+ original_image.thumbnail(pre_scale_size, Image.LANCZOS)
 
 
 
 
 
 
 
 
 
19
 
20
  # 1. Segment the image
21
+ if pre_segment:
22
+ print("Segmenting...")
23
+ segmented_image, mask = segment_image(original_image)
24
+
25
+ if mask is not None:
26
+ # 2. Crop to the dog (if found)
27
+ print("Cropping...")
28
+ visualised_image = visualize_segmentation(original_image, mask)
29
+ cropped_image = original_image
30
+ else:
31
+ print(f"Failed segmentation, using original image")
32
+ visualised_image = None
33
+ cropped_image = crop_dog(segmented_image, mask)
34
+ else:
35
+ visualised_image = None
36
+ cropped_image = original_image
37
 
38
  # 3. Preprocess and classify the cropped image
39
+ print("Running classifier...")
40
+ predicted_class_idx, predicted_label = classify(cropped_image)
41
 
42
+ print("Done.")
43
+ return visualised_image, cropped_image, predicted_label
 
44
 
 
 
 
 
45
 
46
  iface = gr.Interface(
47
  fn=classify_image_with_cropping,
48
+ inputs=[gr.Image(type="pil"),
49
+ gr.Checkbox(label="Try to isolate dog (pre-segmentation)", value=True)],
50
+ outputs=[gr.Image(type="pil", label="Segmented image"),
51
+ gr.Image(type="pil", label="Predicted image"),
52
+ gr.Textbox(label="Predicated class")]
53
  )
 
54
 
55
+ iface.launch()
classification.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
2
+ import torch
3
+ from torchvision import transforms, models
4
+
5
+ from functions import import_class_labels
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+ print(f"Using device {device} for classification")
9
+
10
+ model_img_size = (224, 224)
11
+ class_labels = import_class_labels('./')
12
+
13
+ # Load trained model and feature extractor
14
+ model_name = "paddeh/is-it-max"
15
+ print(f"Loading classifier model {model_name}")
16
+ model = AutoModelForImageClassification.from_pretrained(model_name) \
17
+ .to(device) \
18
+ .eval()
19
+ processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
20
+
21
+ # Define image transformations
22
+ transform = transforms.Compose([
23
+ transforms.Resize(model_img_size, interpolation=transforms.InterpolationMode.BICUBIC),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
26
+ ])
27
+
28
+
29
+ def classify(image):
30
+ input_tensor = transform(image).unsqueeze(0).to(device)
31
+
32
+ with torch.no_grad():
33
+ outputs = model(input_tensor)
34
+
35
+ predicted_class_idx = outputs.logits.argmax(-1).item()
36
+ predicted_label = class_labels[predicted_class_idx]
37
+
38
+ return predicted_class_idx, predicted_label
functions.py CHANGED
@@ -1,13 +1,5 @@
1
  import os
2
  import json
3
- import torch
4
- from torchvision import transforms
5
- import numpy as np
6
- import cv2
7
- import skimage.segmentation as seg
8
-
9
- dog_class = 12
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
 
13
  def import_class_labels(model_path):
@@ -31,90 +23,3 @@ def import_class_labels(model_path):
31
  sorted_class_names = [class_name for _, class_name in idx_class_pairs]
32
 
33
  return sorted_class_names
34
-
35
-
36
- def refine_dog_mask(mask, image):
37
-
38
- # Merge all dog segments together
39
- dog_mask = np.zeros_like(mask, dtype=np.uint8)
40
- for class_id in np.unique(mask):
41
- if class_id == 12: # Dog class
42
- dog_mask[mask == class_id] = 1
43
-
44
- # Apply morphological operations to connect fragmented segments
45
- kernel = np.ones((15, 15), np.uint8)
46
- dog_mask = cv2.morphologyEx(dog_mask, cv2.MORPH_CLOSE, kernel) # Close gaps
47
- dog_mask = cv2.dilate(dog_mask, kernel, iterations=2) # Expand segmentation
48
-
49
- # Refine mask using superpixel segmentation
50
- segments = seg.slic(np.array(image), n_segments=100, compactness=10)
51
- refined_dog_mask = np.where(dog_mask == 1, segments, 0)
52
-
53
- # Restore the dog class label (12) in refined regions
54
- refined_dog_mask[dog_mask == 1] = dog_class
55
-
56
- # Restore the dog class label (12) in refined regions
57
- mask[refined_dog_mask > 0] = dog_class
58
-
59
- # Convert mask to np.uint8 if necessary
60
- return mask.astype(np.uint8)
61
-
62
-
63
- def segment_image(image, seg_model):
64
- image = image.convert("RGB")
65
- orig_size = image.size
66
- transform = transforms.Compose([
67
- transforms.ToTensor()
68
- ])
69
- image_tensor = transform(image).unsqueeze(0).to(device)
70
-
71
- with torch.no_grad():
72
- output = seg_model(image_tensor)['out'][0]
73
- mask = output.argmax(0) # Keep on GPU
74
-
75
- # Dynamically determine the main object class
76
- unique_classes = mask.unique()
77
- unique_classes = unique_classes[unique_classes != 0] # Remove background class (0)
78
- if len(unique_classes) == 0:
79
- print(f'No segmentation found')
80
- return image, None # Skip image if no valid segmentation found
81
-
82
- mask = mask.cpu().numpy() # Move to CPU only when needed
83
- mask = refine_dog_mask(mask, image)
84
-
85
- return image, mask
86
-
87
- def crop_dog(image, mask, target_aspect=1, padding=20):
88
- # Get bounding box of the dog
89
- y_indices, x_indices = np.where(mask == dog_class) # Dog class pixels
90
- if len(y_indices) == 0 or len(x_indices) == 0:
91
- return image # No dog detected
92
-
93
- x_min, x_max = x_indices.min(), x_indices.max()
94
- y_min, y_max = y_indices.min(), y_indices.max()
95
-
96
- # Calculate aspect ratio of resize target
97
- width = x_max - x_min
98
- height = y_max - y_min
99
- current_aspect = width / height
100
-
101
- # Adjust bounding box to match target aspect ratio
102
- if current_aspect > target_aspect:
103
- new_height = width / target_aspect
104
- diff = (new_height - height) / 2
105
- y_min = max(0, int(y_min - diff))
106
- y_max = min(mask.shape[0], int(y_max + diff))
107
- else:
108
- new_width = height * target_aspect
109
- diff = (new_width - width) / 2
110
- x_min = max(0, int(x_min - diff))
111
- x_max = min(mask.shape[1], int(x_max + diff))
112
-
113
- # Apply padding
114
- x_min = max(0, x_min - padding)
115
- x_max = min(mask.shape[1], x_max + padding)
116
- y_min = max(0, y_min - padding)
117
- y_max = min(mask.shape[0], y_max + padding)
118
-
119
- cropped_image = image.crop((x_min, y_min, x_max, y_max))
120
- return cropped_image
 
1
  import os
2
  import json
 
 
 
 
 
 
 
 
3
 
4
 
5
  def import_class_labels(model_path):
 
23
  sorted_class_names = [class_name for _, class_name in idx_class_pairs]
24
 
25
  return sorted_class_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,3 +5,4 @@ Pillow
5
  torchvision
6
  opencv-python-headless
7
  scikit-image
 
 
5
  torchvision
6
  opencv-python-headless
7
  scikit-image
8
+ numpy
segmentation.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms, models
3
+ from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
4
+ import numpy as np
5
+ import cv2
6
+ import skimage.segmentation as seg
7
+ from PIL import Image, ImageDraw, ImageFont
8
+
9
+ dog_class = 12
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ print(f"Using device {device} for segmentation")
13
+
14
+ # Load DeepLabV3 model for segmentation
15
+ print("Loading resnet101 segmentation model...")
16
+ seg_model = models.segmentation \
17
+ .deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT) \
18
+ .to(device) \
19
+ .eval()
20
+
21
+
22
+ def refine_dog_mask(mask, image):
23
+ # Merge all dog segments together
24
+ dog_mask = np.zeros_like(mask, dtype=np.uint8)
25
+ for class_id in np.unique(mask):
26
+ if class_id == 12: # Dog class
27
+ dog_mask[mask == class_id] = 1
28
+
29
+ # Apply morphological operations to connect fragmented segments
30
+ kernel = np.ones((15, 15), np.uint8)
31
+ dog_mask = cv2.morphologyEx(dog_mask, cv2.MORPH_CLOSE, kernel) # Close gaps
32
+ dog_mask = cv2.dilate(dog_mask, kernel, iterations=2) # Expand segmentation
33
+
34
+ # Refine mask using superpixel segmentation
35
+ segments = seg.slic(np.array(image), n_segments=100, compactness=10)
36
+ refined_dog_mask = np.where(dog_mask == 1, segments, 0)
37
+
38
+ # Restore the dog class label (12) in refined regions
39
+ refined_dog_mask[dog_mask == 1] = dog_class
40
+
41
+ # Restore the dog class label (12) in refined regions
42
+ mask[refined_dog_mask > 0] = dog_class
43
+
44
+ # Convert mask to np.uint8 if necessary
45
+ return mask.astype(np.uint8)
46
+
47
+
48
+ def crop_dog(image, mask, target_aspect=1, padding=20):
49
+ # Get bounding box of the dog
50
+ y_indices, x_indices = np.where(mask == dog_class) # Dog class pixels
51
+ if len(y_indices) == 0 or len(x_indices) == 0:
52
+ return image # No dog detected
53
+
54
+ x_min, x_max = x_indices.min(), x_indices.max()
55
+ y_min, y_max = y_indices.min(), y_indices.max()
56
+
57
+ # Calculate aspect ratio of resize target
58
+ width = x_max - x_min
59
+ height = y_max - y_min
60
+ current_aspect = width / height
61
+
62
+ # Adjust bounding box to match target aspect ratio
63
+ if current_aspect > target_aspect:
64
+ new_height = width / target_aspect
65
+ diff = (new_height - height) / 2
66
+ y_min = max(0, int(y_min - diff))
67
+ y_max = min(mask.shape[0], int(y_max + diff))
68
+ else:
69
+ new_width = height * target_aspect
70
+ diff = (new_width - width) / 2
71
+ x_min = max(0, int(x_min - diff))
72
+ x_max = min(mask.shape[1], int(x_max + diff))
73
+
74
+ # Apply padding
75
+ x_min = max(0, x_min - padding)
76
+ x_max = min(mask.shape[1], x_max + padding)
77
+ y_min = max(0, y_min - padding)
78
+ y_max = min(mask.shape[0], y_max + padding)
79
+
80
+ cropped_image = image.crop((x_min, y_min, x_max, y_max))
81
+ return cropped_image
82
+
83
+
84
+ def segment_image(image):
85
+ image = image.convert("RGB")
86
+ orig_size = image.size
87
+ transform = transforms.Compose([
88
+ transforms.ToTensor()
89
+ ])
90
+ image_tensor = transform(image).unsqueeze(0).to(device)
91
+
92
+ with torch.no_grad():
93
+ output = seg_model(image_tensor)['out'][0]
94
+ mask = output.argmax(0) # Keep on GPU
95
+
96
+ # Dynamically determine the main object class
97
+ unique_classes = mask.unique()
98
+ unique_classes = unique_classes[unique_classes != 0] # Remove background class (0)
99
+ if len(unique_classes) == 0:
100
+ print(f'No segmentation found')
101
+ return image, None # Skip image if no valid segmentation found
102
+
103
+ mask = mask.cpu().numpy() # Move to CPU only when needed
104
+ mask = refine_dog_mask(mask, image)
105
+
106
+ return image, mask
107
+
108
+
109
+ def visualize_segmentation(image, mask):
110
+ font_border = 2
111
+ font_size_segment_pct = 0.25
112
+
113
+ # Create color overlay for masks
114
+ overlay = np.zeros((*mask.shape, 3), dtype=np.uint8)
115
+ unique_classes = np.unique(mask)
116
+ contours_dict = []
117
+
118
+ for class_id in unique_classes:
119
+ if class_id == 0:
120
+ continue # Skip background
121
+ mask_indices = np.argwhere(mask == class_id)
122
+ if len(mask_indices) > 0:
123
+ mask_binary = (mask == class_id).astype(np.uint8)
124
+ contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
125
+ for contour in contours:
126
+ if cv2.contourArea(contour) > 100: # Filter small segments
127
+ contours_dict.append((contour, class_id))
128
+ color = (0, 255, 0) if class_id == dog_class else (255, 0, 0) # Green for dog, red for others
129
+ cv2.drawContours(overlay, [contour], -1, color, thickness=cv2.FILLED)
130
+
131
+ # Convert overlay to PIL image with transparency
132
+ overlay_img = Image.fromarray(overlay).convert("RGBA")
133
+ image_rgba = image.convert("RGBA")
134
+ blended = Image.blend(image_rgba, overlay_img, alpha=0.3)
135
+
136
+ # Draw category ID inside masks
137
+ draw = ImageDraw.Draw(blended)
138
+ for contour, class_id in contours_dict:
139
+ x, y, w, h = cv2.boundingRect(contour)
140
+ font_size = max(10, int(h * font_size_segment_pct))
141
+
142
+ try:
143
+ font = ImageFont.truetype("arial.ttf", font_size)
144
+ except IOError:
145
+ font = ImageFont.load_default()
146
+
147
+ text_x = x + w // 2
148
+ text_y = y + h // 2
149
+
150
+ draw.text((text_x - font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font)
151
+ draw.text((text_x + font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font)
152
+ draw.text((text_x, text_y - font_border), str(class_id), fill=(0, 0, 0, 255), font=font)
153
+ draw.text((text_x, text_y + font_border), str(class_id), fill=(0, 0, 0, 255), font=font)
154
+ draw.text((text_x, text_y), str(class_id), fill=(255, 255, 255, 255), font=font)
155
+
156
+ return blended.convert("RGB")