Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchvision | |
| from torch import nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| class CFG: | |
| DEVICE = 'cpu' | |
| NUM_DEVICES = torch.cuda.device_count() | |
| NUM_WORKERS = os.cpu_count() | |
| NUM_CLASSES = 4 | |
| EPOCHS = 16 | |
| BATCH_SIZE = 32 | |
| LR = 0.001 | |
| APPLY_SHUFFLE = True | |
| SEED = 768 | |
| HEIGHT = 224 | |
| WIDTH = 224 | |
| CHANNELS = 3 | |
| IMAGE_SIZE = (224, 224, 3) | |
| class VisionTransformerModel(nn.Module): | |
| def __init__(self, backbone_model, name='vision-transformer', | |
| num_classes=CFG.NUM_CLASSES, device=CFG.DEVICE): | |
| super(VisionTransformerModel, self).__init__() | |
| self.backbone_model = backbone_model | |
| self.device = device | |
| self.num_classes = num_classes | |
| self.name = name | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Dropout(p=0.2, inplace=True), | |
| nn.Linear(in_features=1000, out_features=256, bias=True), | |
| nn.GELU(), | |
| nn.Dropout(p=0.2, inplace=True), | |
| nn.Linear(in_features=256, out_features=num_classes, bias=False) | |
| ).to(device) | |
| def forward(self, image): | |
| vit_output = self.backbone_model(image) | |
| return self.classifier(vit_output) | |
| def get_vit_b32_model( | |
| device: torch.device=CFG.NUM_CLASSES) -> nn.Module: | |
| # Set the manual seeds | |
| torch.manual_seed(CFG.SEED) | |
| torch.cuda.manual_seed(CFG.SEED) | |
| # Get model weights | |
| model_weights = ( | |
| torchvision | |
| .models | |
| .ViT_L_32_Weights | |
| .DEFAULT | |
| ) | |
| # Get model and push to device | |
| model = ( | |
| torchvision.models.vit_l_32( | |
| weights=model_weights | |
| ) | |
| ).to(device) | |
| # Freeze Model Parameters | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |
| # Get ViT model | |
| vit_backbone = get_vit_b32_model(CFG.DEVICE) | |
| vit_params = { | |
| 'backbone_model' : vit_backbone, | |
| 'name' : 'ViT-L-B32', | |
| 'device' : CFG.DEVICE | |
| } | |
| # Generate Model | |
| vit_model = VisionTransformerModel(**vit_params) | |
| vit_model.load_state_dict( | |
| torch.load('vit_model.pth', map_location=torch.device('cpu')) | |
| ) | |
| # Define the image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def predict(image_path): | |
| image = Image.open(image_path) | |
| input_tensor = transform(image) | |
| input_batch = input_tensor.unsqueeze(0).to(CFG.DEVICE) # Add batch dimension | |
| # Perform inference | |
| with torch.no_grad(): | |
| output = vit_model(input_batch).to(CFG.DEVICE) | |
| # You can now use the 'output' tensor as needed (e.g., get predictions) | |
| return output |