File size: 3,138 Bytes
3bcc19b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import annotations

"""Remote-code modeling file for EXAONE-Path Patch Encoder.

Unified with slide-encoder style:
- Keep this file small.
- At runtime, download the repo snapshot and import the actual model code from
  `exaonepath/` (so we don't duplicate model definitions here).

This requires the Hub repo to include `exaonepath/`.
"""

from typing import Any, Dict, Optional
import importlib
import sys

from huggingface_hub import snapshot_download
from torch import Tensor, nn
from transformers import PretrainedConfig, PreTrainedModel


class ExaonePathPatchEncoderConfig(PretrainedConfig):
    model_type = "exaonepath_patch_encoder"

    def __init__(
        self,
        image_encoder: str = "vitb",
        patch_size: int = 14,
        img_size=(224, 224),
        extra_kwargs: Dict[str, Any] | None = None,
        **kwargs: Any,
    ):
        self.image_encoder = str(image_encoder)
        self.patch_size = int(patch_size)
        if isinstance(img_size, int):
            img_size = (img_size, img_size)
        self.img_size = [int(img_size[0]), int(img_size[1])]
        self.extra_kwargs = dict(extra_kwargs or {})
        super().__init__(**kwargs)


class ExaonePathPatchEncoderModel(PreTrainedModel):
    config_class = ExaonePathPatchEncoderConfig
    base_model_prefix = "patch_encoder"

    def __init__(self, config: ExaonePathPatchEncoderConfig):
        super().__init__(config)

        # Ensure the repo code (including `exaonepath/`) is available at runtime.
        repo_id = getattr(config, "_name_or_path", None) or getattr(config, "name_or_path", None)
        if isinstance(repo_id, str) and repo_id:
            local_root = snapshot_download(repo_id)
            if local_root not in sys.path:
                sys.path.insert(0, local_root)

        PatchEncoder = getattr(
            importlib.import_module("exaonepath.models.patch_encoder_hf"),
            "PatchEncoder",
        )

        extra = getattr(config, "extra_kwargs", None) or {}
        self.patch_encoder: nn.Module = PatchEncoder(
            image_encoder=config.image_encoder,
            patch_size=int(config.patch_size),
            img_size=list(config.img_size),
            **extra,
        )

        self.post_init()

    def forward(
        self,
        x: Optional[Tensor] = None,
        *,
        pixel_values: Optional[Tensor] = None,
        **kwargs: Any,
    ) -> Tensor:
        """Return patch embedding as a tensor.

        Returns:
            patch_embedding: [B, C]

        Note:
            The patch encoder produces a single embedding per input patch image.
            We return the tensor directly for the simplest user-facing API.
        """

        # Prefer the simple positional argument `x`, but also accept the
        # Hugging Face convention `pixel_values=` for compatibility.
        if x is None:
            x = pixel_values
        if x is None:
            raise ValueError("Missing input tensor. Provide `x` (positional) or `pixel_values=`.")

        return self.patch_encoder(x)


__all__ = ["ExaonePathPatchEncoderConfig", "ExaonePathPatchEncoderModel"]