File size: 789 Bytes
56ef371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F


class ContrastiveAssign(nn.Module):

    def __init__(
        self,
        cal_bias: nn.Module = None,
    ) -> None:
        """Lanuage-Image Contrastive Assignment used to calculate the similarity between
        the text and the image.

        Args:
            cal_bias (nn.Module, optional): The bias used to calculate the similarity.
                Defaults to None.
            max_text_len (int, optional): The max length of the text. Defaults to 256.
        """
        super().__init__()
        self.cal_bias = cal_bias

    def forward(self, x: torch.Tensor, ref_dict: Dict):

        y = ref_dict["encoded_ref_feature"]
        res = x @ y.transpose(-1, -2)
        return res