Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Base Information Sharing Class for UniCeption | |
| """ | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import torch.nn as nn | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| from torch.utils.checkpoint import checkpoint | |
| class InfoSharingInput: | |
| pass | |
| class InfoSharingOutput: | |
| pass | |
| class UniCeptionInfoSharingBase(nn.Module): | |
| "Information Sharing Base Class for UniCeption" | |
| def __init__( | |
| self, | |
| name: str, | |
| size: Optional[str] = None, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| Base class for all models in UniCeption. | |
| """ | |
| super().__init__(*args, **kwargs) | |
| self.name: str = name | |
| self.size: Optional[str] = size | |
| def forward( | |
| self, | |
| model_input: InfoSharingInput, | |
| ) -> InfoSharingOutput: | |
| """ | |
| Forward interface for the UniCeption information sharing models. | |
| Args: | |
| model_input (InfoSharingInput): Input to the model. | |
| This is also includes the other fields that are required by the specific implementation of the model. | |
| Returns: | |
| InfoSharingOutput: Output of the model. | |
| """ | |
| raise NotImplementedError | |
| def wrap_module_with_gradient_checkpointing(self, module: nn.Module): | |
| """ | |
| Wrapper for Gradient Checkpointing | |
| """ | |
| class _CheckpointingWrapper(module.__class__): | |
| _restore_cls = module.__class__ | |
| def forward(self, *args, **kwargs): | |
| return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) | |
| module.__class__ = _CheckpointingWrapper | |
| return module | |
| class MultiViewTransformerInput(InfoSharingInput): | |
| """ | |
| Input class for Multi-View Transformer. | |
| """ | |
| features: List[Float[Tensor, "batch input_embed_dim feat_height feat_width"]] | |
| additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None | |
| class MultiViewTransformerOutput(InfoSharingOutput): | |
| """ | |
| Output class for Multi-View Transformer. | |
| """ | |
| features: List[Float[Tensor, "batch transformer_embed_dim feat_height feat_width"]] | |
| additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None | |
| class MultiSetTransformerInput(InfoSharingInput): | |
| """ | |
| Input class for Multi-Set Transformer. | |
| """ | |
| features: List[Float[Tensor, "batch input_embed_dim num_tokens"]] | |
| additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None | |
| class MultiSetTransformerOutput(InfoSharingOutput): | |
| """ | |
| Output class for Multi-Set Transformer. | |
| """ | |
| features: List[Float[Tensor, "batch transformer_embed_dim num_tokens"]] | |
| additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None | |
| if __name__ == "__main__": | |
| dummy_model = UniCeptionInfoSharingBase(name="dummy") | |
| print("Dummy Base InfoSharing model created successfully!") | |