|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, Optional, Tuple |
|
|
|
|
|
from torch import Tensor, nn |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
from .slide_transformer import VisionTransformer |
|
|
from .slide_encoder_head import WSIEncoderHead |
|
|
|
|
|
|
|
|
def _build_wsi_encoder(wsi_cfg: DictConfig) -> Tuple[nn.Module, int]: |
|
|
"""Construct a WSIFeatureEncoder composed of a VisionTransformer. |
|
|
|
|
|
This is a minimal, WSI-only factory equivalent to |
|
|
``MultiModalMetaModel._build_wsi_encoder`` but without importing the |
|
|
full multimodal meta model. |
|
|
""" |
|
|
|
|
|
embed_dim = int(wsi_cfg.get("embed_dim", 768)) |
|
|
input_dim = int(wsi_cfg.get("input_dim", 768)) |
|
|
|
|
|
transformer_kwargs = { |
|
|
"input_dim": input_dim, |
|
|
"patch_size": int(wsi_cfg.get("patch_size", 256)), |
|
|
"embed_use_norm": bool(wsi_cfg.get("embed_use_norm", True)), |
|
|
"embed_dim": embed_dim, |
|
|
"depth": int(wsi_cfg.get("depth", 12)), |
|
|
"num_heads": int(wsi_cfg.get("num_heads", 12)), |
|
|
"ffn_ratio": float(wsi_cfg.get("ffn_ratio", 4.0)), |
|
|
"qkv_bias": bool(wsi_cfg.get("qkv_bias", True)), |
|
|
"norm_layer": wsi_cfg.get("norm_layer", "layernorm"), |
|
|
"ffn_layer": wsi_cfg.get("ffn_layer", "swiglu128"), |
|
|
"ffn_bias": bool(wsi_cfg.get("ffn_bias", True)), |
|
|
"proj_bias": bool(wsi_cfg.get("proj_bias", True)), |
|
|
"ffn_drop": float(wsi_cfg.get("ffn_drop", 0.0)), |
|
|
"attn_drop": float(wsi_cfg.get("attn_drop", 0.0)), |
|
|
"n_storage_tokens": int(wsi_cfg.get("n_storage_tokens", 0)), |
|
|
"nope_interval": int(wsi_cfg.get("nope_interval", 2)), |
|
|
|
|
|
"pos_embed_rope_base": wsi_cfg.get("pos_embed_rope_base", 10000.0), |
|
|
"pos_embed_rope_min_period": wsi_cfg.get("pos_embed_rope_min_period"), |
|
|
"pos_embed_rope_max_period": wsi_cfg.get("pos_embed_rope_max_period"), |
|
|
"pos_embed_rope_dtype": wsi_cfg.get("pos_embed_rope_dtype", "fp32"), |
|
|
} |
|
|
|
|
|
|
|
|
transformer = VisionTransformer(**transformer_kwargs) |
|
|
if hasattr(transformer, "init_weights"): |
|
|
transformer.init_weights() |
|
|
|
|
|
wsi_encoder = WSIEncoderHead( |
|
|
transformer, |
|
|
input_dim, |
|
|
embed_dim, |
|
|
) |
|
|
|
|
|
return wsi_encoder |
|
|
|
|
|
|
|
|
class WSIEncoder(nn.Module, PyTorchModelHubMixin): |
|
|
"""WSI slide-level encoder wrapper with Hugging Face Hub support. |
|
|
|
|
|
This wraps the internal :class:`WSIFeatureEncoder` (ViT + aggregation) |
|
|
used in EXAONE-Path for slide-level feature extraction, and exposes it |
|
|
as a Hub-compatible model via :class:`PyTorchModelHubMixin`. |
|
|
|
|
|
The minimal configuration (``wsi_cfg``) |
|
|
is stored on the instance so that it can be serialized to ``config.json`` |
|
|
when calling :meth:`save_pretrained`. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
wsi_cfg: Dict[str, Any], |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.wsi_cfg: Dict[str, Any] = dict(wsi_cfg) |
|
|
|
|
|
|
|
|
cfg_obj: DictConfig = OmegaConf.create(self.wsi_cfg) |
|
|
if isinstance(cfg_obj, DictConfig): |
|
|
OmegaConf.resolve(cfg_obj) |
|
|
|
|
|
wsi_encoder = _build_wsi_encoder(cfg_obj) |
|
|
self.wsi_encoder = wsi_encoder |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
patch_features: Tensor, |
|
|
patch_mask: Tensor, |
|
|
patch_coords: Optional[Tensor] = None, |
|
|
patch_contour_index: Optional[Tensor] = None, |
|
|
) -> Dict[str, Tensor]: |
|
|
"""Forward to underlying WSIFeatureEncoder. |
|
|
|
|
|
Args: |
|
|
patch_features: [B, N, C] |
|
|
patch_mask: [B, N] with 1 for valid tokens |
|
|
patch_coords: optional [B, N, 2] coords (for RoPE) |
|
|
patch_contour_index: optional [B, N] contour indices |
|
|
""" |
|
|
|
|
|
return self.wsi_encoder( |
|
|
patch_features=patch_features, |
|
|
patch_mask=patch_mask, |
|
|
patch_coords=patch_coords, |
|
|
patch_contour_index=patch_contour_index, |
|
|
) |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def from_wsi_config( |
|
|
cls, |
|
|
wsi_cfg: Dict[str, Any], |
|
|
) -> WSIEncoder: |
|
|
return cls(wsi_cfg=wsi_cfg) |