File size: 4,630 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import sys
sys.path.append(".")
import logging
import os
from typing import Optional, Dict
import hydra
import torch
from hydra.utils import instantiate
from datasets import DatasetDict, load_dataset, IterableDatasetDict
from omegaconf import DictConfig, OmegaConf
from src.data.transforms import SamCaptionerDataTransform
from src.data.collator import SamCaptionerDataCollator
from src.arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor
from transformers.trainer_utils import get_last_checkpoint
from transformers import set_seed, Trainer
from dataclasses import dataclass
import numpy as np
from functools import partial
import pandas as pd
import json
import tqdm
import yaml
from src.train import prepare_datasets
logger = logging.getLogger(__name__)
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf")
def main(args: DictConfig) -> None:
# NOTE(xiaoke): follow https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py
logger.info(OmegaConf.to_yaml(args))
logger.info(f"Add command args: +no_sanity_check=False")
no_sanity_check = args.get("no_sanity_check", False)
output_dir = args.training.output_dir
if output_dir is None:
raise ValueError("output_dir is None, which should not happen.")
# Set seed before initializing model.
set_seed(args.training.seed)
# Initialize our dataset and prepare it
# NOTE(xiaoke): We should only inference one eval dataset
train_dataset, eval_dataset = prepare_datasets(args)
for eval_dataset_k, eval_dataset_v in eval_dataset.items():
replace_one_eval_dataset(no_sanity_check, output_dir, eval_dataset_k, eval_dataset_v)
def replace_one_eval_dataset(no_sanity_check, output_dir, eval_dataset_name, eval_dataset):
infer_json_dir = os.path.join(output_dir, "infer")
json_path = os.path.join(infer_json_dir, f"infer-{eval_dataset_name}.json")
if not os.path.exists(json_path):
raise ValueError(f"json_path={json_path} does not exist, which should not happen.")
infer_replace_gt_json_dir = os.path.join(output_dir, "infer-post_processed")
os.makedirs(infer_replace_gt_json_dir, exist_ok=True)
output_json_path = os.path.join(infer_replace_gt_json_dir, f"infer-{eval_dataset_name}.json")
with open(json_path, "r") as f:
json_data = json.load(f)
with open(output_json_path, "w") as f:
json.dump({}, f, indent=4)
if no_sanity_check is False:
# NOTE: Check the sanity. We want the region orders in both eval_dataset and json_data are the same.
logger.info(f"Check the sanity. We want the region orders in both eval_dataset and json_data are the same.")
json_data_region_cnt = 0
for sample in tqdm.tqdm(eval_dataset):
try:
for region in sample["regions"]:
eval_dataset_region_id = region["region_id"]
eval_dataset_image_id = region["image_id"]
json_data_region_id = json_data[json_data_region_cnt]["metadata"]["metadata_region_id"]
json_data_image_id = json_data[json_data_region_cnt]["metadata"]["metadata_image_id"]
assert eval_dataset_region_id == json_data_region_id
assert eval_dataset_image_id == json_data_image_id
json_data_region_cnt += 1
except IndexError as e:
logger.warning(f"Error: {e}. There are not enough samples in the predction, so we stop here.")
break
# NOTE: Now we start to replace the references in json_data with the references in eval_dataset
json_data_region_cnt = 0
pbar = tqdm.tqdm(eval_dataset)
for sample in pbar:
try:
for region in sample["regions"]:
gt_phrases = region["phrases"]
old_phrases = json_data[json_data_region_cnt]["references"]
json_data[json_data_region_cnt]["references"] = gt_phrases
# pbar.set_description(f"old: {old_phrases}, new: {gt_phrases}")
json_data_region_cnt += 1
except IndexError as e:
logger.warning(f"Error: {e}. There are not enough samples in the predction, so we stop here.")
break
logger.info(f"Save the new json_data to {output_json_path}")
with open(output_json_path, "w") as f:
json.dump(json_data, f, indent=4)
if __name__ == "__main__":
main()
|