|
|
from __future__ import annotations |
|
|
|
|
|
"""Remote-code modeling file for EXAONE-Path Patch Encoder. |
|
|
|
|
|
Unified with slide-encoder style: |
|
|
- Keep this file small. |
|
|
- At runtime, download the repo snapshot and import the actual model code from |
|
|
`exaonepath/` (so we don't duplicate model definitions here). |
|
|
|
|
|
This requires the Hub repo to include `exaonepath/`. |
|
|
""" |
|
|
|
|
|
from typing import Any, Dict, Optional |
|
|
import importlib |
|
|
import sys |
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
from torch import Tensor, nn |
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
|
|
|
class ExaonePathPatchEncoderConfig(PretrainedConfig): |
|
|
model_type = "exaonepath_patch_encoder" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_encoder: str = "vitb", |
|
|
patch_size: int = 14, |
|
|
img_size=(224, 224), |
|
|
extra_kwargs: Dict[str, Any] | None = None, |
|
|
**kwargs: Any, |
|
|
): |
|
|
self.image_encoder = str(image_encoder) |
|
|
self.patch_size = int(patch_size) |
|
|
if isinstance(img_size, int): |
|
|
img_size = (img_size, img_size) |
|
|
self.img_size = [int(img_size[0]), int(img_size[1])] |
|
|
self.extra_kwargs = dict(extra_kwargs or {}) |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class ExaonePathPatchEncoderModel(PreTrainedModel): |
|
|
config_class = ExaonePathPatchEncoderConfig |
|
|
base_model_prefix = "patch_encoder" |
|
|
|
|
|
def __init__(self, config: ExaonePathPatchEncoderConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
repo_id = getattr(config, "_name_or_path", None) or getattr(config, "name_or_path", None) |
|
|
if isinstance(repo_id, str) and repo_id: |
|
|
local_root = snapshot_download(repo_id) |
|
|
if local_root not in sys.path: |
|
|
sys.path.insert(0, local_root) |
|
|
|
|
|
PatchEncoder = getattr( |
|
|
importlib.import_module("exaonepath.models.patch_encoder_hf"), |
|
|
"PatchEncoder", |
|
|
) |
|
|
|
|
|
extra = getattr(config, "extra_kwargs", None) or {} |
|
|
self.patch_encoder: nn.Module = PatchEncoder( |
|
|
image_encoder=config.image_encoder, |
|
|
patch_size=int(config.patch_size), |
|
|
img_size=list(config.img_size), |
|
|
**extra, |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Optional[Tensor] = None, |
|
|
*, |
|
|
pixel_values: Optional[Tensor] = None, |
|
|
**kwargs: Any, |
|
|
) -> Tensor: |
|
|
"""Return patch embedding as a tensor. |
|
|
|
|
|
Returns: |
|
|
patch_embedding: [B, C] |
|
|
|
|
|
Note: |
|
|
The patch encoder produces a single embedding per input patch image. |
|
|
We return the tensor directly for the simplest user-facing API. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if x is None: |
|
|
x = pixel_values |
|
|
if x is None: |
|
|
raise ValueError("Missing input tensor. Provide `x` (positional) or `pixel_values=`.") |
|
|
|
|
|
return self.patch_encoder(x) |
|
|
|
|
|
|
|
|
__all__ = ["ExaonePathPatchEncoderConfig", "ExaonePathPatchEncoderModel"] |
|
|
|