FoodVision_Big / model.py
mulasagg's picture
add all files
cd8171c
raw
history blame contribute delete
944 Bytes
import torch
import torchvision
from torch import nn
from timm import create_model #type: ignore
from torchvision import transforms
def create_swin_transformer():
# Load the pretrained Swin Transformer model
model = create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=101)
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomRotation(15),
transforms.RandomAffine(degrees=15, scale=(0.8, 1.2)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return model,train_transforms