Spaces:
Build error
Build error
| # Copyright (c) 2023 Dhruba Ghosh | |
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: MIT | |
| # | |
| # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. | |
| # | |
| # Original file was released under MIT, with the full license text | |
| # available at https://github.com/djghosh13/geneval/blob/main/LICENSE. | |
| # | |
| # This modified file is released under the same license. | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from tqdm import tqdm | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image, ImageOps | |
| import torch | |
| import torch.distributed as dist | |
| import mmdet | |
| from mmdet.apis import inference_detector, init_detector | |
| import open_clip | |
| from clip_benchmark.metrics import zeroshot_classification as zsc | |
| zsc.tqdm = lambda it, *args, **kwargs: it | |
| def setup_distributed(): | |
| """初始化分布式环境""" | |
| dist.init_process_group(backend="nccl") | |
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | |
| # Get directory path | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("imagedir", type=str) | |
| parser.add_argument("--outfile", type=str, default="results.jsonl") | |
| parser.add_argument("--model-config", type=str, default=None) | |
| parser.add_argument("--model-path", type=str, default="./") | |
| # Other arguments | |
| parser.add_argument("--options", nargs="*", type=str, default=[]) | |
| args = parser.parse_args() | |
| args.options = dict(opt.split("=", 1) for opt in args.options) | |
| if args.model_config is None: | |
| args.model_config = os.path.join( | |
| os.path.dirname(mmdet.__file__), | |
| "../configs/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py" | |
| ) | |
| return args | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| assert DEVICE == "cuda" | |
| def timed(fn): | |
| def wrapper(*args, **kwargs): | |
| startt = time.time() | |
| result = fn(*args, **kwargs) | |
| endt = time.time() | |
| print(f'Function {fn.__name__!r} executed in {endt - startt:.3f}s', file=sys.stderr) | |
| return result | |
| return wrapper | |
| # Load models | |
| def load_models(args): | |
| CONFIG_PATH = args.model_config | |
| OBJECT_DETECTOR = args.options.get('model', "mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco") | |
| CKPT_PATH = os.path.join(args.model_path, f"{OBJECT_DETECTOR}.pth") | |
| object_detector = init_detector(CONFIG_PATH, CKPT_PATH, device=DEVICE) | |
| clip_arch = args.options.get('clip_model', "ViT-L-14") | |
| clip_model, _, transform = open_clip.create_model_and_transforms(clip_arch, pretrained="openai", device=DEVICE) | |
| tokenizer = open_clip.get_tokenizer(clip_arch) | |
| with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file: | |
| classnames = [line.strip() for line in cls_file] | |
| return object_detector, (clip_model, transform, tokenizer), classnames | |
| COLORS = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"] | |
| COLOR_CLASSIFIERS = {} | |
| # Evaluation parts | |
| class ImageCrops(torch.utils.data.Dataset): | |
| def __init__(self, image: Image.Image, objects): | |
| self._image = image.convert("RGB") | |
| bgcolor = args.options.get('bgcolor', "#999") | |
| if bgcolor == "original": | |
| self._blank = self._image.copy() | |
| else: | |
| self._blank = Image.new("RGB", image.size, color=bgcolor) | |
| self._objects = objects | |
| def __len__(self): | |
| return len(self._objects) | |
| def __getitem__(self, index): | |
| box, mask = self._objects[index] | |
| if mask is not None: | |
| assert tuple(self._image.size[::-1]) == tuple(mask.shape), (index, self._image.size[::-1], mask.shape) | |
| image = Image.composite(self._image, self._blank, Image.fromarray(mask)) | |
| else: | |
| image = self._image | |
| if args.options.get('crop', '1') == '1': | |
| image = image.crop(box[:4]) | |
| # if args.save: | |
| # base_count = len(os.listdir(args.save)) | |
| # image.save(os.path.join(args.save, f"cropped_{base_count:05}.png")) | |
| return (transform(image), 0) | |
| def color_classification(image, bboxes, classname): | |
| if classname not in COLOR_CLASSIFIERS: | |
| COLOR_CLASSIFIERS[classname] = zsc.zero_shot_classifier( | |
| clip_model, tokenizer, COLORS, | |
| [ | |
| f"a photo of a {{c}} {classname}", | |
| f"a photo of a {{c}}-colored {classname}", | |
| f"a photo of a {{c}} object" | |
| ], | |
| DEVICE | |
| ) | |
| clf = COLOR_CLASSIFIERS[classname] | |
| dataloader = torch.utils.data.DataLoader( | |
| ImageCrops(image, bboxes), | |
| batch_size=16, num_workers=4 | |
| ) | |
| with torch.no_grad(): | |
| pred, _ = zsc.run_classification(clip_model, clf, dataloader, DEVICE) | |
| return [COLORS[index.item()] for index in pred.argmax(1)] | |
| def compute_iou(box_a, box_b): | |
| area_fn = lambda box: max(box[2] - box[0] + 1, 0) * max(box[3] - box[1] + 1, 0) | |
| i_area = area_fn([ | |
| max(box_a[0], box_b[0]), max(box_a[1], box_b[1]), | |
| min(box_a[2], box_b[2]), min(box_a[3], box_b[3]) | |
| ]) | |
| u_area = area_fn(box_a) + area_fn(box_b) - i_area | |
| return i_area / u_area if u_area else 0 | |
| def relative_position(obj_a, obj_b): | |
| """Give position of A relative to B, factoring in object dimensions""" | |
| boxes = np.array([obj_a[0], obj_b[0]])[:, :4].reshape(2, 2, 2) | |
| center_a, center_b = boxes.mean(axis=-2) | |
| dim_a, dim_b = np.abs(np.diff(boxes, axis=-2))[..., 0, :] | |
| offset = center_a - center_b | |
| # | |
| revised_offset = np.maximum(np.abs(offset) - POSITION_THRESHOLD * (dim_a + dim_b), 0) * np.sign(offset) | |
| if np.all(np.abs(revised_offset) < 1e-3): | |
| return set() | |
| # | |
| dx, dy = revised_offset / np.linalg.norm(offset) | |
| relations = set() | |
| if dx < -0.5: relations.add("left of") | |
| if dx > 0.5: relations.add("right of") | |
| if dy < -0.5: relations.add("above") | |
| if dy > 0.5: relations.add("below") | |
| return relations | |
| def evaluate(image, objects, metadata): | |
| """ | |
| Evaluate given image using detected objects on the global metadata specifications. | |
| Assumptions: | |
| * Metadata combines 'include' clauses with AND, and 'exclude' clauses with OR | |
| * All clauses are independent, i.e., duplicating a clause has no effect on the correctness | |
| * CHANGED: Color and position will only be evaluated on the most confidently predicted objects; | |
| therefore, objects are expected to appear in sorted order | |
| """ | |
| correct = True | |
| reason = [] | |
| matched_groups = [] | |
| # Check for expected objects | |
| for req in metadata.get('include', []): | |
| classname = req['class'] | |
| matched = True | |
| found_objects = objects.get(classname, [])[:req['count']] | |
| if len(found_objects) < req['count']: | |
| correct = matched = False | |
| reason.append(f"expected {classname}>={req['count']}, found {len(found_objects)}") | |
| else: | |
| if 'color' in req: | |
| # Color check | |
| colors = color_classification(image, found_objects, classname) | |
| if colors.count(req['color']) < req['count']: | |
| correct = matched = False | |
| reason.append( | |
| f"expected {req['color']} {classname}>={req['count']}, found " + | |
| f"{colors.count(req['color'])} {req['color']}; and " + | |
| ", ".join(f"{colors.count(c)} {c}" for c in COLORS if c in colors) | |
| ) | |
| if 'position' in req and matched: | |
| # Relative position check | |
| expected_rel, target_group = req['position'] | |
| if matched_groups[target_group] is None: | |
| correct = matched = False | |
| reason.append(f"no target for {classname} to be {expected_rel}") | |
| else: | |
| for obj in found_objects: | |
| for target_obj in matched_groups[target_group]: | |
| true_rels = relative_position(obj, target_obj) | |
| if expected_rel not in true_rels: | |
| correct = matched = False | |
| reason.append( | |
| f"expected {classname} {expected_rel} target, found " + | |
| f"{' and '.join(true_rels)} target" | |
| ) | |
| break | |
| if not matched: | |
| break | |
| if matched: | |
| matched_groups.append(found_objects) | |
| else: | |
| matched_groups.append(None) | |
| # Check for non-expected objects | |
| for req in metadata.get('exclude', []): | |
| classname = req['class'] | |
| if len(objects.get(classname, [])) >= req['count']: | |
| correct = False | |
| reason.append(f"expected {classname}<{req['count']}, found {len(objects[classname])}") | |
| return correct, "\n".join(reason) | |
| def evaluate_image(filepath, metadata): | |
| result = inference_detector(object_detector, filepath) | |
| bbox = result[0] if isinstance(result, tuple) else result | |
| segm = result[1] if isinstance(result, tuple) and len(result) > 1 else None | |
| image = ImageOps.exif_transpose(Image.open(filepath)) | |
| detected = {} | |
| # Determine bounding boxes to keep | |
| confidence_threshold = THRESHOLD if metadata['tag'] != "counting" else COUNTING_THRESHOLD | |
| for index, classname in enumerate(classnames): | |
| ordering = np.argsort(bbox[index][:, 4])[::-1] | |
| ordering = ordering[bbox[index][ordering, 4] > confidence_threshold] # Threshold | |
| ordering = ordering[:MAX_OBJECTS].tolist() # Limit number of detected objects per class | |
| detected[classname] = [] | |
| while ordering: | |
| max_obj = ordering.pop(0) | |
| detected[classname].append((bbox[index][max_obj], None if segm is None else segm[index][max_obj])) | |
| ordering = [ | |
| obj for obj in ordering | |
| if NMS_THRESHOLD == 1 or compute_iou(bbox[index][max_obj], bbox[index][obj]) < NMS_THRESHOLD | |
| ] | |
| if not detected[classname]: | |
| del detected[classname] | |
| # Evaluate | |
| is_correct, reason = evaluate(image, detected, metadata) | |
| return { | |
| 'filename': filepath, | |
| 'tag': metadata['tag'], | |
| 'prompt': metadata['prompt'], | |
| 'correct': is_correct, | |
| 'reason': reason, | |
| 'metadata': json.dumps(metadata), | |
| 'details': json.dumps({ | |
| key: [box.tolist() for box, _ in value] | |
| for key, value in detected.items() | |
| }) | |
| } | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| THRESHOLD = float(args.options.get('threshold', 0.3)) | |
| COUNTING_THRESHOLD = float(args.options.get('counting_threshold', 0.9)) | |
| MAX_OBJECTS = int(args.options.get('max_objects', 16)) | |
| NMS_THRESHOLD = float(args.options.get('max_overlap', 1.0)) | |
| POSITION_THRESHOLD = float(args.options.get('position_threshold', 0.1)) | |
| # Initialize distributed environment | |
| setup_distributed() | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| device = f"cuda:{rank}" | |
| # Load models | |
| if rank == 0: | |
| print(f"[Rank 0] Loading model...") | |
| object_detector, (clip_model, transform, tokenizer), classnames = load_models(args) | |
| full_results = [] | |
| subfolders = [f for f in os.listdir(args.imagedir) if os.path.isdir(os.path.join(args.imagedir, f)) and f.isdigit()] | |
| total_subfolders = len(subfolders) | |
| # Divide subfolders to process by GPU | |
| subfolders_per_gpu = (total_subfolders + world_size - 1) // world_size | |
| start = rank * subfolders_per_gpu | |
| end = min(start + subfolders_per_gpu, total_subfolders) | |
| print(f"GPU {rank}: Processing {end - start} subfolders (index {start} to {end - 1})") | |
| for subfolder in tqdm(subfolders[start:end]): | |
| folderpath = os.path.join(args.imagedir, subfolder) | |
| with open(os.path.join(folderpath, "metadata.jsonl")) as fp: | |
| metadata = json.load(fp) | |
| # Evaluate each image | |
| for imagename in os.listdir(os.path.join(folderpath, "samples")): | |
| imagepath = os.path.join(folderpath, "samples", imagename) | |
| if not os.path.isfile(imagepath) or not re.match(r"\d+\.png", imagename): | |
| continue | |
| result = evaluate_image(imagepath, metadata) | |
| full_results.append(result) | |
| # Synchronize results from all GPUs | |
| all_results = [None] * world_size | |
| dist.all_gather_object(all_results, full_results) | |
| if rank == 0: | |
| # Merge results from all GPUs | |
| final_results = [] | |
| for results in all_results: | |
| final_results.extend(results) | |
| # Save results | |
| if os.path.dirname(args.outfile): | |
| os.makedirs(os.path.dirname(args.outfile), exist_ok=True) | |
| with open(args.outfile, "w") as fp: | |
| pd.DataFrame(final_results).to_json(fp, orient="records", lines=True) | |
| print("All GPUs have completed their tasks and the final results have been saved.") | |
| else: | |
| print(f"GPU {rank} has completed all tasks") | |