File size: 464 Bytes
94a0812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# models/base_encoder.py

import torch
import torch.nn as nn
from abc import ABC, abstractmethod


class BaseVisionEncoder(nn.Module, ABC):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.embed_dim = embed_dim

    @abstractmethod
    def forward(self, pixel_values: torch.Tensor):
        pass

    @abstractmethod
    def get_output_dim(self):
        """Return the dimensionality of the encoder output embedding."""
        pass