import torch import torchvision from torch import nn from torchvision import transforms def create_vit_model(num_classes:int=28, seed:int=42): """Creates an ViTB16 feature extractor model and transforms. Args: num_classes (int, optional): number of classes in the classifier head. Defaults to 28. seed (int, optional): random seed value. Defaults to 42. Returns: model (torch.nn.Module): ViTB16 feature extractor model. transforms (torchvision.transforms): ViTB16 image transforms. """ # Create ViTB16 pretrained weights, transforms and model weights = torchvision.models.ViT_B_16_Weights.DEFAULT # Get transforms from weights vit_transforms = weights.transforms() # Extend the vit_transforms to include grayscale conversion, since vit is trained on 3-channel RGB transform_pipeline = transforms.Compose([ transforms.Grayscale(num_output_channels=3), # Convert grayscale to 3-channel RGB vit_transforms # Append the existing transforms ]) # transforms = weights.transforms() model = torchvision.models.vit_b_16(weights=weights) # Freeze all layers in base model for param in model.parameters(): param.requires_grad = False # Change heads with random seed for reproducibility model.heads = torch.nn.Sequential( nn.Linear(in_features=768, out_features=28, # Number of Arabic letters = our classes bias=True)) return model, transform_pipeline