|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class ImageProjection(nn.Module): |
|
|
""" |
|
|
Projects encoder image embeddings into the T5 hidden size. |
|
|
|
|
|
Example: |
|
|
- CLIP ViT-L/14 gives 1024-d embeddings |
|
|
- T5-small expects 512-d hidden states |
|
|
β This linear layer maps 1024 β 512 |
|
|
|
|
|
Forward: |
|
|
image_embeds: (B, D_enc) or (B, S, D_enc) |
|
|
returns: |
|
|
projected_embeds: (B, D_t5) or (B, S, D_t5) |
|
|
""" |
|
|
|
|
|
def __init__(self, encoder_dim: int, t5_hidden_size: int): |
|
|
super().__init__() |
|
|
self.proj = nn.Linear(encoder_dim, t5_hidden_size) |
|
|
|
|
|
def forward(self, image_embeds: torch.Tensor): |
|
|
return self.proj(image_embeds) |
|
|
|