Spaces:
Build error
Build error
| import argparse | |
| import json | |
| import os | |
| import random | |
| from typing import Any, Dict, List | |
| from loguru import logger | |
| import torch | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from torch.nn.utils.rnn import pad_sequence | |
| from tqdm import tqdm | |
| import wordsegment as ws | |
| from virtex.config import Config | |
| from virtex.data import ZeroShotDataset | |
| from virtex.data.tokenizers import SentencePieceBPETokenizer | |
| from virtex.factories import TokenizerFactory, VisualBackboneFactory,TextualHeadFactory | |
| from virtex.utils.checkpointing import CheckpointManager | |
| from virtex.utils.common import common_parser | |
| from virtex.utils.metrics import TopkAccuracy | |
| import virtex.utils.distributed as dist | |
| #importing classifier | |
| from virtex.models.zero_shot_classification_eval import ZeroShotClassifier | |
| ws.load() | |
| # fmt: off | |
| parser = common_parser( | |
| description="""Run image captioning inference on a pretrained model, and/or | |
| evaluate pretrained model on COCO Captions val2017 split.""" | |
| ) | |
| parser.add_argument( | |
| "--data-root", default=None, | |
| help="""Path to a directory containing image files to generate captions for imagenet. | |
| Default: COCO val2017 image directory as expected relative to project root.""" | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-path", required=False, | |
| help="Path to load checkpoint and run captioning evaluation." | |
| ) | |
| parser.add_argument( | |
| "--output", default=None, | |
| help="Path to save predictions as a JSON file." | |
| ) | |
| parser.add_argument( | |
| "--calc-metrics", action="store_true", | |
| help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. | |
| This flag should not be set when running inference on arbitrary images.""" | |
| ) | |
| parser.add_argument( | |
| "--idx_label_dict", default=None, required=False, | |
| help="""a dictionary that maps from lable index to label string for classification""" | |
| ) | |
| parser.add_argument( | |
| "--is_redcaps", default=None, required=False, | |
| help="""a dictionary that maps from lable index to label string for""" | |
| ) | |
| parser.add_argument( | |
| "--prompt_cls_sos", default=None, required=False, | |
| help="""a dictionary that maps from lable index to label string for""" | |
| ) | |
| parser.add_argument( | |
| "--prompt_sos_eos", default=None, required=False, | |
| help="""a dictionary that maps from lable index to label string for""" | |
| ) | |
| # fmt: on | |
| print("###########") | |
| print(os.getcwd() ) | |
| print("###########") | |
| tokenizer = SentencePieceBPETokenizer("datasets_1/vocab/common_32k.model") | |
| def main(_A: argparse.Namespace): | |
| if _A.num_gpus_per_machine == 0: | |
| # Set device as CPU if num_gpus_per_machine = 0. | |
| device = torch.device("cpu") | |
| else: | |
| # Get the current device (this will be zero here by default). | |
| device = torch.cuda.current_device() | |
| _C = Config(_A.config, _A.config_override) | |
| #tokenizer = TokenizerFactory.from_config(_C) | |
| if _A.data_root is None: | |
| _A.data_root = os.path.join(_C.DATA.ROOT, "val2017") | |
| if _A.is_redcaps == 1: | |
| model_dataset = 'redcaps' | |
| else: | |
| model_dataset = 'gcc or sbu' | |
| print(_A.idx_label_dict) | |
| val_dataset = ZeroShotDataset(data_root=_A.data_root, | |
| split="test/", | |
| label_map=_A.idx_label_dict, | |
| tokenizer=tokenizer, | |
| prompt_cls_sos=_A.prompt_cls_sos.replace("_", " "), | |
| prompt_sos_eos=_A.prompt_sos_eos.replace("_", " ")) | |
| val_dataloader = DataLoader( | |
| val_dataset, | |
| batch_size= _C.OPTIM.BATCH_SIZE // dist.get_world_size(), | |
| num_workers=_A.cpu_workers, | |
| sampler=DistributedSampler( | |
| val_dataset, | |
| num_replicas=dist.get_world_size(), | |
| rank=dist.get_rank(), | |
| ), | |
| pin_memory=True, | |
| drop_last=False, | |
| collate_fn=val_dataset.collate_fn, | |
| ) | |
| # Initialize model from a checkpoint | |
| visual = VisualBackboneFactory.from_config(_C) | |
| textual = TextualHeadFactory.from_config(_C) | |
| model = ZeroShotClassifier(visual,textual) | |
| ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) | |
| model.to(device).eval() | |
| ## setup distributed training | |
| if dist.get_world_size() > 1: | |
| dist.synchronize() | |
| model = nn.parallel.DistributedDataParallel( | |
| model, device_ids=[device], find_unused_parameters=True | |
| ) | |
| top_1 = TopkAccuracy(top_k=1) | |
| top_5 = TopkAccuracy(top_k=5) | |
| batch_num = 0 | |
| for val_iteration, val_batch in tqdm(enumerate(val_dataloader, start=1)): | |
| val_batch["image"] = val_batch["image"].to(device) | |
| val_batch["caption_tokens"] = val_batch["caption_tokens"].to(device) | |
| val_batch["noitpac_tokens"] = val_batch["noitpac_tokens"] .to(device) | |
| val_batch["caption_lengths"] = val_batch["caption_lengths"].to(device) | |
| val_batch["label"] = val_batch["label"].to(device) | |
| with torch.no_grad(): | |
| classification_losses = model(val_batch) | |
| batch_num+=1 | |
| top_1(classification_losses, val_batch["label"]) | |
| top_1_acc = top_1.get_metric(reset=False) | |
| dist.average_across_processes(top_1_acc) | |
| top_5(classification_losses, val_batch["label"]) | |
| top_5_acc = top_5.get_metric(reset=False) | |
| dist.average_across_processes(top_5_acc) | |
| logger.info(f"Iter: {val_iteration} | Top-1 accuracy: {top_1_acc} | Top-5 accuracy: {top_5_acc}") | |
| if __name__ == "__main__": | |
| _A = parser.parse_args() | |
| #if _A.num_gpus_per_machine > 1: | |
| # raise ValueError("Using multiple GPUs is not supported for this script.") | |
| # No distributed training here, just a single process. | |
| main(_A) | |