Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from leo.utils import get_mlp_head | |
| class SequentialGroundHead(nn.Module): | |
| def __init__(self, hidden_size=4096): | |
| super().__init__() | |
| # grounding head | |
| self.og3d_head = get_mlp_head( | |
| hidden_size * 2, hidden_size // 2, | |
| 1, dropout=0.1 | |
| ) | |
| def forward(self, obj_embeds, grd_embdes, obj_masks=None): | |
| txt_embeds = grd_embdes | |
| og3d_logits = self.og3d_head(torch.cat((obj_embeds, txt_embeds.repeat(1, obj_embeds.shape[1], 1)), dim=2)).squeeze(2) | |
| if obj_masks is not None: | |
| og3d_logits = og3d_logits.masked_fill_(obj_masks.logical_not(), -float('inf')) | |
| return og3d_logits |