Spaces:
Runtime error
Runtime error
| from .image_classification import CIFAR10DataModule, ImageDataModule, MNISTDataModule | |
| from .transformations import UnNest | |
| from .visual_qa import CIFAR10QADataModule, ToyQADataModule | |
| from argparse import Namespace | |
| from transformers import ConvNextFeatureExtractor, ViTFeatureExtractor | |
| def get_configs(args: Namespace) -> tuple[dict, dict]: | |
| """Get the model and feature extractor configs from the command line args. | |
| Args: | |
| args (Namespace): the argparse Namespace object | |
| Returns: | |
| a tuple containing the model and feature extractor configs | |
| """ | |
| if args.dataset == "MNIST": | |
| # We upsample the MNIST images to 112x112, with 1 channel (grayscale) | |
| # and 10 classes (0-9). We normalize the image to have a mean of 0.5 | |
| # and a standard deviation of ±0.5. | |
| model_cfg_args = { | |
| "image_size": 112, | |
| "num_channels": 1, | |
| "num_labels": 10, | |
| } | |
| fe_cfg_args = { | |
| "image_mean": [0.5], | |
| "image_std": [0.5], | |
| } | |
| elif args.dataset.startswith("CIFAR10"): | |
| if args.dataset not in ("CIFAR10", "CIFAR10_QA"): | |
| raise Exception(f"Unknown CIFAR10 variant: {args.dataset}") | |
| # We upsample the CIFAR10 images to 224x224, with 3 channels (RGB) and | |
| # 10 classes (0-9) for the normal dataset, or (grid_size)^2 + 1 for the | |
| # toy task. We normalize the image to have a mean of 0.5 and a standard | |
| # deviation of ±0.5. | |
| model_cfg_args = { | |
| "image_size": 224, # fixed to 224 because pretrained models have that size | |
| "num_channels": 3, | |
| "num_labels": (args.grid_size**2) + 1 | |
| if args.dataset == "CIFAR10_QA" | |
| else 10, | |
| } | |
| fe_cfg_args = { | |
| "image_mean": [0.5, 0.5, 0.5], | |
| "image_std": [0.5, 0.5, 0.5], | |
| } | |
| elif args.dataset == "toy": | |
| # We use an image size so that each patch contains a single color, with | |
| # 3 channels (RGB) and (grid_size)^2 + 1 classes. We normalize the image | |
| # to have a mean of 0.5 and a standard deviation of ±0.5. | |
| model_cfg_args = { | |
| "image_size": args.grid_size * 16, | |
| "num_channels": 3, | |
| "num_labels": (args.grid_size**2) + 1, | |
| } | |
| fe_cfg_args = { | |
| "image_mean": [0.5, 0.5, 0.5], | |
| "image_std": [0.5, 0.5, 0.5], | |
| } | |
| else: | |
| raise Exception(f"Unknown dataset: {args.dataset}") | |
| # Set the feature extractor's size attribute to be the same as the model's image size | |
| fe_cfg_args["size"] = model_cfg_args["image_size"] | |
| # Set the tensors' return type to PyTorch tensors | |
| fe_cfg_args["return_tensors"] = "pt" | |
| return model_cfg_args, fe_cfg_args | |
| def datamodule_factory(args: Namespace) -> ImageDataModule: | |
| """A factory method for creating a datamodule based on the command line args. | |
| Args: | |
| args (Namespace): the argparse Namespace object | |
| Returns: | |
| an ImageDataModule instance | |
| """ | |
| # Get the model and feature extractor configs | |
| model_cfg_args, fe_cfg_args = get_configs(args) | |
| # Set the feature extractor class based on the provided base model name | |
| if args.base_model == "ViT": | |
| fe_class = ViTFeatureExtractor | |
| elif args.base_model == "ConvNeXt": | |
| fe_class = ConvNextFeatureExtractor | |
| else: | |
| raise Exception(f"Unknown base model: {args.base_model}") | |
| # Create the feature extractor instance | |
| if args.from_pretrained: | |
| feature_extractor = fe_class.from_pretrained( | |
| args.from_pretrained, **fe_cfg_args | |
| ) | |
| else: | |
| feature_extractor = fe_class(**fe_cfg_args) | |
| # Un-nest the feature extractor's output | |
| feature_extractor = UnNest(feature_extractor) | |
| # Define the datamodule's configuration | |
| dm_cfg = { | |
| "feature_extractor": feature_extractor, | |
| "batch_size": args.batch_size, | |
| "add_noise": args.add_noise, | |
| "add_rotation": args.add_rotation, | |
| "add_blur": args.add_blur, | |
| "num_workers": args.num_workers, | |
| } | |
| # Determine the dataset class based on the provided dataset name | |
| if args.dataset.startswith("CIFAR10"): | |
| if args.dataset == "CIFAR10": | |
| dm_class = CIFAR10DataModule | |
| elif args.dataset == "CIFAR10_QA": | |
| dm_cfg["class_idx"] = args.class_idx | |
| dm_cfg["grid_size"] = args.grid_size | |
| dm_class = CIFAR10QADataModule | |
| else: | |
| raise Exception(f"Unknown CIFAR10 variant: {args.dataset}") | |
| elif args.dataset == "MNIST": | |
| dm_class = MNISTDataModule | |
| elif args.dataset == "toy": | |
| dm_cfg["class_idx"] = args.class_idx | |
| dm_cfg["grid_size"] = args.grid_size | |
| dm_class = ToyQADataModule | |
| else: | |
| raise Exception(f"Unknown dataset: {args.dataset}") | |
| return dm_class(**dm_cfg) | |