raster2seq / eval_seg.py
anas
Initial deployment of Raster2Seq floor plan vectorization API
fadb92b
import inspect
import json
import os
import sys
from os.path import isfile, join, realpath
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), "evaluations"))
from clipseg_eval.general_utils import (
AttributeDict,
filter_args,
get_attribute,
score_config_from_cli_args,
)
sys.path.append(str(Path(__file__).resolve().parent.parent))
from datasets import build_dataset
from detectron2.data.detection_utils import annotations_to_instances
DATASET_CACHE = dict()
def load_model(
checkpoint_id, weights_file=None, strict=True, model_args="from_config", with_config=False, ignore_weights=False
):
config = json.load(open(join("logs", checkpoint_id, "config.json")))
if model_args != "from_config" and type(model_args) != dict:
raise ValueError('model_args must either be "from_config" or a dictionary of values')
model_cls = get_attribute(config["model"])
# load model
if model_args == "from_config":
_, model_args, _ = filter_args(config, inspect.signature(model_cls).parameters)
model = model_cls(**model_args)
if weights_file is None:
weights_file = realpath(join("logs", checkpoint_id, "weights.pth"))
else:
weights_file = realpath(join("logs", checkpoint_id, weights_file))
if isfile(weights_file) and not ignore_weights:
weights = torch.load(weights_file)
for _, w in weights.items():
assert not torch.any(torch.isnan(w)), "weights contain NaNs"
model.load_state_dict(weights, strict=strict)
else:
if not ignore_weights:
raise FileNotFoundError(f"model checkpoint {weights_file} was not found")
if with_config:
return model, config
return model
def read_pred_json(json_file_path, image_size=(256, 256), mask_format="bitmask"):
# Read and parse the JSON file
with open(json_file_path, "r") as file:
predictions = json.load(file)
for i, p in enumerate(predictions):
predictions[i]["segmentation"] = [np.array(p["segmentation"]).flatten()]
pred = annotations_to_instances(predictions, image_size, mask_format, no_boxes=True)
return pred
def compute_shift2(model, datasets, seed=123, repetitions=1):
"""computes shift"""
model.eval()
model.cuda()
import random
random.seed(seed)
preds, gts = [], []
for i_dataset, dataset in enumerate(datasets):
loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, drop_last=False)
max_iterations = int(repetitions * len(dataset.dataset.data_list))
with torch.no_grad():
i = []
for i_all, (data_x, data_y) in enumerate(loader):
data_x = [v.cuda(non_blocking=True) if v is not None else v for v in data_x]
data_y = [v.cuda(non_blocking=True) if v is not None else v for v in data_y]
(pred,) = model(data_x[0], data_x[1], data_x[2])
preds += [pred.detach()]
gts += [data_y]
i += 1
if max_iterations and i >= max_iterations:
break
from metrics import FixedIntervalMetrics
n_values = 25 # 51
thresholds = np.linspace(0, 1, n_values)[1:-1]
metric = FixedIntervalMetrics(resize_pred=True, sigmoid=True, n_values=n_values)
for p, y in zip(preds, gts):
metric.add(p.unsqueeze(1), y)
best_idx = np.argmax(metric.value()["fgiou_scores"])
best_thresh = thresholds[best_idx]
return best_thresh
def get_cached_pascal_pfe(split, config):
from datasets.pfe_dataset import PFEPascalWrapper
try:
dataset = DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)]
except KeyError:
dataset = PFEPascalWrapper(
mode="val", split=split, mask=config.mask, image_size=config.image_size, label_support=config.label_support
)
DATASET_CACHE[(split, config.image_size, config.label_support, config.mask)] = dataset
return dataset
def main():
config, train_checkpoint_id = score_config_from_cli_args()
metrics = score(config, train_checkpoint_id, None)
print(metrics)
def score(config, train_checkpoint_id, train_config):
config = AttributeDict(config)
print(config)
metric_args = dict()
if "threshold" in config:
if config.metric.split(".")[-1] == "SkLearnMetrics":
metric_args["threshold"] = config.threshold
if "resize_to" in config:
metric_args["resize_to"] = config.resize_to
if "sigmoid" in config:
metric_args["sigmoid"] = config.sigmoid
if "custom_threshold" in config:
metric_args["custom_threshold"] = config.custom_threshold
if config.test_dataset == "waffle":
coco_dataset = build_dataset(image_set="test", args=config)
coco_dataset[0]
def trivial_batch_collator(batch):
"""
A batch collator that does nothing.
"""
return batch
loader = DataLoader(
coco_dataset,
batch_size=config.batch_size,
num_workers=2,
shuffle=False,
drop_last=False,
collate_fn=trivial_batch_collator,
)
metric = get_attribute(config.metric)(resize_pred=False, n_values=25, **metric_args)
shift = config.shift if "shift" in config else 0
pred_json_root = config.pred_json_root
with torch.no_grad():
i = 0
for i_all, batch_data in enumerate(tqdm(loader)):
image_path = batch_data[0]["file_name"]
data_y = batch_data[0]["instances"].gt_masks.tensor[None, ...]
gt_classes = batch_data[0]["instances"].gt_classes[None, ...]
interior_mask = gt_classes == 0
data_y = data_y[interior_mask][None, ...]
data_y = torch.sum(data_y, dim=1, keepdim=True).clamp(0, 1) # Shape: Bx1xHxW
pred = read_pred_json(
os.path.join(pred_json_root, os.path.basename(image_path).split(".")[0] + ".json"),
image_size=(config.image_size, config.image_size),
mask_format=config.mask_format,
)
if len(pred) == 0:
pred = torch.zeros_like(data_y)
else:
pred = pred.gt_masks.tensor[None, ...]
pred = torch.sum(pred, dim=1, keepdim=True).clamp(0, 1) # Shape: Bx1xHxW
metric.add(pred + shift, data_y)
i += 1
if config.max_iterations and i >= config.max_iterations:
break
key_prefix = config["name"] if "name" in config else "coco"
print(metric.scores())
return {key_prefix: metric.scores()}
if __name__ == "__main__":
main()