vignesh-99 commited on
Commit
0fd2fa1
·
1 Parent(s): 5f8e3f9

removed reduanant batching

Browse files
mask2former.py CHANGED
@@ -145,7 +145,7 @@ def get_model_input(img_arr, mask, image_id):
145
  instance_id_to_class_id[instance_id] = class_id
146
  instance_id += 1
147
 
148
- return {'image': img_arr, 'instance_id_to_class_id': instance_id_to_class_id, 'mask': np.astype(mask, dtype=np.int32), 'image_id' : image_id}
149
 
150
 
151
 
@@ -609,6 +609,9 @@ class Mask2FormerSahi(DetectionModel):
609
  return_binary_maps=True, target_sizes=target_sizes)
610
 
611
  self._original_predictions = post_processed_outputs
 
 
 
612
 
613
  def get_polygonal_predictions(self, post_processed_output) -> tuple:
614
 
 
145
  instance_id_to_class_id[instance_id] = class_id
146
  instance_id += 1
147
 
148
+ return {'image': img_arr, 'instance_id_to_class_id': instance_id_to_class_id, 'mask': np.astype(mask, np.int32), 'image_id' : image_id}
149
 
150
 
151
 
 
609
  return_binary_maps=True, target_sizes=target_sizes)
610
 
611
  self._original_predictions = post_processed_outputs
612
+
613
+ def perform_batch_inference(self, images: list[np.ndarray]) -> None:
614
+ return self.perform_inference(images)
615
 
616
  def get_polygonal_predictions(self, post_processed_output) -> tuple:
617
 
mask2former_app_predict.py CHANGED
@@ -1,7 +1,7 @@
1
  import mask2former
2
  import tree_commons as tc
3
  import torch
4
- from mask2former_sahi_predict_override import get_sliced_prediction
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
@@ -28,8 +28,7 @@ def predict(img_arr):
28
  slice_width=tc.CROPPED_IMAGE_WIDTH,
29
  overlap_height_ratio=0.2,
30
  overlap_width_ratio=0.2,
31
- num_batch=9,
32
- verbose=2,
33
  )
34
 
35
 
 
1
  import mask2former
2
  import tree_commons as tc
3
  import torch
4
+ from sahi.predict import get_sliced_prediction
5
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
 
28
  slice_width=tc.CROPPED_IMAGE_WIDTH,
29
  overlap_height_ratio=0.2,
30
  overlap_width_ratio=0.2,
31
+ batch_size=9
 
32
  )
33
 
34
 
