| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from ...configuration_utils import ConfigMixin, register_to_config |
| | from ...models.modeling_utils import ModelMixin |
| | from ...utils import BaseOutput |
| |
|
| |
|
| | @dataclass |
| | class ReduxImageEncoderOutput(BaseOutput): |
| | image_embeds: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class ReduxImageEncoder(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | redux_dim: int = 1152, |
| | txt_in_features: int = 4096, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) |
| | self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) |
| |
|
| | def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput: |
| | projected_x = self.redux_down(nn.functional.silu(self.redux_up(x))) |
| |
|
| | return ReduxImageEncoderOutput(image_embeds=projected_x) |
| |
|