Spaces:
Sleeping
Sleeping
Tony Lian
commited on
Commit
·
93de48e
1
Parent(s):
0cbad80
Apply batching to SAM to reduce the memory cost with many objects
Browse files- generation.py +13 -7
generation.py
CHANGED
|
@@ -53,7 +53,7 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
|
|
| 53 |
batch_size = input_len
|
| 54 |
|
| 55 |
run_times = int(np.ceil(input_len / batch_size))
|
| 56 |
-
|
| 57 |
for batch_idx in range(run_times):
|
| 58 |
input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
|
| 59 |
bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
|
|
@@ -68,17 +68,23 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
|
|
| 68 |
gc.collect()
|
| 69 |
torch.cuda.empty_cache()
|
| 70 |
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
|
| 73 |
latents_all.append(latents_all_batch)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
|
| 81 |
-
mask_selected
|
| 82 |
|
| 83 |
mask_selected_tensor = torch.tensor(mask_selected)
|
| 84 |
|
|
|
|
| 53 |
batch_size = input_len
|
| 54 |
|
| 55 |
run_times = int(np.ceil(input_len / batch_size))
|
| 56 |
+
mask_selected_list, single_object_pil_images_box_ann, latents_all = [], [], []
|
| 57 |
for batch_idx in range(run_times):
|
| 58 |
input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
|
| 59 |
bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
|
|
|
|
| 68 |
gc.collect()
|
| 69 |
torch.cuda.empty_cache()
|
| 70 |
|
| 71 |
+
# `sam_refine_boxes` also calls `empty_cache` so we don't need to explicitly empty the cache again.
|
| 72 |
+
mask_selected, _ = sam.sam_refine_boxes(sam_input_images=single_object_images_batch, boxes=bboxes_batch, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
|
| 73 |
+
|
| 74 |
+
mask_selected_list.append(np.array(mask_selected)[:, 0])
|
| 75 |
single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
|
| 76 |
latents_all.append(latents_all_batch)
|
| 77 |
|
| 78 |
+
single_object_pil_images_box_ann, latents_all = sum(single_object_pil_images_box_ann, []), torch.cat(latents_all, dim=1)
|
| 79 |
+
|
| 80 |
+
# mask_selected_list: List(batch)[List(image)[List(box)[Array of shape (64, 64)]]]
|
| 81 |
+
|
| 82 |
+
mask_selected = np.concatenate(mask_selected_list, axis=0)
|
| 83 |
+
mask_selected = mask_selected.reshape((-1, *mask_selected.shape[-2:]))
|
| 84 |
|
| 85 |
+
assert mask_selected.shape[0] == input_latents.shape[0], f"{mask_selected.shape[0]} != {input_latents.shape[0]}"
|
| 86 |
|
| 87 |
+
print(mask_selected.shape)
|
| 88 |
|
| 89 |
mask_selected_tensor = torch.tensor(mask_selected)
|
| 90 |
|