mask2former_sahi_predict_override.py DELETED
@@ -1,245 +0,0 @@
1
- from __future__ import annotations
2
-
3
-
4
- import time
5
-
6
- from functools import cmp_to_key
7
-
8
- from sahi import ObjectPrediction
9
- from tqdm import tqdm
10
- from sahi.predict import filter_predictions
11
-
12
-
13
- from sahi.logger import logger
14
-
15
- from sahi.postprocess.combine import (
16
- GreedyNMMPostprocess,
17
- LSNMSPostprocess,
18
- NMMPostprocess,
19
- NMSPostprocess,
20
- PostprocessPredictions,
21
- )
22
- from sahi.prediction import PredictionResult
23
- from sahi.slicing import slice_image
24
-
25
-
26
- from sahi.utils.import_utils import check_requirements
27
-
28
- POSTPROCESS_NAME_TO_CLASS = {
29
- "GREEDYNMM": GreedyNMMPostprocess,
30
- "NMM": NMMPostprocess,
31
- "NMS": NMSPostprocess,
32
- "LSNMS": LSNMSPostprocess,
33
- }
34
-
35
- LOW_MODEL_CONFIDENCE = 0.1
36
-
37
- def get_prediction(
38
- image_list,
39
- detection_model,
40
- shift_amount: list = [0, 0],
41
- full_shape=None,
42
- postprocess: PostprocessPredictions | None = None,
43
- verbose: int = 0,
44
- exclude_classes_by_name: list[str] | None = None,
45
- exclude_classes_by_id: list[int] | None = None,
46
- ) -> list[PredictionResult]:
47
- """Function for performing prediction for given image using given detection_model.
48
-
49
- Arguments:
50
- image: str or np.ndarray
51
- Location of image or numpy image matrix to slice
52
- detection_model: model.DetectionMode
53
- shift_amount: List
54
- To shift the box and mask predictions from sliced image to full
55
- sized image, should be in the form of [shift_x, shift_y]
56
- full_shape: List
57
- Size of the full image, should be in the form of [height, width]
58
- postprocess: sahi.postprocess.combine.PostprocessPredictions
59
- verbose: int
60
- 0: no print (default)
61
- 1: print prediction duration
62
- exclude_classes_by_name: Optional[List[str]]
63
- None: if no classes are excluded
64
- List[str]: set of classes to exclude using its/their class label name/s
65
- exclude_classes_by_id: Optional[List[int]]
66
- None: if no classes are excluded
67
- List[int]: set of classes to exclude using one or more IDs
68
- Returns:
69
- A dict with fields:
70
- object_prediction_list: a list of ObjectPrediction
71
- durations_in_seconds: a dict containing elapsed times for profiling
72
- """
73
- durations_in_seconds = dict()
74
-
75
-
76
- # get prediction
77
- time_start = time.time()
78
- detection_model.perform_inference(image_list)
79
- time_end = time.time() - time_start
80
- durations_in_seconds["prediction"] = time_end
81
-
82
-
83
-
84
- # process prediction
85
- time_start = time.time()
86
- # works only with 1 batch
87
- detection_model.convert_original_predictions(
88
- shift_amount=shift_amount,
89
- full_shape=full_shape,
90
-
91
- )
92
- object_prediction_list_per_image = detection_model.object_prediction_list_per_image
93
-
94
-
95
- time_end = time.time() - time_start
96
- durations_in_seconds["postprocess"] = time_end
97
-
98
- if verbose == 1:
99
- print(
100
- "Prediction performed in",
101
- durations_in_seconds["prediction"],
102
- "seconds.",
103
- )
104
-
105
- preds = []
106
- for image, object_prediction_list in zip(image_list, object_prediction_list_per_image):
107
-
108
- res = PredictionResult(
109
- image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
110
- )
111
- preds.append(res)
112
-
113
- return preds
114
-
115
-
116
- def get_sliced_prediction(
117
- image,
118
- detection_model=None,
119
- slice_height: int | None = None,
120
- slice_width: int | None = None,
121
- overlap_height_ratio: float = 0.2,
122
- overlap_width_ratio: float = 0.2,
123
- num_batch = 1,
124
- postprocess_type: str = "GREEDYNMM",
125
- postprocess_match_metric: str = "IOS",
126
- postprocess_match_threshold: float = 0.5,
127
- postprocess_class_agnostic: bool = False,
128
- verbose: int = 1,
129
- merge_buffer_length: int | None = None,
130
- auto_slice_resolution: bool = True,
131
- slice_export_prefix: str | None = None,
132
- slice_dir: str | None = None,
133
- exclude_classes_by_name: list[str] | None = None,
134
- exclude_classes_by_id: list[int] | None = None,
135
- ) -> PredictionResult:
136
-
137
- # for profiling
138
- durations_in_seconds = dict()
139
-
140
-
141
- # create slices from full image
142
- time_start = time.time()
143
- slice_image_result = slice_image(
144
- image=image,
145
- output_file_name=slice_export_prefix,
146
- output_dir=slice_dir,
147
- slice_height=slice_height,
148
- slice_width=slice_width,
149
- overlap_height_ratio=overlap_height_ratio,
150
- overlap_width_ratio=overlap_width_ratio,
151
- auto_slice_resolution=auto_slice_resolution,
152
- )
153
- from sahi.models.ultralytics import UltralyticsDetectionModel
154
-
155
- num_slices = len(slice_image_result)
156
-
157
- time_end = time.time() - time_start
158
- durations_in_seconds["slice"] = time_end
159
-
160
- if isinstance(detection_model, UltralyticsDetectionModel) and detection_model.is_obb:
161
- # Only NMS is supported for OBB model outputs
162
- postprocess_type = "NMS"
163
-
164
- # init match postprocess instance
165
- if postprocess_type not in POSTPROCESS_NAME_TO_CLASS.keys():
166
- raise ValueError(
167
- f"postprocess_type should be one of {list(POSTPROCESS_NAME_TO_CLASS.keys())} "
168
- f"but given as {postprocess_type}"
169
- )
170
- postprocess_constructor = POSTPROCESS_NAME_TO_CLASS[postprocess_type]
171
- postprocess = postprocess_constructor(
172
- match_threshold=postprocess_match_threshold,
173
- match_metric=postprocess_match_metric,
174
- class_agnostic=postprocess_class_agnostic,
175
- )
176
-
177
- postprocess_time = 0
178
- time_start = time.time()
179
-
180
- # create prediction input
181
- num_group = int(num_slices / num_batch)
182
- if verbose == 1 or verbose == 2:
183
- tqdm.write(f"Performing prediction on {num_slices} slices.")
184
- object_prediction_list = []
185
- # perform sliced prediction
186
- for group_ind in range(num_group):
187
- # prepare batch (currently supports only 1 batch)
188
- image_list = []
189
- shift_amount_list = []
190
- for image_ind in range(num_batch):
191
- image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
192
- shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
193
- # perform batch prediction
194
- prediction_results = get_prediction(
195
- image_list=image_list,
196
- detection_model=detection_model,
197
- shift_amount=shift_amount_list,
198
- full_shape=[[
199
- slice_image_result.original_image_height,
200
- slice_image_result.original_image_width,
201
- ]]*num_batch,
202
- exclude_classes_by_name=exclude_classes_by_name,
203
- exclude_classes_by_id=exclude_classes_by_id,
204
- )
205
- for prediction_result in prediction_results:
206
- # convert sliced predictions to full predictions
207
- for object_prediction in prediction_result.object_prediction_list:
208
- if object_prediction: # if not empty
209
- object_prediction_list.append(object_prediction.get_shifted_object_prediction())
210
-
211
- # merge matching predictions during sliced prediction
212
- if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
213
- postprocess_time_start = time.time()
214
- object_prediction_list = postprocess(object_prediction_list)
215
- postprocess_time += time.time() - postprocess_time_start
216
-
217
-
218
-
219
- time_end = time.time() - time_start
220
- durations_in_seconds["prediction"] = time_end - postprocess_time
221
- durations_in_seconds["postprocess"] = postprocess_time
222
-
223
- if verbose == 2:
224
- print(
225
- "Slicing performed in",
226
- durations_in_seconds["slice"],
227
- "seconds.",
228
- )
229
- print(
230
- "Prediction performed in",
231
- durations_in_seconds["prediction"],
232
- "seconds.",
233
- )
234
- print(
235
- "Postprocessing performed in",
236
- durations_in_seconds["postprocess"],
237
- "seconds.",
238
- )
239
-
240
- return PredictionResult(
241
- image=image, object_prediction_list=object_prediction_list, durations_in_seconds=durations_in_seconds
242
- )
243
-
244
-
245
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tree_commons.py CHANGED
@@ -3,6 +3,8 @@ import numpy as np
3
  import matplotlib.pyplot as plt
4
  from matplotlib.patches import Polygon
5
 
 
 
6
 
7
  MASK2FORMER_CHECKPOINT_DIR = f'mask2former_checkpoints'
8
 
 
3
  import matplotlib.pyplot as plt
4
  from matplotlib.patches import Polygon
5
 
6
+ from sahi.predict import get_sliced_prediction
7
+
8
 
9
  MASK2FORMER_CHECKPOINT_DIR = f'mask2former_checkpoints'
10
 
yolo_app_predict.py CHANGED
@@ -26,6 +26,7 @@ def predict(img_arr):
26
  slice_width=tc.CROPPED_IMAGE_WIDTH,
27
  overlap_height_ratio=0.2,
28
  overlap_width_ratio=0.2,
 
29
  )
30
 
31
 
 
26
  slice_width=tc.CROPPED_IMAGE_WIDTH,
27
  overlap_height_ratio=0.2,
28
  overlap_width_ratio=0.2,
29
+ batch_size=9
30
  )
31
 
32