coco-demo / models /encoder_projection_t5.py
evanec's picture
Upload 5 files
94a0812 verified
raw
history blame contribute delete
715 Bytes
# models/image_projection.py
import torch
import torch.nn as nn
class ImageProjection(nn.Module):
"""
Projects encoder image embeddings into the T5 hidden size.
Example:
- CLIP ViT-L/14 gives 1024-d embeddings
- T5-small expects 512-d hidden states
β†’ This linear layer maps 1024 β†’ 512
Forward:
image_embeds: (B, D_enc) or (B, S, D_enc)
returns:
projected_embeds: (B, D_t5) or (B, S, D_t5)
"""
def __init__(self, encoder_dim: int, t5_hidden_size: int):
super().__init__()
self.proj = nn.Linear(encoder_dim, t5_hidden_size)
def forward(self, image_embeds: torch.Tensor):
return self.proj(image_embeds)