Hang Zhou commited on
Commit
9aff0cd
·
verified ·
1 Parent(s): 9756ed1

Upload folder using huggingface_hub

Browse files
scripts/annotate_sam.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+ from segment_anything import sam_model_registry, SamPredictor
7
+ from lvis import LVIS
8
+ import copy
9
+ from pathlib import Path
10
+
11
+
12
+ class Objects365SAM():
13
+ def __init__(self, index_low, index_high):
14
+ # Load SAM model
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
17
+ model_type = "vit_h"
18
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
+ sam.to(device=self.device)
20
+ self.predictor = SamPredictor(sam)
21
+
22
+ self.index_low = index_low
23
+ self.index_high = index_high
24
+
25
+ # Load annotations
26
+ def load_annotations(self, annotation_file):
27
+ with open(annotation_file, 'r') as f:
28
+ self.json_data = json.load(f)
29
+
30
+ def process_annotations_with_sam(self, images_dir, output_dir):
31
+ image_info_list = self.json_data['images']
32
+ counter = 0
33
+ for image_info in image_info_list[self.index_low:self.index_high]:
34
+ # start_time = time.time()
35
+ image_id = image_info['id']
36
+ image_name = image_info['file_name'].split('/')[-1]
37
+ image_subset = image_info['file_name'].split('/')[-2]
38
+
39
+ output_json_dir = Path(os.path.join(output_dir, image_subset))
40
+ output_json_dir.mkdir(exist_ok=True)
41
+
42
+ image_path = os.path.join(images_dir, image_subset, image_name)
43
+
44
+ # Load the image
45
+ image = cv2.imread(image_path)
46
+ if image is None:
47
+ print(f"Image not found: {image_path}")
48
+ continue
49
+ h, w, _ = image.shape
50
+ self.predictor.set_image(image)
51
+
52
+ # Get annotations for this image
53
+ image_annotations = [anno for anno in self.json_data['annotations'] if anno['image_id'] == image_id]
54
+
55
+ # Create bounding boxes from COCO format
56
+ bounding_boxes = []
57
+ for anno in image_annotations:
58
+ xmin, ymin, width, height = anno['bbox']
59
+ xmax, ymax = xmin + width, ymin + height
60
+ bounding_boxes.append([xmin, ymin, xmax, ymax])
61
+
62
+ # Convert bounding boxes to tensor for SAM
63
+ input_boxes = torch.tensor(bounding_boxes, device=self.device).float()
64
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
65
+
66
+ # Get masks from SAM
67
+ with torch.no_grad():
68
+ masks, scores, logits = self.predictor.predict_torch(
69
+ point_coords=None,
70
+ point_labels=None,
71
+ boxes=transformed_boxes,
72
+ multimask_output=False,
73
+ )
74
+
75
+ # Convert masks to COCO-style annotations
76
+ mask_annotations = []
77
+ for mask in masks:
78
+ binary_mask = mask.squeeze().cpu().numpy().astype(np.uint8)
79
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
80
+ if len(contours) == 0:
81
+ continue
82
+ largest_contour = max(contours, key=cv2.contourArea)
83
+ segmentation = largest_contour.flatten().tolist()
84
+ x, y, w, h = cv2.boundingRect(largest_contour)
85
+ area = float(cv2.contourArea(largest_contour))
86
+ # mask_annotations.append(segmentation)
87
+ mask_annotations.append({
88
+ "segmentation": [segmentation],
89
+ "bbox": [x, y, w, h],
90
+ "area": area,
91
+ "category_id": 1
92
+ })
93
+
94
+ save_annotations_to_json(image_id,
95
+ mask_annotations,
96
+ os.path.join(output_json_dir, image_name[:-4]+'.json')
97
+ )
98
+ torch.cuda.empty_cache()
99
+ counter += 1
100
+ print('Done image idex: ', counter)
101
+
102
+ def save_annotations_to_json(image_id, mask_annotations, output_file):
103
+ coco_format_output = {
104
+ "image_id": image_id,
105
+ "annotations": mask_annotations
106
+ }
107
+
108
+ with open(output_file, 'w') as f:
109
+ json.dump(coco_format_output, f)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ '''
114
+ Image number: train/test: 1742292/80000
115
+ '''
116
+ import argparse
117
+
118
+ parser = argparse.ArgumentParser(description="Annotate labels with Segment Anything")
119
+ parser.add_argument('--is_train', action='store_true', help="Train/Test")
120
+ parser.add_argument("--index_low", type=int, default=0, help="Lower bound of indexes for processing Objects365 dataset.")
121
+ parser.add_argument("--index_high", type=int, default=1742292, help="Upper bound of indexes for processing Objects365 dataset.")
122
+ args = parser.parse_args()
123
+
124
+ if args.is_train:
125
+ input_json_dir = '../data/object365/zhiyuan_objv2_train.json'
126
+ input_image_dir = '../data/object365/images/train/'
127
+ output_dir = Path('../data/object365/labels/train/')
128
+ else:
129
+ input_json_dir = '../data/object365/zhiyuan_objv2_val.json'
130
+ input_image_dir = '../data/object365/images/val/'
131
+ output_dir = Path('../data/object365/labels/val/')
132
+
133
+ output_dir.mkdir(exist_ok=True)
134
+
135
+ sam_annotator = Objects365SAM(args.index_low, args.index_high)
136
+ sam_annotator.load_annotations(input_json_dir)
137
+ sam_annotator.process_annotations_with_sam(input_image_dir, output_dir)
138
+
139
+
scripts/data_construction.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LVIS train set
2
+ python -m datasets.lvis \
3
+ --dataset_dir "/path/to/raw_data" \
4
+ --construct_dataset_dir "data/train/LVIS" \
5
+ --area_ratio 0.02 \
6
+ --is_build_data \
7
+ --is_train
8
+
9
+ # LVIS test set
10
+ python -m datasets.lvis \
11
+ --dataset_dir "/path/to/raw_data" \
12
+ --construct_dataset_dir "data/test/LVIS" \
13
+ --area_ratio 0.02 \
14
+ --is_build_data
15
+
scripts/inference.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ python run_test.py \
3
+ --input "sample" \
4
+ --output "results/sample" \
5
+ --obj_thr 2
scripts/train.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train on the whole dataset
2
+ python run_train.py \
3
+ --root_dir 'LOGS/all_data' \
4
+ --batch_size 16 \
5
+ --logger_freq 1000 \
6
+ --is_joint
7
+
8
+ python run_train.py \
9
+ --root_dir 'LOGS/lvis' \
10
+ --batch_size 16 \
11
+ --logger_freq 1000 \
12
+ --dataset_name lvis
13
+