File size: 715 Bytes
94a0812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# models/image_projection.py

import torch
import torch.nn as nn


class ImageProjection(nn.Module):
    """
    Projects encoder image embeddings into the T5 hidden size.

    Example:
        - CLIP ViT-L/14 gives 1024-d embeddings
        - T5-small expects 512-d hidden states
        β†’ This linear layer maps 1024 β†’ 512

    Forward:
        image_embeds: (B, D_enc) or (B, S, D_enc)
        returns:
            projected_embeds: (B, D_t5) or (B, S, D_t5)
    """

    def __init__(self, encoder_dim: int, t5_hidden_size: int):
        super().__init__()
        self.proj = nn.Linear(encoder_dim, t5_hidden_size)

    def forward(self, image_embeds: torch.Tensor):
        return self.proj(image_embeds)