st24hour's picture
Upload folder using huggingface_hub
e101805 verified
from __future__ import annotations
from typing import Any, Sequence, Union
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from .patch_transformer import vit_base
class PatchEncoder(nn.Module, PyTorchModelHubMixin):
"""EXAONE-Path image patch encoder with Hugging Face Hub support.
This class wraps the ViT backbone used for patch-level feature extraction and
integrates with the Hub via :class:`huggingface_hub.PyTorchModelHubMixin`.
The configuration (``image_encoder``, ``patch_size``, ``img_size``, and any
extra keyword arguments) is automatically serialized to ``config.json`` when
calling ``save_pretrained``.
"""
def __init__(
self,
image_encoder: str = "vitb",
patch_size: int = 14,
img_size: Union[int, Sequence[int]] = 224,
**kwargs: Any,
) -> None:
super().__init__()
if isinstance(img_size, int):
img_size = [img_size, img_size]
self.image_encoder = image_encoder
self.patch_size = int(patch_size)
self.img_size = [int(img_size[0]), int(img_size[1])]
self.extra_kwargs = dict(kwargs)
if image_encoder == "vitb":
model_kwargs = dict(self.extra_kwargs)
model_kwargs["img_size"] = self.img_size # VisionTransformer expects [H, W]
self.backbone = vit_base(patch_size=self.patch_size, **model_kwargs)
else:
raise ValueError(f"Unsupported image_encoder for PatchEncoder: {image_encoder}")
def forward(self, x):
return self.backbone(x)