|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Sequence, Union |
|
|
|
|
|
import torch.nn as nn |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
from .patch_transformer import vit_base |
|
|
|
|
|
|
|
|
class PatchEncoder(nn.Module, PyTorchModelHubMixin): |
|
|
"""EXAONE-Path image patch encoder with Hugging Face Hub support. |
|
|
|
|
|
This class wraps the ViT backbone used for patch-level feature extraction and |
|
|
integrates with the Hub via :class:`huggingface_hub.PyTorchModelHubMixin`. |
|
|
|
|
|
The configuration (``image_encoder``, ``patch_size``, ``img_size``, and any |
|
|
extra keyword arguments) is automatically serialized to ``config.json`` when |
|
|
calling ``save_pretrained``. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_encoder: str = "vitb", |
|
|
patch_size: int = 14, |
|
|
img_size: Union[int, Sequence[int]] = 224, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
if isinstance(img_size, int): |
|
|
img_size = [img_size, img_size] |
|
|
|
|
|
self.image_encoder = image_encoder |
|
|
self.patch_size = int(patch_size) |
|
|
self.img_size = [int(img_size[0]), int(img_size[1])] |
|
|
self.extra_kwargs = dict(kwargs) |
|
|
|
|
|
if image_encoder == "vitb": |
|
|
model_kwargs = dict(self.extra_kwargs) |
|
|
model_kwargs["img_size"] = self.img_size |
|
|
self.backbone = vit_base(patch_size=self.patch_size, **model_kwargs) |
|
|
else: |
|
|
raise ValueError(f"Unsupported image_encoder for PatchEncoder: {image_encoder}") |
|
|
|
|
|
def forward(self, x): |
|
|
return self.backbone(x) |
|
|
|