| |
|
| | from copy import deepcopy |
| | import json |
| |
|
| | import os |
| | import argparse |
| | import torchvision.transforms.functional as F |
| | import torch |
| | import cv2 |
| | import numpy as np |
| | from tqdm import tqdm |
| | from pathlib import Path |
| | import sys |
| | sys.path.append('VISAM') |
| | from main import get_args_parser |
| | from models import build_model |
| | from util.tool import load_model |
| | from models.structures import Instances |
| |
|
| | from torch.utils.data import Dataset, DataLoader |
| |
|
| |
|
| | |
| | sys.path.append('segment_anything') |
| | from segment_anything import build_sam, SamPredictor |
| |
|
| |
|
| | class Colors: |
| | |
| | def __init__(self): |
| | |
| | hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', |
| | '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7') |
| | self.palette = [self.hex2rgb(f'#{c}') for c in hexs] |
| | self.n = len(self.palette) |
| |
|
| | def __call__(self, i, bgr=False): |
| | c = self.palette[int(i) % self.n] |
| | return (c[2], c[1], c[0]) if bgr else c |
| |
|
| | @staticmethod |
| | def hex2rgb(h): |
| | return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) |
| |
|
| |
|
| | colors = Colors() |
| |
|
| |
|
| | class ListImgDataset(Dataset): |
| | def __init__(self, mot_path, img_list, det_db) -> None: |
| | super().__init__() |
| | self.mot_path = mot_path |
| | self.img_list = img_list |
| | self.det_db = det_db |
| |
|
| | ''' |
| | common settings |
| | ''' |
| | self.img_height = 800 |
| | self.img_width = 1536 |
| | self.mean = [0.485, 0.456, 0.406] |
| | self.std = [0.229, 0.224, 0.225] |
| |
|
| | def load_img_from_file(self, f_path): |
| | cur_img = cv2.imread(os.path.join(self.mot_path, f_path)) |
| | assert cur_img is not None, f_path |
| | cur_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2RGB) |
| | proposals = [] |
| | im_h, im_w = cur_img.shape[:2] |
| | for line in self.det_db[f_path[:-4] + '.txt']: |
| | l, t, w, h, s = list(map(float, line.split(','))) |
| | proposals.append([(l + w / 2) / im_w, |
| | (t + h / 2) / im_h, |
| | w / im_w, |
| | h / im_h, |
| | s]) |
| | return cur_img, torch.as_tensor(proposals).reshape(-1, 5) |
| |
|
| | def init_img(self, img, proposals): |
| | ori_img = img.copy() |
| | self.seq_h, self.seq_w = img.shape[:2] |
| | scale = self.img_height / min(self.seq_h, self.seq_w) |
| | if max(self.seq_h, self.seq_w) * scale > self.img_width: |
| | scale = self.img_width / max(self.seq_h, self.seq_w) |
| | target_h = int(self.seq_h * scale) |
| | target_w = int(self.seq_w * scale) |
| | img = cv2.resize(img, (target_w, target_h)) |
| | img = F.normalize(F.to_tensor(img), self.mean, self.std) |
| | img = img.unsqueeze(0) |
| | return img, ori_img, proposals |
| |
|
| | def __len__(self): |
| | return len(self.img_list) |
| | |
| | def __getitem__(self, index): |
| | img, proposals = self.load_img_from_file(self.img_list[index]) |
| | return self.init_img(img, proposals) |
| |
|
| |
|
| | class Detector(object): |
| | def __init__(self, args, model, vid, sam_predictor=None): |
| | self.args = args |
| | self.detr = model |
| |
|
| | self.vid = vid |
| | self.seq_num = os.path.basename(vid) |
| | img_list = os.listdir(os.path.join(self.args.mot_path, vid, 'img1')) |
| | img_list = [os.path.join(vid, 'img1', i) for i in img_list if 'jpg' in i] |
| |
|
| | self.img_list = sorted(img_list) |
| | self.img_len = len(self.img_list) |
| |
|
| | self.predict_path = os.path.join(self.args.output_dir, args.exp_name) |
| | os.makedirs(self.predict_path, exist_ok=True) |
| | |
| | fps = 25 |
| | size = (1920, 1080) |
| | self.videowriter = cv2.VideoWriter('visam.avi', cv2.VideoWriter_fourcc('M','J','P','G'), fps, size) |
| | |
| | self.sam_predictor = sam_predictor |
| |
|
| | @staticmethod |
| | def filter_dt_by_score(dt_instances: Instances, prob_threshold: float) -> Instances: |
| | keep = dt_instances.scores > prob_threshold |
| | keep &= dt_instances.obj_idxes >= 0 |
| | return dt_instances[keep] |
| |
|
| | @staticmethod |
| | def filter_dt_by_area(dt_instances: Instances, area_threshold: float) -> Instances: |
| | wh = dt_instances.boxes[:, 2:4] - dt_instances.boxes[:, 0:2] |
| | areas = wh[:, 0] * wh[:, 1] |
| | keep = areas > area_threshold |
| | return dt_instances[keep] |
| |
|
| | def detect(self, prob_threshold=0.6, area_threshold=100, vis=False): |
| | total_dts = 0 |
| | total_occlusion_dts = 0 |
| |
|
| | track_instances = None |
| | with open(os.path.join(self.args.mot_path, 'DanceTrack', self.args.det_db)) as f: |
| | det_db = json.load(f) |
| | loader = DataLoader(ListImgDataset(self.args.mot_path, self.img_list, det_db), 1, num_workers=2) |
| | lines = [] |
| | for i, data in enumerate(tqdm(loader)): |
| | cur_img, ori_img, proposals = [d[0] for d in data] |
| | cur_img, proposals = cur_img.cuda(), proposals.cuda() |
| |
|
| | |
| | if track_instances is not None: |
| | track_instances.remove('boxes') |
| | track_instances.remove('labels') |
| | seq_h, seq_w, _ = ori_img.shape |
| |
|
| | res = self.detr.inference_single_image(cur_img, (seq_h, seq_w), track_instances, proposals) |
| | track_instances = res['track_instances'] |
| |
|
| | dt_instances = deepcopy(track_instances) |
| |
|
| | |
| | dt_instances = self.filter_dt_by_score(dt_instances, prob_threshold) |
| | dt_instances = self.filter_dt_by_area(dt_instances, area_threshold) |
| |
|
| | total_dts += len(dt_instances) |
| |
|
| | bbox_xyxy = dt_instances.boxes.tolist() |
| | identities = dt_instances.obj_idxes.tolist() |
| |
|
| | img = ori_img.to(torch.device('cpu')).numpy().copy()[..., ::-1] |
| | if self.sam_predictor is not None: |
| | masks_all = [] |
| | self.sam_predictor.set_image(ori_img.to(torch.device('cpu')).numpy().copy()) |
| | |
| | for bbox, id in zip(np.array(bbox_xyxy), identities): |
| | masks, iou_predictions, low_res_masks = self.sam_predictor.predict(box=bbox) |
| | index_max = iou_predictions.argsort()[0] |
| | masks = np.concatenate([masks[index_max:(index_max+1)], masks[index_max:(index_max+1)], masks[index_max:(index_max+1)]], axis=0) |
| | masks = masks.astype(np.int32)*np.array(colors(id))[:, None, None] |
| | masks_all.append(masks) |
| | |
| | self.sam_predictor.reset_image() |
| | if len(masks_all): |
| | masks_sum = masks_all[0].copy() |
| | for m in masks_all[1:]: |
| | masks_sum += m |
| | else: |
| | masks_sum = np.zeros_like(img).transpose(2, 0, 1) |
| |
|
| | img = (img * 0.5 + (masks_sum.transpose(1,2,0) * 30) %128).astype(np.uint8) |
| | for bbox in bbox_xyxy: |
| | cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0,0,255), thickness=3) |
| | self.videowriter.write(img) |
| |
|
| | save_format = '{frame},{id},{x1:.2f},{y1:.2f},{w:.2f},{h:.2f},1,-1,-1,-1\n' |
| | for xyxy, track_id in zip(bbox_xyxy, identities): |
| | if track_id < 0 or track_id is None: |
| | continue |
| | x1, y1, x2, y2 = xyxy |
| | w, h = x2 - x1, y2 - y1 |
| | lines.append(save_format.format(frame=i + 1, id=track_id, x1=x1, y1=y1, w=w, h=h)) |
| | with open(os.path.join(self.predict_path, f'{self.seq_num}.txt'), 'w') as f: |
| | f.writelines(lines) |
| | print("totally {} dts {} occlusion dts".format(total_dts, total_occlusion_dts)) |
| |
|
| |
|
| | class RuntimeTrackerBase(object): |
| | def __init__(self, score_thresh=0.6, filter_score_thresh=0.5, miss_tolerance=10): |
| | self.score_thresh = score_thresh |
| | self.filter_score_thresh = filter_score_thresh |
| | self.miss_tolerance = miss_tolerance |
| | self.max_obj_id = 0 |
| |
|
| | def clear(self): |
| | self.max_obj_id = 0 |
| |
|
| | def update(self, track_instances: Instances): |
| | device = track_instances.obj_idxes.device |
| |
|
| | track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0 |
| | new_obj = (track_instances.obj_idxes == -1) & (track_instances.scores >= self.score_thresh) |
| | disappeared_obj = (track_instances.obj_idxes >= 0) & (track_instances.scores < self.filter_score_thresh) |
| | num_new_objs = new_obj.sum().item() |
| |
|
| | track_instances.obj_idxes[new_obj] = self.max_obj_id + torch.arange(num_new_objs, device=device) |
| | self.max_obj_id += num_new_objs |
| |
|
| | track_instances.disappear_time[disappeared_obj] += 1 |
| | to_del = disappeared_obj & (track_instances.disappear_time >= self.miss_tolerance) |
| | track_instances.obj_idxes[to_del] = -1 |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser("Grounded-Segment-Anything VISAM Demo", parents=[get_args_parser()]) |
| | parser.add_argument('--score_threshold', default=0.5, type=float) |
| | parser.add_argument('--update_score_threshold', default=0.5, type=float) |
| | parser.add_argument('--miss_tolerance', default=20, type=int) |
| | |
| | parser.add_argument( |
| | "--sam_checkpoint", type=str, required=True, help="path to checkpoint file" |
| | ) |
| | parser.add_argument("--video_path", type=str, required=True, help="path to image file") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if args.output_dir: |
| | Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
| | |
| | sam_predictor = SamPredictor(build_sam(checkpoint=args.sam_checkpoint)) |
| | _ = sam_predictor.model.to(device='cuda') |
| | |
| | |
| | detr, _, _ = build_model(args) |
| | detr.track_embed.score_thr = args.update_score_threshold |
| | detr.track_base = RuntimeTrackerBase(args.score_threshold, args.score_threshold, args.miss_tolerance) |
| | checkpoint = torch.load(args.resume, map_location='cpu') |
| | detr = load_model(detr, args.resume) |
| | detr.eval() |
| | detr = detr.cuda() |
| | |
| | rank = int(os.environ.get('RLAUNCH_REPLICA', '0')) |
| | ws = int(os.environ.get('RLAUNCH_REPLICA_TOTAL', '1')) |
| | |
| | det = Detector(args, model=detr, vid=args.video_path, sam_predictor=sam_predictor) |
| | det.detect(args.score_threshold) |
| |
|