st24hour commited on
Commit
3bcc19b
·
verified ·
1 Parent(s): 91fa856

Upload modeling_exaonepath_patch_encoder.py with huggingface_hub

Browse files
modeling_exaonepath_patch_encoder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Remote-code modeling file for EXAONE-Path Patch Encoder.
4
+
5
+ Unified with slide-encoder style:
6
+ - Keep this file small.
7
+ - At runtime, download the repo snapshot and import the actual model code from
8
+ `exaonepath/` (so we don't duplicate model definitions here).
9
+
10
+ This requires the Hub repo to include `exaonepath/`.
11
+ """
12
+
13
+ from typing import Any, Dict, Optional
14
+ import importlib
15
+ import sys
16
+
17
+ from huggingface_hub import snapshot_download
18
+ from torch import Tensor, nn
19
+ from transformers import PretrainedConfig, PreTrainedModel
20
+
21
+
22
+ class ExaonePathPatchEncoderConfig(PretrainedConfig):
23
+ model_type = "exaonepath_patch_encoder"
24
+
25
+ def __init__(
26
+ self,
27
+ image_encoder: str = "vitb",
28
+ patch_size: int = 14,
29
+ img_size=(224, 224),
30
+ extra_kwargs: Dict[str, Any] | None = None,
31
+ **kwargs: Any,
32
+ ):
33
+ self.image_encoder = str(image_encoder)
34
+ self.patch_size = int(patch_size)
35
+ if isinstance(img_size, int):
36
+ img_size = (img_size, img_size)
37
+ self.img_size = [int(img_size[0]), int(img_size[1])]
38
+ self.extra_kwargs = dict(extra_kwargs or {})
39
+ super().__init__(**kwargs)
40
+
41
+
42
+ class ExaonePathPatchEncoderModel(PreTrainedModel):
43
+ config_class = ExaonePathPatchEncoderConfig
44
+ base_model_prefix = "patch_encoder"
45
+
46
+ def __init__(self, config: ExaonePathPatchEncoderConfig):
47
+ super().__init__(config)
48
+
49
+ # Ensure the repo code (including `exaonepath/`) is available at runtime.
50
+ repo_id = getattr(config, "_name_or_path", None) or getattr(config, "name_or_path", None)
51
+ if isinstance(repo_id, str) and repo_id:
52
+ local_root = snapshot_download(repo_id)
53
+ if local_root not in sys.path:
54
+ sys.path.insert(0, local_root)
55
+
56
+ PatchEncoder = getattr(
57
+ importlib.import_module("exaonepath.models.patch_encoder_hf"),
58
+ "PatchEncoder",
59
+ )
60
+
61
+ extra = getattr(config, "extra_kwargs", None) or {}
62
+ self.patch_encoder: nn.Module = PatchEncoder(
63
+ image_encoder=config.image_encoder,
64
+ patch_size=int(config.patch_size),
65
+ img_size=list(config.img_size),
66
+ **extra,
67
+ )
68
+
69
+ self.post_init()
70
+
71
+ def forward(
72
+ self,
73
+ x: Optional[Tensor] = None,
74
+ *,
75
+ pixel_values: Optional[Tensor] = None,
76
+ **kwargs: Any,
77
+ ) -> Tensor:
78
+ """Return patch embedding as a tensor.
79
+
80
+ Returns:
81
+ patch_embedding: [B, C]
82
+
83
+ Note:
84
+ The patch encoder produces a single embedding per input patch image.
85
+ We return the tensor directly for the simplest user-facing API.
86
+ """
87
+
88
+ # Prefer the simple positional argument `x`, but also accept the
89
+ # Hugging Face convention `pixel_values=` for compatibility.
90
+ if x is None:
91
+ x = pixel_values
92
+ if x is None:
93
+ raise ValueError("Missing input tensor. Provide `x` (positional) or `pixel_values=`.")
94
+
95
+ return self.patch_encoder(x)
96
+
97
+
98
+ __all__ = ["ExaonePathPatchEncoderConfig", "ExaonePathPatchEncoderModel"]