Spaces:
Build error
Build error
| import torch | |
| import torchvision | |
| from torch import nn | |
| from torchvision import transforms | |
| from transformers import ViTForImageClassification | |
| from transformers import ViTImageProcessor | |
| from typing import List | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def create_vit(output_shape:int, classes:List, device:torch.device=device): | |
| """Creates a HuggingFace ViT model google/vit-base-patch16-224 | |
| Args: | |
| output_shape: The output shape | |
| classes: A list of classes | |
| device: A torch.device | |
| Returns: | |
| A tuple of the model, train_transforms, val_transforms, test_transforms | |
| """ | |
| id2label = {id:label for id, label in enumerate(classes)} | |
| label2id = {label:id for id,label in id2label.items()} | |
| model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', | |
| num_labels=len(classes), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ignore_mismatched_sizes=True) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Can add dropout here if needed | |
| model.classifier = nn.Linear(in_features=768, out_features=output_shape) | |
| #https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb | |
| processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| image_mean = processor.image_mean | |
| image_std = processor.image_std | |
| size = processor.size["height"] | |
| normalize = transforms.Normalize(mean=image_mean, std=image_std) | |
| train_transforms = transforms.Compose([ | |
| #transforms.RandomResizedCrop(size), | |
| transforms.Resize(size), | |
| transforms.CenterCrop(size), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| normalize]) | |
| val_transforms = transforms.Compose([ | |
| transforms.Resize(size), | |
| transforms.CenterCrop(size), | |
| transforms.ToTensor(), | |
| normalize]) | |
| test_transforms = val_transforms | |
| return model.to(device), train_transforms, val_transforms, test_transforms | |