Spaces:
Build error
Build error
| import argparse | |
| import json | |
| import os | |
| from typing import Any, Dict, List | |
| from loguru import logger | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import wordsegment as ws | |
| from virtex.config import Config | |
| from virtex.data import ImageDirectoryDataset | |
| from virtex.factories import TokenizerFactory, PretrainingModelFactory | |
| from virtex.utils.checkpointing import CheckpointManager | |
| from virtex.utils.common import common_parser | |
| ws.load() | |
| # fmt: off | |
| parser = common_parser( | |
| description="Decode captions using a RedCaps-pretrained VirTex model." | |
| ) | |
| parser.add_argument( | |
| "--images", required=True, | |
| help="Path to a directory containing image files to generate captions for." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-path", required=True, | |
| help="Path to load checkpoint and run captioning evaluation." | |
| ) | |
| parser.add_argument( | |
| "--output", required=True, | |
| help="Path to save predictions as a JSON file." | |
| ) | |
| parser.add_argument( | |
| "--subreddit-prompt", default=None, | |
| help="Optional subreddit prompt for controllable subreddit-style captioning." | |
| ) | |
| # fmt: on | |
| def main(_A: argparse.Namespace): | |
| if _A.num_gpus_per_machine == 0: | |
| # Set device as CPU if num_gpus_per_machine = 0. | |
| device = torch.device("cpu") | |
| else: | |
| # Get the current device (this will be zero here by default). | |
| device = torch.cuda.current_device() | |
| _C = Config(_A.config, _A.config_override) | |
| tokenizer = TokenizerFactory.from_config(_C) | |
| val_dataloader = DataLoader( | |
| ImageDirectoryDataset(_A.images), | |
| batch_size=_C.OPTIM.BATCH_SIZE, | |
| num_workers=_A.cpu_workers, | |
| pin_memory=True, | |
| ) | |
| # Initialize model from a checkpoint. | |
| model = PretrainingModelFactory.from_config(_C).to(device) | |
| CheckpointManager(model=model).load(_A.checkpoint_path) | |
| model.eval() | |
| # Prepare subreddit prompt for the model if provided. | |
| if _A.subreddit_prompt is not None: | |
| # Remove "r/" if provided. | |
| _A.subreddit_prompt = _A.subreddit_prompt.replace("r/", "") | |
| # Word segmenting (e.g. "itookapicture" -> "i took a picture"). | |
| _segments = " ".join(ws.segment(ws.clean(_A.subreddit_prompt))) | |
| subreddit_tokens = ( | |
| [model.sos_index] | |
| + tokenizer.encode(_segments) | |
| + [tokenizer.token_to_id("[SEP]")] | |
| ) | |
| else: | |
| # Just seed the model with [SOS] | |
| subreddit_tokens = [model.sos_index] | |
| # Shift the subreddit prompt to appropriate device. | |
| subreddit_tokens = torch.tensor(subreddit_tokens, device=device).long() | |
| # Make a list of predictions to evaluate. | |
| predictions: List[Dict[str, Any]] = [] | |
| for val_batch in tqdm(val_dataloader): | |
| val_batch["image"] = val_batch["image"].to(device) | |
| # Add the subreddit tokens as decoding prompt to batch. | |
| val_batch["decode_prompt"] = subreddit_tokens | |
| with torch.no_grad(): | |
| output_dict = model(val_batch) | |
| for idx, (image_id, caption) in enumerate( | |
| zip(val_batch["image_id"], output_dict["predictions"]) | |
| ): | |
| caption = caption.tolist() | |
| # Replace [SOS] index with "::" temporarily so it gets decoded. | |
| if tokenizer.token_to_id("[SEP]") in caption: | |
| sos_index = caption.index(tokenizer.token_to_id("[SEP]")) | |
| caption[sos_index] = tokenizer.token_to_id("::") | |
| caption = tokenizer.decode(caption) | |
| # Separate out subreddit from the rest of caption. | |
| if "::" in caption: | |
| subreddit, rest_of_caption = caption.split("::") | |
| subreddit = "".join(subreddit.split()) | |
| rest_of_caption = rest_of_caption.strip() | |
| else: | |
| subreddit, rest_of_caption = "", caption | |
| predictions.append( | |
| {"image_id": image_id, "subreddit": subreddit, "caption": rest_of_caption} | |
| ) | |
| logger.info("Displaying first 25 caption predictions:") | |
| for pred in predictions[:25]: | |
| logger.info(f"{pred['image_id']} - r/{pred['subreddit']}:: {pred['caption']}") | |
| # Save predictions as a JSON file. | |
| os.makedirs(os.path.dirname(_A.output), exist_ok=True) | |
| json.dump(predictions, open(_A.output, "w")) | |
| logger.info(f"Saved predictions to {_A.output}") | |
| if __name__ == "__main__": | |
| _A = parser.parse_args() | |
| if _A.num_gpus_per_machine > 1: | |
| raise ValueError("Using multiple GPUs is not supported for this script.") | |
| # No distributed training here, just a single process. | |
| main(_A) | |