|
|
import sys |
|
|
|
|
|
sys.path.append(".") |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from typing import Optional, Dict |
|
|
|
|
|
import hydra |
|
|
import torch |
|
|
from hydra.utils import instantiate |
|
|
from datasets import DatasetDict, load_dataset, IterableDatasetDict |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from src.data.transforms import SamCaptionerDataTransform |
|
|
from src.data.collator import SamCaptionerDataCollator |
|
|
from src.arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments |
|
|
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor |
|
|
|
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
from transformers import set_seed, Trainer |
|
|
from dataclasses import dataclass |
|
|
import numpy as np |
|
|
from functools import partial |
|
|
import pandas as pd |
|
|
import json |
|
|
import tqdm |
|
|
import yaml |
|
|
from src.train import prepare_datasets |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf") |
|
|
def main(args: DictConfig) -> None: |
|
|
|
|
|
logger.info(OmegaConf.to_yaml(args)) |
|
|
|
|
|
logger.info(f"Add command args: +no_sanity_check=False") |
|
|
no_sanity_check = args.get("no_sanity_check", False) |
|
|
output_dir = args.training.output_dir |
|
|
if output_dir is None: |
|
|
raise ValueError("output_dir is None, which should not happen.") |
|
|
|
|
|
|
|
|
set_seed(args.training.seed) |
|
|
|
|
|
|
|
|
|
|
|
train_dataset, eval_dataset = prepare_datasets(args) |
|
|
|
|
|
for eval_dataset_k, eval_dataset_v in eval_dataset.items(): |
|
|
replace_one_eval_dataset(no_sanity_check, output_dir, eval_dataset_k, eval_dataset_v) |
|
|
|
|
|
|
|
|
def replace_one_eval_dataset(no_sanity_check, output_dir, eval_dataset_name, eval_dataset): |
|
|
infer_json_dir = os.path.join(output_dir, "infer") |
|
|
json_path = os.path.join(infer_json_dir, f"infer-{eval_dataset_name}.json") |
|
|
if not os.path.exists(json_path): |
|
|
raise ValueError(f"json_path={json_path} does not exist, which should not happen.") |
|
|
infer_replace_gt_json_dir = os.path.join(output_dir, "infer-post_processed") |
|
|
os.makedirs(infer_replace_gt_json_dir, exist_ok=True) |
|
|
output_json_path = os.path.join(infer_replace_gt_json_dir, f"infer-{eval_dataset_name}.json") |
|
|
|
|
|
with open(json_path, "r") as f: |
|
|
json_data = json.load(f) |
|
|
with open(output_json_path, "w") as f: |
|
|
json.dump({}, f, indent=4) |
|
|
|
|
|
if no_sanity_check is False: |
|
|
|
|
|
logger.info(f"Check the sanity. We want the region orders in both eval_dataset and json_data are the same.") |
|
|
json_data_region_cnt = 0 |
|
|
for sample in tqdm.tqdm(eval_dataset): |
|
|
try: |
|
|
for region in sample["regions"]: |
|
|
eval_dataset_region_id = region["region_id"] |
|
|
eval_dataset_image_id = region["image_id"] |
|
|
json_data_region_id = json_data[json_data_region_cnt]["metadata"]["metadata_region_id"] |
|
|
json_data_image_id = json_data[json_data_region_cnt]["metadata"]["metadata_image_id"] |
|
|
assert eval_dataset_region_id == json_data_region_id |
|
|
assert eval_dataset_image_id == json_data_image_id |
|
|
json_data_region_cnt += 1 |
|
|
except IndexError as e: |
|
|
logger.warning(f"Error: {e}. There are not enough samples in the predction, so we stop here.") |
|
|
break |
|
|
|
|
|
|
|
|
json_data_region_cnt = 0 |
|
|
pbar = tqdm.tqdm(eval_dataset) |
|
|
for sample in pbar: |
|
|
try: |
|
|
for region in sample["regions"]: |
|
|
gt_phrases = region["phrases"] |
|
|
old_phrases = json_data[json_data_region_cnt]["references"] |
|
|
json_data[json_data_region_cnt]["references"] = gt_phrases |
|
|
|
|
|
json_data_region_cnt += 1 |
|
|
except IndexError as e: |
|
|
logger.warning(f"Error: {e}. There are not enough samples in the predction, so we stop here.") |
|
|
break |
|
|
|
|
|
logger.info(f"Save the new json_data to {output_json_path}") |
|
|
with open(output_json_path, "w") as f: |
|
|
json.dump(json_data, f, indent=4) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|