File size: 9,622 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | import os
import sys
from typing import Any, Dict, Optional
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
try:
from protify.base_models.supported_models import all_presets_with_paths
from protify.pooler import Pooler
from protify.probes.get_probe import rebuild_probe_from_saved_config
except ImportError:
current_dir = os.path.dirname(os.path.abspath(__file__))
candidate_paths = [
current_dir,
os.path.dirname(current_dir),
os.path.dirname(os.path.dirname(current_dir)),
os.path.join(current_dir, "src"),
]
for candidate in candidate_paths:
if os.path.isdir(candidate) and candidate not in sys.path:
sys.path.insert(0, candidate)
from protify.base_models.supported_models import all_presets_with_paths
from protify.pooler import Pooler
from protify.probes.get_probe import rebuild_probe_from_saved_config
class PackagedProbeConfig(PretrainedConfig):
model_type = "packaged_probe"
def __init__(
self,
base_model_name: str = "",
probe_type: str = "linear",
probe_config: Optional[Dict[str, Any]] = None,
tokenwise: bool = False,
matrix_embed: bool = False,
pooling_types: Optional[list[str]] = None,
task_type: str = "singlelabel",
num_labels: int = 2,
ppi: bool = False,
add_token_ids: bool = False,
sep_token_id: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.probe_type = probe_type
self.probe_config = {} if probe_config is None else probe_config
self.tokenwise = tokenwise
self.matrix_embed = matrix_embed
self.pooling_types = ["mean"] if pooling_types is None else pooling_types
self.task_type = task_type
self.num_labels = num_labels
self.ppi = ppi
self.add_token_ids = add_token_ids
self.sep_token_id = sep_token_id
class PackagedProbeModel(PreTrainedModel):
config_class = PackagedProbeConfig
base_model_prefix = "backbone"
all_tied_weights_keys = {}
def __init__(
self,
config: PackagedProbeConfig,
base_model: Optional[nn.Module] = None,
probe: Optional[nn.Module] = None,
):
super().__init__(config)
self.config = config
self.backbone = self._load_base_model() if base_model is None else base_model
self.probe = self._load_probe() if probe is None else probe
self.pooler = Pooler(self.config.pooling_types)
def _load_base_model(self) -> nn.Module:
if self.config.base_model_name in all_presets_with_paths:
model_path = all_presets_with_paths[self.config.base_model_name]
else:
model_path = self.config.base_model_name
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
return model
def _load_probe(self) -> nn.Module:
return rebuild_probe_from_saved_config(
probe_type=self.config.probe_type,
tokenwise=self.config.tokenwise,
probe_config=self.config.probe_config,
)
@staticmethod
def _extract_hidden_states(backbone_output: Any) -> torch.Tensor:
if isinstance(backbone_output, tuple):
return backbone_output[0]
if hasattr(backbone_output, "last_hidden_state"):
return backbone_output.last_hidden_state
if isinstance(backbone_output, torch.Tensor):
return backbone_output
raise ValueError("Unsupported backbone output format for packaged probe model")
@staticmethod
def _extract_attentions(backbone_output: Any) -> Optional[torch.Tensor]:
if hasattr(backbone_output, "attentions"):
return backbone_output.attentions
return None
def _build_ppi_segment_masks(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
if token_type_ids is not None and torch.any(token_type_ids == 1):
mask_a = ((token_type_ids == 0) & (attention_mask == 1)).long()
mask_b = ((token_type_ids == 1) & (attention_mask == 1)).long()
assert torch.all(mask_a.sum(dim=1) > 0), "PPI token_type_ids produced empty segment A"
assert torch.all(mask_b.sum(dim=1) > 0), "PPI token_type_ids produced empty segment B"
return mask_a, mask_b
assert self.config.sep_token_id is not None, "sep_token_id is required for PPI fallback segmentation"
batch_size, seq_len = input_ids.shape
mask_a = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
mask_b = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
for batch_idx in range(batch_size):
valid_positions = torch.where(attention_mask[batch_idx] == 1)[0]
sep_positions = torch.where((input_ids[batch_idx] == self.config.sep_token_id) & (attention_mask[batch_idx] == 1))[0]
if len(valid_positions) == 0:
continue
if len(sep_positions) >= 2:
first_sep = int(sep_positions[0].item())
second_sep = int(sep_positions[1].item())
mask_a[batch_idx, :first_sep + 1] = 1
mask_b[batch_idx, first_sep + 1:second_sep + 1] = 1
elif len(sep_positions) == 1:
first_sep = int(sep_positions[0].item())
mask_a[batch_idx, :first_sep + 1] = 1
mask_b[batch_idx, first_sep + 1: int(valid_positions[-1].item()) + 1] = 1
else:
midpoint = len(valid_positions) // 2
mask_a[batch_idx, valid_positions[:midpoint]] = 1
mask_b[batch_idx, valid_positions[midpoint:]] = 1
assert torch.all(mask_a.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment A"
assert torch.all(mask_b.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment B"
return mask_a, mask_b
def _build_probe_inputs(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
attentions: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.config.ppi and (not self.config.matrix_embed) and (not self.config.tokenwise):
mask_a, mask_b = self._build_ppi_segment_masks(input_ids, attention_mask, token_type_ids)
vec_a = self.pooler(hidden_states, attention_mask=mask_a, attentions=attentions)
vec_b = self.pooler(hidden_states, attention_mask=mask_b, attentions=attentions)
return torch.cat((vec_a, vec_b), dim=-1), None
if self.config.matrix_embed or self.config.tokenwise:
return hidden_states, attention_mask
pooled = self.pooler(hidden_states, attention_mask=attention_mask, attentions=attentions)
return pooled, None
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> SequenceClassifierOutput | TokenClassifierOutput:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
requires_attentions = "parti" in self.config.pooling_types and (not self.config.matrix_embed) and (not self.config.tokenwise)
backbone_kwargs: Dict[str, Any] = {"input_ids": input_ids, "attention_mask": attention_mask}
if requires_attentions:
backbone_kwargs["output_attentions"] = True
backbone_output = self.backbone(**backbone_kwargs)
hidden_states = self._extract_hidden_states(backbone_output)
attentions = self._extract_attentions(backbone_output)
if requires_attentions:
assert attentions is not None, "parti pooling requires base model attentions"
probe_embeddings, probe_attention_mask = self._build_probe_inputs(
hidden_states=hidden_states,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
attentions=attentions,
)
if self.config.probe_type == "linear":
return self.probe(embeddings=probe_embeddings, labels=labels)
if self.config.probe_type == "transformer":
forward_kwargs: Dict[str, Any] = {"embeddings": probe_embeddings, "labels": labels}
if probe_attention_mask is not None:
forward_kwargs["attention_mask"] = probe_attention_mask
if self.config.add_token_ids and token_type_ids is not None and probe_attention_mask is not None:
forward_kwargs["token_type_ids"] = token_type_ids
return self.probe(**forward_kwargs)
if self.config.probe_type in ["retrievalnet", "lyra"]:
return self.probe(embeddings=probe_embeddings, attention_mask=probe_attention_mask, labels=labels)
raise ValueError(f"Unsupported probe type for packaged model: {self.config.probe_type}")
|