Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,113 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""
Base Information Sharing Class for UniCeption
"""
from dataclasses import dataclass
from typing import List, Optional
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor
from torch.utils.checkpoint import checkpoint
@dataclass
class InfoSharingInput:
pass
@dataclass
class InfoSharingOutput:
pass
class UniCeptionInfoSharingBase(nn.Module):
"Information Sharing Base Class for UniCeption"
def __init__(
self,
name: str,
size: Optional[str] = None,
*args,
**kwargs,
):
"""
Base class for all models in UniCeption.
"""
super().__init__(*args, **kwargs)
self.name: str = name
self.size: Optional[str] = size
def forward(
self,
model_input: InfoSharingInput,
) -> InfoSharingOutput:
"""
Forward interface for the UniCeption information sharing models.
Args:
model_input (InfoSharingInput): Input to the model.
This is also includes the other fields that are required by the specific implementation of the model.
Returns:
InfoSharingOutput: Output of the model.
"""
raise NotImplementedError
def wrap_module_with_gradient_checkpointing(self, module: nn.Module):
"""
Wrapper for Gradient Checkpointing
"""
class _CheckpointingWrapper(module.__class__):
_restore_cls = module.__class__
def forward(self, *args, **kwargs):
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
module.__class__ = _CheckpointingWrapper
return module
@dataclass
class MultiViewTransformerInput(InfoSharingInput):
"""
Input class for Multi-View Transformer.
"""
features: List[Float[Tensor, "batch input_embed_dim feat_height feat_width"]]
additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None
@dataclass
class MultiViewTransformerOutput(InfoSharingOutput):
"""
Output class for Multi-View Transformer.
"""
features: List[Float[Tensor, "batch transformer_embed_dim feat_height feat_width"]]
additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None
@dataclass
class MultiSetTransformerInput(InfoSharingInput):
"""
Input class for Multi-Set Transformer.
"""
features: List[Float[Tensor, "batch input_embed_dim num_tokens"]]
additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None
@dataclass
class MultiSetTransformerOutput(InfoSharingOutput):
"""
Output class for Multi-Set Transformer.
"""
features: List[Float[Tensor, "batch transformer_embed_dim num_tokens"]]
additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None
if __name__ == "__main__":
dummy_model = UniCeptionInfoSharingBase(name="dummy")
print("Dummy Base InfoSharing model created successfully!")
|