File size: 4,630 Bytes
002bd9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys

sys.path.append(".")

import logging
import os
from typing import Optional, Dict

import hydra
import torch
from hydra.utils import instantiate
from datasets import DatasetDict, load_dataset, IterableDatasetDict
from omegaconf import DictConfig, OmegaConf
from src.data.transforms import SamCaptionerDataTransform
from src.data.collator import SamCaptionerDataCollator
from src.arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor

from transformers.trainer_utils import get_last_checkpoint
from transformers import set_seed, Trainer
from dataclasses import dataclass
import numpy as np
from functools import partial
import pandas as pd
import json
import tqdm
import yaml
from src.train import prepare_datasets

logger = logging.getLogger(__name__)


@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf")
def main(args: DictConfig) -> None:
    # NOTE(xiaoke): follow https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py
    logger.info(OmegaConf.to_yaml(args))

    logger.info(f"Add command args: +no_sanity_check=False")
    no_sanity_check = args.get("no_sanity_check", False)
    output_dir = args.training.output_dir
    if output_dir is None:
        raise ValueError("output_dir is None, which should not happen.")

    # Set seed before initializing model.
    set_seed(args.training.seed)

    # Initialize our dataset and prepare it
    # NOTE(xiaoke): We should only inference one eval dataset
    train_dataset, eval_dataset = prepare_datasets(args)

    for eval_dataset_k, eval_dataset_v in eval_dataset.items():
        replace_one_eval_dataset(no_sanity_check, output_dir, eval_dataset_k, eval_dataset_v)


def replace_one_eval_dataset(no_sanity_check, output_dir, eval_dataset_name, eval_dataset):
    infer_json_dir = os.path.join(output_dir, "infer")
    json_path = os.path.join(infer_json_dir, f"infer-{eval_dataset_name}.json")
    if not os.path.exists(json_path):
        raise ValueError(f"json_path={json_path} does not exist, which should not happen.")
    infer_replace_gt_json_dir = os.path.join(output_dir, "infer-post_processed")
    os.makedirs(infer_replace_gt_json_dir, exist_ok=True)
    output_json_path = os.path.join(infer_replace_gt_json_dir, f"infer-{eval_dataset_name}.json")

    with open(json_path, "r") as f:
        json_data = json.load(f)
    with open(output_json_path, "w") as f:
        json.dump({}, f, indent=4)

    if no_sanity_check is False:
        # NOTE: Check the sanity. We want the region orders in both eval_dataset and json_data are the same.
        logger.info(f"Check the sanity. We want the region orders in both eval_dataset and json_data are the same.")
        json_data_region_cnt = 0
        for sample in tqdm.tqdm(eval_dataset):
            try:
                for region in sample["regions"]:
                    eval_dataset_region_id = region["region_id"]
                    eval_dataset_image_id = region["image_id"]
                    json_data_region_id = json_data[json_data_region_cnt]["metadata"]["metadata_region_id"]
                    json_data_image_id = json_data[json_data_region_cnt]["metadata"]["metadata_image_id"]
                    assert eval_dataset_region_id == json_data_region_id
                    assert eval_dataset_image_id == json_data_image_id
                    json_data_region_cnt += 1
            except IndexError as e:
                logger.warning(f"Error: {e}. There are not enough samples in the predction, so we stop here.")
                break

    # NOTE: Now we start to replace the references in json_data with the references in eval_dataset
    json_data_region_cnt = 0
    pbar = tqdm.tqdm(eval_dataset)
    for sample in pbar:
        try:
            for region in sample["regions"]:
                gt_phrases = region["phrases"]
                old_phrases = json_data[json_data_region_cnt]["references"]
                json_data[json_data_region_cnt]["references"] = gt_phrases
                # pbar.set_description(f"old: {old_phrases}, new: {gt_phrases}")
                json_data_region_cnt += 1
        except IndexError as e:
            logger.warning(f"Error: {e}. There are not enough samples in the predction, so we stop here.")
            break

    logger.info(f"Save the new json_data to {output_json_path}")
    with open(output_json_path, "w") as f:
        json.dump(json_data, f, indent=4)


if __name__ == "__main__":
    main()