File size: 715 Bytes
94a0812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# models/image_projection.py
import torch
import torch.nn as nn
class ImageProjection(nn.Module):
"""
Projects encoder image embeddings into the T5 hidden size.
Example:
- CLIP ViT-L/14 gives 1024-d embeddings
- T5-small expects 512-d hidden states
β This linear layer maps 1024 β 512
Forward:
image_embeds: (B, D_enc) or (B, S, D_enc)
returns:
projected_embeds: (B, D_t5) or (B, S, D_t5)
"""
def __init__(self, encoder_dim: int, t5_hidden_size: int):
super().__init__()
self.proj = nn.Linear(encoder_dim, t5_hidden_size)
def forward(self, image_embeds: torch.Tensor):
return self.proj(image_embeds)
|