Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| class DinoVisionTransformerClassifier(nn.Module): | |
| def __init__(self, num_classes): | |
| super(DinoVisionTransformerClassifier, self).__init__() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Workaround to bypass HTTP Error 403 rate limit exceeded | |
| torch.hub._validate_not_a_forked_repo=lambda a,b,c: True | |
| self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_lc") | |
| self.model.linear_head = nn.Sequential( | |
| nn.Linear(3840, 512, bias=True), | |
| nn.ReLU(), | |
| nn.Linear(512, 256, bias=True), | |
| nn.ReLU(), | |
| nn.Linear(256, num_classes, bias=True) | |
| ) | |
| self.model.to(self.device) | |
| self.transform_image = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| self.model_name = "dinov2" | |
| def load_image_from_filepath(self, img: str) -> torch.Tensor: | |
| """ | |
| Load an image as filepath and return a tensor that can be used as an input to model. | |
| """ | |
| img = Image.open(img).convert('RGB') | |
| transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device) | |
| return transformed_img | |
| def load_image_from_pillowimage(self, img: Image.Image) -> torch.Tensor: | |
| """ | |
| Load an image as Pillow Image and return a tensor that can be used as an input to model. | |
| """ | |
| transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device) | |
| return transformed_img | |
| def forward(self, x): | |
| if isinstance(x, str): | |
| x = self.load_image_from_filepath(x) | |
| if isinstance(x, Image.Image): | |
| x = self.load_image_from_pillowimage(x) | |
| return self.model(x) | |