panda1835's picture
Update models.py
abc71eb
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from PIL import Image
class DinoVisionTransformerClassifier(nn.Module):
def __init__(self, num_classes):
super(DinoVisionTransformerClassifier, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Workaround to bypass HTTP Error 403 rate limit exceeded
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_lc")
self.model.linear_head = nn.Sequential(
nn.Linear(3840, 512, bias=True),
nn.ReLU(),
nn.Linear(512, 256, bias=True),
nn.ReLU(),
nn.Linear(256, num_classes, bias=True)
)
self.model.to(self.device)
self.transform_image = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.model_name = "dinov2"
def load_image_from_filepath(self, img: str) -> torch.Tensor:
"""
Load an image as filepath and return a tensor that can be used as an input to model.
"""
img = Image.open(img).convert('RGB')
transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)
return transformed_img
def load_image_from_pillowimage(self, img: Image.Image) -> torch.Tensor:
"""
Load an image as Pillow Image and return a tensor that can be used as an input to model.
"""
transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)
return transformed_img
def forward(self, x):
if isinstance(x, str):
x = self.load_image_from_filepath(x)
if isinstance(x, Image.Image):
x = self.load_image_from_pillowimage(x)
return self.model(x)