Spaces:
Runtime error
Runtime error
| import torch | |
| class VitLaionPreProcess(torch.nn.Module): | |
| def __init__(self, processor): | |
| super().__init__() | |
| self.processor = processor | |
| def forward(self, img): | |
| out = self.processor(images=img, return_tensors="pt") | |
| return out.data['pixel_values'].squeeze() | |
| class VitLaionFeatureExtractor(torch.nn.Module): | |
| def __init__(self, model, processor): | |
| super().__init__() | |
| self.vit_model = model | |
| self.transforms = VitLaionPreProcess(processor) | |
| def forward(self, x): | |
| img_a, img_b = x | |
| return self.vit_model.get_image_features(pixel_values=img_a), self.vit_model.get_image_features( | |
| pixel_values=img_b) | |