Spaces:
Sleeping
Sleeping
| from itertools import zip_longest | |
| from typing import Sequence, Dict, Union | |
| import torch | |
| from lightning_utilities.core.rank_zero import rank_zero_warn | |
| from torch import nn | |
| class MultiEntityInteraction(nn.Module): | |
| def __init__( | |
| self, | |
| encoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], | |
| decoders: Union[nn.Module, Sequence[nn.Module], Dict[str, nn.Module]], | |
| ): | |
| super().__init__() | |
| # Add new encoders to MultiEntityInteraction. | |
| if isinstance(encoders, nn.Module): | |
| # set compatible with original type expectations | |
| encoders = [encoders] | |
| elif isinstance(encoders, Sequence): | |
| # Check all values are encoders | |
| for i, encoder in enumerate(encoders): | |
| if not isinstance(encoder, nn.Module): | |
| raise ValueError( | |
| f"Value {encoder} at index {i} is not an instance of `nn.Module`." | |
| ) | |
| elif isinstance(encoders, dict): | |
| # Check all values are encoders | |
| for k, encoder in encoders.items(): | |
| if not isinstance(encoder, nn.Module): | |
| raise ValueError( | |
| f"Value {encoder} at key {k} is not an instance of `nn.Module`." | |
| ) | |
| else: | |
| raise ValueError( | |
| "Unknown input to MultiEntityInteraction. Expected, `nn.Module`, or `dict`/`sequence` of the" | |
| f" previous, but got {encoders}" | |
| ) | |
| self.encoders = encoders | |
| self.decoders = decoders | |
| def forward(self, inputs): | |
| preds = [encoder(x) for encoder, x in zip_longest(self.encoders, inputs)] | |
| return preds | |