st24hour commited on
Commit
6f4e480
·
verified ·
1 Parent(s): c1b26ff

Upload modeling_exaonepath_slide_encoder.py with huggingface_hub

Browse files
modeling_exaonepath_slide_encoder.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """Remote-code modeling file for EXAONE-Path Slide/WSI encoder.
4
+
5
+ This file is imported by Transformers when using `trust_remote_code=True`.
6
+
7
+ Important:
8
+ - This file acts as a *thin AutoModel entrypoint*.
9
+ - The actual implementation lives in `exaonepath.models.slide_encoder_hf`.
10
+ - At runtime, the repository snapshot is downloaded via `snapshot_download`
11
+ and added to `sys.path` so that `exaonepath/` can be imported.
12
+ - Do NOT import sibling modules like `configuration_exaonepath_slide_encoder` here.
13
+ Transformers' remote-code dependency checker treats those imports as missing
14
+ third-party packages (e.g. it suggests `pip install configuration_exaonepath_slide_encoder`).
15
+ """
16
+
17
+ from typing import Any, Dict, Optional
18
+ import importlib
19
+ import sys
20
+
21
+ from huggingface_hub import snapshot_download
22
+ from torch import Tensor, nn
23
+ from transformers import PretrainedConfig, PreTrainedModel
24
+
25
+
26
+ class ExaonePathSlideEncoderConfig(PretrainedConfig):
27
+ """Self-contained Transformers config for EXAONE-Path Slide/WSI encoder.
28
+
29
+ Keep it here (in the modeling file) so we don't need a separate
30
+ `configuration_exaonepath_slide_encoder.py` on the Hub.
31
+ """
32
+
33
+ model_type = "exaonepath_slide_encoder"
34
+
35
+ def __init__(self, wsi_cfg: Dict[str, Any] | None = None, **kwargs: Any):
36
+ self.wsi_cfg = dict(wsi_cfg or {})
37
+ super().__init__(**kwargs)
38
+
39
+
40
+ class ExaonePathSlideEncoderModel(PreTrainedModel):
41
+ config_class = ExaonePathSlideEncoderConfig
42
+ base_model_prefix = "slide_encoder"
43
+
44
+ def __init__(self, config: ExaonePathSlideEncoderConfig):
45
+ super().__init__(config)
46
+
47
+ # Ensure the repo code (including `exaonepath/`) is available at runtime.
48
+ # NOTE: config._name_or_path is usually the repo id when loaded from Hub.
49
+ repo_id = getattr(config, "_name_or_path", None) or getattr(config, "name_or_path", None)
50
+ if isinstance(repo_id, str) and repo_id:
51
+ local_root = snapshot_download(repo_id)
52
+ if local_root not in sys.path:
53
+ sys.path.insert(0, local_root)
54
+
55
+ WSIEncoder = getattr(importlib.import_module("exaonepath.models.slide_encoder_hf"), "WSIEncoder")
56
+ self.slide_encoder: nn.Module = WSIEncoder.from_wsi_config(wsi_cfg=config.wsi_cfg)
57
+
58
+ self.post_init()
59
+
60
+ def forward(
61
+ self,
62
+ patch_features: Tensor,
63
+ patch_mask: Tensor,
64
+ patch_coords: Optional[Tensor] = None,
65
+ patch_contour_index: Optional[Tensor] = None,
66
+ **kwargs: Any,
67
+ ) -> Dict[str, Tensor]:
68
+ """Return patch- and slide-level embeddings.
69
+
70
+ Returns a dict with exactly two keys:
71
+ - "patch_embedding": [B, N, C_in + D]
72
+ - "slide_embedding": [B, C_in + D]
73
+
74
+ Note: We intentionally return a plain dict (instead of a ModelOutput)
75
+ to make the remote-code API explicit and easy to use.
76
+ """
77
+
78
+ out: Dict[str, Tensor] = self.slide_encoder(
79
+ patch_features=patch_features,
80
+ patch_mask=patch_mask,
81
+ patch_coords=patch_coords,
82
+ patch_contour_index=patch_contour_index,
83
+ )
84
+ return out
85
+
86
+
87
+ __all__ = ["ExaonePathSlideEncoderModel"]