GeoGround / inference_scripts /generate_mask_sam_by_box.py
erenzhou's picture
Upload folder using huggingface_hub
4963c36 verified
import re
import math
from tqdm import tqdm
import os
from PIL import Image, ImageDraw
import numpy as np
import json
import argparse
import cv2
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "/mnt/lustre/zhouyue1.vendor/segment-anything-main/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
def calculate_iou(box1, box2):
"""
快速计算两个水平矩形框的 IoU(Intersection over Union)。
参数:
- box1, box2: (x1, y1, x2, y2) 矩形框,表示左上角和右下角的坐标。
返回:
- IoU (float): 交并比(Intersection over Union)
"""
# 计算交集矩形的左上角和右下角坐标
xi1 = max(box1[0], box2[0])
yi1 = max(box1[1], box2[1])
xi2 = min(box1[2], box2[2])
yi2 = min(box1[3], box2[3])
# 计算交集的宽和高,如果没有交集则宽或高为0
inter_width = xi2 - xi1
inter_height = yi2 - yi1
if inter_width <= 0 or inter_height <= 0:
return 0.0 # 没有交集
# 交集面积
inter_area = inter_width * inter_height
# 计算每个矩形框的面积
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
# 并集面积 = 两个矩形的面积之和 - 交集面积
union_area = box1_area + box2_area - inter_area
# 计算IoU
return inter_area / union_area
def decode_string(s, length):
rows = s.strip().split(';') # 先分割每一行
result = []
for row in rows:
if not row.strip(): # 如果字符串是空的,跳过
continue
decoded_row = []
groups = row.split(',') # 分割每个组
# print(row)
# print(groups)
for group in groups:
try:
num, count = group.split('*') # 分割数字和次数
decoded_row.extend([int(num.strip())] * int(count.strip())) # 解码并扩展
except:
decoded_row.extend([0] * length)
# print(len(decoded_row))
if len(decoded_row) > length:
decoded_row = decoded_row[:length]
# print(decoded_row)
elif len(decoded_row) < length:
decoded_row.extend([0] * (length - len(decoded_row)))
# print(decoded_row)
result.append(decoded_row) # 将解码的行添加到结果中
if len(result) < length:
for _ in range(length - len(result)):
result.append([0] * length)
# print(result)
return result
def extract_bboxes(output):
"""
Extract bounding box coordinates from the given string using regular expressions.
:param output: String containing bounding box coordinates in the format {<bx_left><by_top><bx_right><by_bottom>|θ}
:return: List of bounding boxes, each in the format [bx_left, by_top, bx_right, by_bottom, θ]
"""
# 修改正则表达式,确保最后一个数字和管道符号能够正确匹配
pattern = r'\[([0-9, ]+)\]'
matches = re.findall(pattern, output)
bboxes = [list(map(float, match.split(","))) for match in matches]
return bboxes
# 读取JSONL文件并将每行解析为Python字典,存入列表
def load_jsonl(filename):
data = []
with open(filename, 'r') as jsonl_file:
for line in jsonl_file:
data.append(json.loads(line.strip()))
return data
def folder_creat_if_not_exist(folder):
if not os.path.exists(folder):
os.makedirs(folder)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some paths.')
parser.add_argument('--scale', required=True, help='Normalize scale')
parser.add_argument('--image-folder', required=True, help='Image directory')
parser.add_argument('--answers-file', required=True, help='Target jsonl directory')
parser.add_argument('--vis-dir', default=None, help='Base URL for the API')
args = parser.parse_args()
scale = int(args.scale)
# 从 jsonl 文件中加载数据
predict = load_jsonl(args.answers_file)
total_cnt = len(predict)
correct = 0
format_error = 0
i = 0
for i, predict in tqdm(enumerate(predict), total=total_cnt):
answer = predict['answer']
answer = answer.strip()
gt_bbox = predict['bbox']
# answer = answer.replace("others", "0")
# answer = answer.replace("object", "1")
# answer = answer.replace("<seg>", "")
# answer = answer.replace("</seg>", "")
try:
predict_boxes = extract_bboxes(answer)
if predict_boxes == None:
format_error += 1
continue
# except:
# format_error += 1
# continue
ori_img_path = args.image_folder + predict['image_id']
img = Image.open(ori_img_path)
width, height = img.size
scale = int(args.scale)
vis_dir = args.vis_dir
if vis_dir:
folder_creat_if_not_exist(vis_dir)
draw = ImageDraw.Draw(img)
# try:
pred_bbox = predict_boxes[0]
pred_bbox[0] = pred_bbox[0] / scale * width
pred_bbox[1] = pred_bbox[1] / scale * height
pred_bbox[2] = pred_bbox[2] / scale * width
pred_bbox[3] = pred_bbox[3] / scale * height
# predict_points = [x / scale * width if i % 2 == 0 else x / scale * height for i, x in enumerate(predict_points)]
image = cv2.imread(ori_img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
input_box = np.array(pred_bbox)
masks, qualities, _ = predictor.predict(
box=input_box,
multimask_output=False,
)
mask_uint8 = masks[0].astype(np.uint8) * 255
# 找到所有的轮廓
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 找到面积最大的轮廓
max_contour = max(contours, key=cv2.contourArea)
# 创建一个与输入掩码相同大小的全黑图像
largest_mask = np.zeros_like(mask_uint8)
# 只绘制最大的连通域
cv2.drawContours(largest_mask, [max_contour], -1, 255, thickness=cv2.FILLED)
# 对掩码进行形态学开运算,去除小白点
# kernel = np.ones((5, 5), np.uint8) # 你可以根据你白点的大小调整内核尺寸
# opened_mask = cv2.morphologyEx(mask_uint8, cv2.MORPH_OPEN, kernel)
x, y, w, h = cv2.boundingRect(largest_mask)
# print(x, y, w, h, qualities[0], flush=True)
pred_bbox_new = [x, y, x + w, y + h]
# compute IoU
iou_score = calculate_iou(gt_bbox, pred_bbox)
if iou_score >= 0.5:
correct += 1
if vis_dir:
mask_save_path = vis_dir + predict['image_id'].split('.')[0] + f'_{i}.png'
cv2.imwrite(mask_save_path, largest_mask)
draw.rectangle(gt_bbox, outline="red", width=5)
draw.rectangle(pred_bbox, outline="blue", width=5)
draw.rectangle(pred_bbox_new, outline="green", width=5)
# coordinates = [(int(predict_points[i+1]), int(predict_points[i])) for i in range(0, len(predict_points), 2)]
point_color = (0, 255, 0) # 红色
radius = 5
# for point in coordinates:
# x, y = point
# # 用椭圆代表点,指定左上角和右下角的坐标范围
# draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill=point_color)
img.save(vis_dir + predict['image_id'].split('.')[0] + f'_{i}.jpg')
except:
format_error += 1
continue
print(f"Evaluating ...")
print(f'Precision @ 0.5: {correct / total_cnt} \n')
print(f'Format error ratio: {format_error / total_cnt} \n')