Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from torch import nn | |
| from ola_vlm.model.multimodal_projector.resampler import Resampler, TaskTokenResampler | |
| class GenHead(nn.Module): | |
| def __init__( | |
| self, | |
| proj_config: dict = None, | |
| llm_hidden_size: int = 4096, | |
| ) -> None: | |
| super().__init__() | |
| self.projector = Resampler( | |
| dim=proj_config["output_dim"], | |
| depth=proj_config["depth"], | |
| dim_head=proj_config["dim_head"], | |
| heads=proj_config["num_heads"], | |
| num_queries=proj_config["num_tokens"], | |
| embedding_dim=llm_hidden_size, | |
| output_dim=proj_config["output_dim"], | |
| ff_mult=proj_config["ff_mult"], | |
| ) | |
| def forward( | |
| self, | |
| llm_feats: torch.Tensor, | |
| ): | |
| gen_feats = self.projector(llm_feats) | |
| return gen_feats | |
| class TaskTokenGenHead(nn.Module): | |
| def __init__( | |
| self, | |
| proj_config: dict = None, | |
| llm_hidden_size: int = 4096, | |
| ) -> None: | |
| super().__init__() | |
| self.projector = TaskTokenResampler( | |
| dim=proj_config["output_dim"], | |
| depth=proj_config["depth"], | |
| dim_head=proj_config["dim_head"], | |
| heads=proj_config["num_heads"], | |
| num_queries=proj_config["num_tokens"], | |
| embedding_dim=llm_hidden_size, | |
| output_dim=proj_config["output_dim"], | |
| ff_mult=proj_config["ff_mult"], | |
| ) | |
| def forward( | |
| self, | |
| llm_feats: torch.Tensor, | |
| latents: torch.Tensor | |
| ): | |
| gen_feats = self.projector(llm_feats, latents) | |
| return gen_feats |