| import torch |
|
|
| from typing import Tuple |
| from dataclasses import dataclass |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
| from .csd import CSD |
| from .config import CSDConfig |
|
|
|
|
| @dataclass |
| class CSDOutput: |
| image_embeds: torch.Tensor |
| style_embeds: torch.Tensor |
| content_embeds: torch.Tensor |
|
|
|
|
| class CSDModel(PreTrainedModel): |
| config_class = CSDConfig |
|
|
| def __init__(self, config: CSDConfig) -> None: |
| super(CSDModel, self).__init__(config) |
|
|
| self.model = CSD( |
| vit_input_resolution=config.vit_input_resolution, |
| vit_patch_size=config.vit_patch_size, |
| vit_width=config.vit_width, |
| vit_layers=config.vit_layers, |
| vit_heads=config.vit_heads, |
| vit_output_dim=config.vit_output_dim, |
| ) |
|
|
| @torch.inference_mode() |
| def forward(self, pixel_values: torch.Tensor) -> CSDOutput: |
| image_embeds, style_embeds, content_embeds = self.model(pixel_values) |
| return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds) |
|
|