| from transformers import ViTModel | |
| from torchvision import transforms | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| transformers.logging.set_verbosity_error() | |
| class VisionEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.vision_model = ViTModel.from_pretrained("google/vit-base-patch16-224") | |
| self.image_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def forward(self, images, device): | |
| if not isinstance(images, list): | |
| images = [images] | |
| processed_images = torch.stack([self.image_transform(image) for image in images]).to(device) | |
| with torch.no_grad(): | |
| pixel_values = self.vision_model(processed_images) | |
| image_features = pixel_values.last_hidden_state | |
| return image_features |