File size: 1,590 Bytes
e101805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from __future__ import annotations
from typing import Any, Sequence, Union
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from .patch_transformer import vit_base
class PatchEncoder(nn.Module, PyTorchModelHubMixin):
"""EXAONE-Path image patch encoder with Hugging Face Hub support.
This class wraps the ViT backbone used for patch-level feature extraction and
integrates with the Hub via :class:`huggingface_hub.PyTorchModelHubMixin`.
The configuration (``image_encoder``, ``patch_size``, ``img_size``, and any
extra keyword arguments) is automatically serialized to ``config.json`` when
calling ``save_pretrained``.
"""
def __init__(
self,
image_encoder: str = "vitb",
patch_size: int = 14,
img_size: Union[int, Sequence[int]] = 224,
**kwargs: Any,
) -> None:
super().__init__()
if isinstance(img_size, int):
img_size = [img_size, img_size]
self.image_encoder = image_encoder
self.patch_size = int(patch_size)
self.img_size = [int(img_size[0]), int(img_size[1])]
self.extra_kwargs = dict(kwargs)
if image_encoder == "vitb":
model_kwargs = dict(self.extra_kwargs)
model_kwargs["img_size"] = self.img_size # VisionTransformer expects [H, W]
self.backbone = vit_base(patch_size=self.patch_size, **model_kwargs)
else:
raise ValueError(f"Unsupported image_encoder for PatchEncoder: {image_encoder}")
def forward(self, x):
return self.backbone(x)
|