import inspect import json import os import sys from os.path import isfile, join, realpath from pathlib import Path import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm sys.path.append(os.path.join(os.path.dirname(__file__), "evaluations")) from clipseg_eval.general_utils import ( AttributeDict, filter_args, get_attribute, score_config_from_cli_args, ) sys.path.append(str(Path(__file__).resolve().parent.parent)) from datasets import build_dataset from detectron2.data.detection_utils import annotations_to_instances DATASET_CACHE = dict() def load_model( checkpoint_id, weights_file=None, strict=True, model_args="from_config", with_config=False, ignore_weights=False ): config = json.load(open(join("logs", checkpoint_id, "config.json"))) if model_args != "from_config" and type(model_args) != dict: raise ValueError('model_args must either be "from_config" or a dictionary of values') model_cls = get_attribute(config["model"]) # load model if model_args == "from_config": _, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters) model = model_cls(**model_args) if weights_file is None: weights_file = realpath(join("logs", checkpoint_id, "weights.pth")) else: weights_file = realpath(join("logs", checkpoint_id, weights_file)) if isfile(weights_file) and not ignore_weights: weights = torch.load(weights_file) for _, w in weights.items(): assert not torch.any(torch.isnan(w)), "weights contain NaNs" model.load_state_dict(weights, strict=strict) else: if not ignore_weights: raise FileNotFoundError(f"model checkpoint {weights_file} was not found") if with_config: return model, config return model def read_pred_json(json_file_path, image_size=(256, 256), mask_format="bitmask"): # Read and parse the JSON file with open(json_file_path, "r") as file: predictions = json.load(file) for i, p in enumerate(predictions): predictions[i]["segmentation"] = [np.array(p["segmentation"]).flatten()] pred = annotations_to_instances(predictions, image_size, mask_format, no_boxes=True) return pred def compute_shift2(model, datasets, seed=123, repetitions=1): """computes shift""" model.eval() model.cuda() import random random.seed(seed) preds, gts = [], [] for i_dataset, dataset in enumerate(datasets): loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False) max_iterations = int(repetitions * len(dataset.dataset.data_list)) with torch.no_grad(): i = [] for i_all, (data_x, data_y) in enumerate(loader): data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x] data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y] (pred,) = model(data_x[0], data_x[1], data_x[2]) preds += [pred.detach()] gts += [data_y] i += 1 if max_iterations and i >= max_iterations: break from metrics import FixedIntervalMetrics n_values = 25 # 51 thresholds = np.linspace(0, 1, n_values)[1:-1] metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values) for p, y in zip(preds, gts): metric.add(p.unsqueeze(1), y) best_idx = np.argmax(metric.value()["fgiou_scores"]) best_thresh = thresholds[best_idx] return best_thresh def get_cached_pascal_pfe(split, config): from datasets.pfe_dataset import PFEPascalWrapper try: dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] except KeyError: dataset = PFEPascalWrapper( mode="val", split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support ) DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset return dataset def main(): config, train_checkpoint_id = score_config_from_cli_args() metrics = score(config, train_checkpoint_id, None) print(metrics) def score(config, train_checkpoint_id, train_config): config = AttributeDict(config) print(config) metric_args = dict() if "threshold" in config: if config.metric.split(".")[-1] == "SkLearnMetrics": metric_args["threshold"] = config.threshold if "resize_to" in config: metric_args["resize_to"] = config.resize_to if "sigmoid" in config: metric_args["sigmoid"] = config.sigmoid if "custom_threshold" in config: metric_args["custom_threshold"] = config.custom_threshold if config.test_dataset == "waffle": coco_dataset = build_dataset(image_set="test", args=config) coco_dataset[0] def trivial_batch_collator(batch): """ A batch collator that does nothing. """ return batch loader = DataLoader( coco_dataset, batch_size=config.batch_size, num_workers=2, shuffle=False, drop_last=False, collate_fn=trivial_batch_collator, ) metric = get_attribute(config.metric)(resize_pred=False, n_values=25, **metric_args) shift = config.shift if "shift" in config else 0 pred_json_root = config.pred_json_root with torch.no_grad(): i = 0 for i_all, batch_data in enumerate(tqdm(loader)): image_path = batch_data[0]["file_name"] data_y = batch_data[0]["instances"].gt_masks.tensor[None, ...] gt_classes = batch_data[0]["instances"].gt_classes[None, ...] interior_mask = gt_classes == 0 data_y = data_y[interior_mask][None, ...] data_y = torch.sum(data_y, dim=1, keepdim=True).clamp(0, 1) # Shape: Bx1xHxW pred = read_pred_json( os.path.join(pred_json_root, os.path.basename(image_path).split(".")[0] + ".json"), image_size=(config.image_size, config.image_size), mask_format=config.mask_format, ) if len(pred) == 0: pred = torch.zeros_like(data_y) else: pred = pred.gt_masks.tensor[None, ...] pred = torch.sum(pred, dim=1, keepdim=True).clamp(0, 1) # Shape: Bx1xHxW metric.add(pred + shift, data_y) i += 1 if config.max_iterations and i >= config.max_iterations: break key_prefix = config["name"] if "name" in config else "coco" print(metric.scores()) return {key_prefix: metric.scores()} if __name__ == "__main__": main()