|
|
import re |
|
|
import math |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
from PIL import Image, ImageDraw |
|
|
import numpy as np |
|
|
import json |
|
|
import argparse |
|
|
import cv2 |
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
|
|
|
sam_checkpoint = "/mnt/lustre/zhouyue1.vendor/segment-anything-main/sam_vit_h_4b8939.pth" |
|
|
model_type = "vit_h" |
|
|
device = "cuda" |
|
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
|
sam.to(device=device) |
|
|
predictor = SamPredictor(sam) |
|
|
|
|
|
def calculate_iou(box1, box2): |
|
|
""" |
|
|
快速计算两个水平矩形框的 IoU(Intersection over Union)。 |
|
|
|
|
|
参数: |
|
|
- box1, box2: (x1, y1, x2, y2) 矩形框,表示左上角和右下角的坐标。 |
|
|
|
|
|
返回: |
|
|
- IoU (float): 交并比(Intersection over Union) |
|
|
""" |
|
|
|
|
|
|
|
|
xi1 = max(box1[0], box2[0]) |
|
|
yi1 = max(box1[1], box2[1]) |
|
|
xi2 = min(box1[2], box2[2]) |
|
|
yi2 = min(box1[3], box2[3]) |
|
|
|
|
|
|
|
|
inter_width = xi2 - xi1 |
|
|
inter_height = yi2 - yi1 |
|
|
|
|
|
if inter_width <= 0 or inter_height <= 0: |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
inter_area = inter_width * inter_height |
|
|
|
|
|
|
|
|
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) |
|
|
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) |
|
|
|
|
|
|
|
|
union_area = box1_area + box2_area - inter_area |
|
|
|
|
|
|
|
|
return inter_area / union_area |
|
|
|
|
|
|
|
|
def decode_string(s, length): |
|
|
rows = s.strip().split(';') |
|
|
result = [] |
|
|
|
|
|
for row in rows: |
|
|
if not row.strip(): |
|
|
continue |
|
|
decoded_row = [] |
|
|
|
|
|
groups = row.split(',') |
|
|
|
|
|
|
|
|
for group in groups: |
|
|
try: |
|
|
num, count = group.split('*') |
|
|
decoded_row.extend([int(num.strip())] * int(count.strip())) |
|
|
except: |
|
|
decoded_row.extend([0] * length) |
|
|
|
|
|
if len(decoded_row) > length: |
|
|
decoded_row = decoded_row[:length] |
|
|
|
|
|
elif len(decoded_row) < length: |
|
|
decoded_row.extend([0] * (length - len(decoded_row))) |
|
|
|
|
|
result.append(decoded_row) |
|
|
|
|
|
|
|
|
if len(result) < length: |
|
|
for _ in range(length - len(result)): |
|
|
result.append([0] * length) |
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
def extract_bboxes(output): |
|
|
""" |
|
|
Extract bounding box coordinates from the given string using regular expressions. |
|
|
:param output: String containing bounding box coordinates in the format {<bx_left><by_top><bx_right><by_bottom>|θ} |
|
|
:return: List of bounding boxes, each in the format [bx_left, by_top, bx_right, by_bottom, θ] |
|
|
""" |
|
|
|
|
|
pattern = r'\[([0-9, ]+)\]' |
|
|
matches = re.findall(pattern, output) |
|
|
bboxes = [list(map(float, match.split(","))) for match in matches] |
|
|
|
|
|
return bboxes |
|
|
|
|
|
|
|
|
def load_jsonl(filename): |
|
|
data = [] |
|
|
with open(filename, 'r') as jsonl_file: |
|
|
for line in jsonl_file: |
|
|
data.append(json.loads(line.strip())) |
|
|
return data |
|
|
|
|
|
def folder_creat_if_not_exist(folder): |
|
|
if not os.path.exists(folder): |
|
|
os.makedirs(folder) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description='Process some paths.') |
|
|
parser.add_argument('--scale', required=True, help='Normalize scale') |
|
|
parser.add_argument('--image-folder', required=True, help='Image directory') |
|
|
parser.add_argument('--answers-file', required=True, help='Target jsonl directory') |
|
|
parser.add_argument('--vis-dir', default=None, help='Base URL for the API') |
|
|
|
|
|
args = parser.parse_args() |
|
|
scale = int(args.scale) |
|
|
|
|
|
|
|
|
predict = load_jsonl(args.answers_file) |
|
|
total_cnt = len(predict) |
|
|
correct = 0 |
|
|
format_error = 0 |
|
|
i = 0 |
|
|
for i, predict in tqdm(enumerate(predict), total=total_cnt): |
|
|
answer = predict['answer'] |
|
|
answer = answer.strip() |
|
|
gt_bbox = predict['bbox'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
predict_boxes = extract_bboxes(answer) |
|
|
if predict_boxes == None: |
|
|
format_error += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ori_img_path = args.image_folder + predict['image_id'] |
|
|
img = Image.open(ori_img_path) |
|
|
width, height = img.size |
|
|
scale = int(args.scale) |
|
|
vis_dir = args.vis_dir |
|
|
|
|
|
if vis_dir: |
|
|
folder_creat_if_not_exist(vis_dir) |
|
|
draw = ImageDraw.Draw(img) |
|
|
|
|
|
|
|
|
pred_bbox = predict_boxes[0] |
|
|
pred_bbox[0] = pred_bbox[0] / scale * width |
|
|
pred_bbox[1] = pred_bbox[1] / scale * height |
|
|
pred_bbox[2] = pred_bbox[2] / scale * width |
|
|
pred_bbox[3] = pred_bbox[3] / scale * height |
|
|
|
|
|
|
|
|
image = cv2.imread(ori_img_path) |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
predictor.set_image(image) |
|
|
input_box = np.array(pred_bbox) |
|
|
masks, qualities, _ = predictor.predict( |
|
|
box=input_box, |
|
|
multimask_output=False, |
|
|
) |
|
|
mask_uint8 = masks[0].astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
max_contour = max(contours, key=cv2.contourArea) |
|
|
|
|
|
|
|
|
largest_mask = np.zeros_like(mask_uint8) |
|
|
|
|
|
cv2.drawContours(largest_mask, [max_contour], -1, 255, thickness=cv2.FILLED) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, y, w, h = cv2.boundingRect(largest_mask) |
|
|
|
|
|
pred_bbox_new = [x, y, x + w, y + h] |
|
|
|
|
|
|
|
|
iou_score = calculate_iou(gt_bbox, pred_bbox) |
|
|
if iou_score >= 0.5: |
|
|
correct += 1 |
|
|
|
|
|
if vis_dir: |
|
|
mask_save_path = vis_dir + predict['image_id'].split('.')[0] + f'_{i}.png' |
|
|
cv2.imwrite(mask_save_path, largest_mask) |
|
|
draw.rectangle(gt_bbox, outline="red", width=5) |
|
|
draw.rectangle(pred_bbox, outline="blue", width=5) |
|
|
draw.rectangle(pred_bbox_new, outline="green", width=5) |
|
|
|
|
|
point_color = (0, 255, 0) |
|
|
radius = 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img.save(vis_dir + predict['image_id'].split('.')[0] + f'_{i}.jpg') |
|
|
except: |
|
|
format_error += 1 |
|
|
continue |
|
|
|
|
|
print(f"Evaluating ...") |
|
|
print(f'Precision @ 0.5: {correct / total_cnt} \n') |
|
|
print(f'Format error ratio: {format_error / total_cnt} \n') |
|
|
|