| from transformers import PreTrainedModel, PretrainedConfig | |
| from .module import ConditionalViT | |
| class CondViTConfig(PretrainedConfig): | |
| model_type = "condvit" | |
| def __init__( | |
| self, | |
| input_resolution: int = 224, | |
| patch_size: int = 16, | |
| width: int = 768, | |
| layers: int = 12, | |
| heads: int = 12, | |
| output_dim: int = 512, | |
| n_categories: int = 10, | |
| **kwargs | |
| ): | |
| self.input_resolution = input_resolution | |
| self.patch_size = patch_size | |
| self.width = width | |
| self.layers = layers | |
| self.heads = heads | |
| self.output_dim = output_dim | |
| self.n_categories = n_categories | |
| super().__init__(**kwargs) | |
| class CondViTForEmbedding(PreTrainedModel): | |
| config_class = CondViTConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = ConditionalViT( | |
| input_resolution=config.input_resolution, | |
| patch_size=config.patch_size, | |
| width=config.width, | |
| layers=config.layers, | |
| heads=config.heads, | |
| output_dim=config.output_dim, | |
| n_categories=config.n_categories, | |
| ) | |
| def forward(self, pixel_values, category_indices=None): | |
| return self.model(imgs=pixel_values, c=category_indices) | |