File size: 4,297 Bytes
e101805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations

from typing import Any, Dict, Optional, Tuple

from torch import Tensor, nn
from huggingface_hub import PyTorchModelHubMixin
from omegaconf import DictConfig, OmegaConf

from .slide_transformer import VisionTransformer
from .slide_encoder_head import WSIEncoderHead


def _build_wsi_encoder(wsi_cfg: DictConfig) -> Tuple[nn.Module, int]:
    """Construct a WSIFeatureEncoder composed of a VisionTransformer.

    This is a minimal, WSI-only factory equivalent to
    ``MultiModalMetaModel._build_wsi_encoder`` but without importing the
    full multimodal meta model.
    """

    embed_dim = int(wsi_cfg.get("embed_dim", 768))
    input_dim = int(wsi_cfg.get("input_dim", 768))

    transformer_kwargs = {
        "input_dim": input_dim,
        "patch_size": int(wsi_cfg.get("patch_size", 256)),
        "embed_use_norm": bool(wsi_cfg.get("embed_use_norm", True)),
        "embed_dim": embed_dim,
        "depth": int(wsi_cfg.get("depth", 12)),
        "num_heads": int(wsi_cfg.get("num_heads", 12)),
        "ffn_ratio": float(wsi_cfg.get("ffn_ratio", 4.0)),
        "qkv_bias": bool(wsi_cfg.get("qkv_bias", True)),
        "norm_layer": wsi_cfg.get("norm_layer", "layernorm"),
        "ffn_layer": wsi_cfg.get("ffn_layer", "swiglu128"),
        "ffn_bias": bool(wsi_cfg.get("ffn_bias", True)),
        "proj_bias": bool(wsi_cfg.get("proj_bias", True)),
        "ffn_drop": float(wsi_cfg.get("ffn_drop", 0.0)),
        "attn_drop": float(wsi_cfg.get("attn_drop", 0.0)),
        "n_storage_tokens": int(wsi_cfg.get("n_storage_tokens", 0)),
        "nope_interval": int(wsi_cfg.get("nope_interval", 2)),
        # Rope / coords related
        "pos_embed_rope_base": wsi_cfg.get("pos_embed_rope_base", 10000.0),
        "pos_embed_rope_min_period": wsi_cfg.get("pos_embed_rope_min_period"),
        "pos_embed_rope_max_period": wsi_cfg.get("pos_embed_rope_max_period"),
        "pos_embed_rope_dtype": wsi_cfg.get("pos_embed_rope_dtype", "fp32"),
    }

    # Build transformer with internal WSI patch embedding
    transformer = VisionTransformer(**transformer_kwargs)
    if hasattr(transformer, "init_weights"):
        transformer.init_weights()

    wsi_encoder = WSIEncoderHead(
        transformer,
        input_dim,
        embed_dim,
    )

    return wsi_encoder


class WSIEncoder(nn.Module, PyTorchModelHubMixin):
    """WSI slide-level encoder wrapper with Hugging Face Hub support.

    This wraps the internal :class:`WSIFeatureEncoder` (ViT + aggregation)
    used in EXAONE-Path for slide-level feature extraction, and exposes it
    as a Hub-compatible model via :class:`PyTorchModelHubMixin`.

    The minimal configuration (``wsi_cfg``)
    is stored on the instance so that it can be serialized to ``config.json``
    when calling :meth:`save_pretrained`.
    """

    def __init__(
        self,
        *,
        wsi_cfg: Dict[str, Any],
    ) -> None:
        super().__init__()

        # Store config in plain-dict form for easy JSON/YAML serialization
        self.wsi_cfg: Dict[str, Any] = dict(wsi_cfg)

        # Build encoder on CPU
        cfg_obj: DictConfig = OmegaConf.create(self.wsi_cfg)
        if isinstance(cfg_obj, DictConfig):
            OmegaConf.resolve(cfg_obj)

        wsi_encoder = _build_wsi_encoder(cfg_obj)
        self.wsi_encoder = wsi_encoder

    def forward(
        self,
        patch_features: Tensor,
        patch_mask: Tensor,
        patch_coords: Optional[Tensor] = None,
        patch_contour_index: Optional[Tensor] = None,
    ) -> Dict[str, Tensor]:
        """Forward to underlying WSIFeatureEncoder.

        Args:
            patch_features: [B, N, C]
            patch_mask: [B, N] with 1 for valid tokens
            patch_coords: optional [B, N, 2] coords (for RoPE)
            patch_contour_index: optional [B, N] contour indices
        """

        return self.wsi_encoder(
            patch_features=patch_features,
            patch_mask=patch_mask,
            patch_coords=patch_coords,
            patch_contour_index=patch_contour_index,
        )

    # Optional: expose a small helper to reconstruct from a minimal config dict
    @classmethod
    def from_wsi_config(
        cls,
        wsi_cfg: Dict[str, Any],
    ) -> WSIEncoder:
        return cls(wsi_cfg=wsi_cfg)