| import os |
| import sys |
| from typing import Any, Dict, Optional |
|
|
| import torch |
| from torch import nn |
| from transformers import AutoModel, PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput |
|
|
|
|
| try: |
| from protify.base_models.supported_models import all_presets_with_paths |
| from protify.pooler import Pooler |
| from protify.probes.get_probe import rebuild_probe_from_saved_config |
| except ImportError: |
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| candidate_paths = [ |
| current_dir, |
| os.path.dirname(current_dir), |
| os.path.dirname(os.path.dirname(current_dir)), |
| os.path.join(current_dir, "src"), |
| ] |
| for candidate in candidate_paths: |
| if os.path.isdir(candidate) and candidate not in sys.path: |
| sys.path.insert(0, candidate) |
| from protify.base_models.supported_models import all_presets_with_paths |
| from protify.pooler import Pooler |
| from protify.probes.get_probe import rebuild_probe_from_saved_config |
|
|
|
|
| class PackagedProbeConfig(PretrainedConfig): |
| model_type = "packaged_probe" |
|
|
| def __init__( |
| self, |
| base_model_name: str = "", |
| probe_type: str = "linear", |
| probe_config: Optional[Dict[str, Any]] = None, |
| tokenwise: bool = False, |
| matrix_embed: bool = False, |
| pooling_types: Optional[list[str]] = None, |
| task_type: str = "singlelabel", |
| num_labels: int = 2, |
| ppi: bool = False, |
| add_token_ids: bool = False, |
| sep_token_id: Optional[int] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.base_model_name = base_model_name |
| self.probe_type = probe_type |
| self.probe_config = {} if probe_config is None else probe_config |
| self.tokenwise = tokenwise |
| self.matrix_embed = matrix_embed |
| self.pooling_types = ["mean"] if pooling_types is None else pooling_types |
| self.task_type = task_type |
| self.num_labels = num_labels |
| self.ppi = ppi |
| self.add_token_ids = add_token_ids |
| self.sep_token_id = sep_token_id |
|
|
|
|
| class PackagedProbeModel(PreTrainedModel): |
| config_class = PackagedProbeConfig |
| base_model_prefix = "backbone" |
| all_tied_weights_keys = {} |
|
|
| def __init__( |
| self, |
| config: PackagedProbeConfig, |
| base_model: Optional[nn.Module] = None, |
| probe: Optional[nn.Module] = None, |
| ): |
| super().__init__(config) |
| self.config = config |
| self.backbone = self._load_base_model() if base_model is None else base_model |
| self.probe = self._load_probe() if probe is None else probe |
| self.pooler = Pooler(self.config.pooling_types) |
|
|
| def _load_base_model(self) -> nn.Module: |
| if self.config.base_model_name in all_presets_with_paths: |
| model_path = all_presets_with_paths[self.config.base_model_name] |
| else: |
| model_path = self.config.base_model_name |
| model = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
| model.eval() |
| return model |
|
|
| def _load_probe(self) -> nn.Module: |
| return rebuild_probe_from_saved_config( |
| probe_type=self.config.probe_type, |
| tokenwise=self.config.tokenwise, |
| probe_config=self.config.probe_config, |
| ) |
|
|
| @staticmethod |
| def _extract_hidden_states(backbone_output: Any) -> torch.Tensor: |
| if isinstance(backbone_output, tuple): |
| return backbone_output[0] |
| if hasattr(backbone_output, "last_hidden_state"): |
| return backbone_output.last_hidden_state |
| if isinstance(backbone_output, torch.Tensor): |
| return backbone_output |
| raise ValueError("Unsupported backbone output format for packaged probe model") |
|
|
| @staticmethod |
| def _extract_attentions(backbone_output: Any) -> Optional[torch.Tensor]: |
| if hasattr(backbone_output, "attentions"): |
| return backbone_output.attentions |
| return None |
|
|
| def _build_ppi_segment_masks( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if token_type_ids is not None and torch.any(token_type_ids == 1): |
| mask_a = ((token_type_ids == 0) & (attention_mask == 1)).long() |
| mask_b = ((token_type_ids == 1) & (attention_mask == 1)).long() |
| assert torch.all(mask_a.sum(dim=1) > 0), "PPI token_type_ids produced empty segment A" |
| assert torch.all(mask_b.sum(dim=1) > 0), "PPI token_type_ids produced empty segment B" |
| return mask_a, mask_b |
|
|
| assert self.config.sep_token_id is not None, "sep_token_id is required for PPI fallback segmentation" |
| batch_size, seq_len = input_ids.shape |
| mask_a = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device) |
| mask_b = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device) |
|
|
| for batch_idx in range(batch_size): |
| valid_positions = torch.where(attention_mask[batch_idx] == 1)[0] |
| sep_positions = torch.where((input_ids[batch_idx] == self.config.sep_token_id) & (attention_mask[batch_idx] == 1))[0] |
| if len(valid_positions) == 0: |
| continue |
|
|
| if len(sep_positions) >= 2: |
| first_sep = int(sep_positions[0].item()) |
| second_sep = int(sep_positions[1].item()) |
| mask_a[batch_idx, :first_sep + 1] = 1 |
| mask_b[batch_idx, first_sep + 1:second_sep + 1] = 1 |
| elif len(sep_positions) == 1: |
| first_sep = int(sep_positions[0].item()) |
| mask_a[batch_idx, :first_sep + 1] = 1 |
| mask_b[batch_idx, first_sep + 1: int(valid_positions[-1].item()) + 1] = 1 |
| else: |
| midpoint = len(valid_positions) // 2 |
| mask_a[batch_idx, valid_positions[:midpoint]] = 1 |
| mask_b[batch_idx, valid_positions[midpoint:]] = 1 |
|
|
| assert torch.all(mask_a.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment A" |
| assert torch.all(mask_b.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment B" |
| return mask_a, mask_b |
|
|
| def _build_probe_inputs( |
| self, |
| hidden_states: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor], |
| attentions: Optional[torch.Tensor], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| if self.config.ppi and (not self.config.matrix_embed) and (not self.config.tokenwise): |
| mask_a, mask_b = self._build_ppi_segment_masks(input_ids, attention_mask, token_type_ids) |
| vec_a = self.pooler(hidden_states, attention_mask=mask_a, attentions=attentions) |
| vec_b = self.pooler(hidden_states, attention_mask=mask_b, attentions=attentions) |
| return torch.cat((vec_a, vec_b), dim=-1), None |
|
|
| if self.config.matrix_embed or self.config.tokenwise: |
| return hidden_states, attention_mask |
|
|
| pooled = self.pooler(hidden_states, attention_mask=attention_mask, attentions=attentions) |
| return pooled, None |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> SequenceClassifierOutput | TokenClassifierOutput: |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long) |
|
|
| requires_attentions = "parti" in self.config.pooling_types and (not self.config.matrix_embed) and (not self.config.tokenwise) |
| backbone_kwargs: Dict[str, Any] = {"input_ids": input_ids, "attention_mask": attention_mask} |
| if requires_attentions: |
| backbone_kwargs["output_attentions"] = True |
| backbone_output = self.backbone(**backbone_kwargs) |
| hidden_states = self._extract_hidden_states(backbone_output) |
| attentions = self._extract_attentions(backbone_output) |
| if requires_attentions: |
| assert attentions is not None, "parti pooling requires base model attentions" |
| probe_embeddings, probe_attention_mask = self._build_probe_inputs( |
| hidden_states=hidden_states, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| attentions=attentions, |
| ) |
|
|
| if self.config.probe_type == "linear": |
| return self.probe(embeddings=probe_embeddings, labels=labels) |
|
|
| if self.config.probe_type == "transformer": |
| forward_kwargs: Dict[str, Any] = {"embeddings": probe_embeddings, "labels": labels} |
| if probe_attention_mask is not None: |
| forward_kwargs["attention_mask"] = probe_attention_mask |
| if self.config.add_token_ids and token_type_ids is not None and probe_attention_mask is not None: |
| forward_kwargs["token_type_ids"] = token_type_ids |
| return self.probe(**forward_kwargs) |
|
|
| if self.config.probe_type in ["retrievalnet", "lyra"]: |
| return self.probe(embeddings=probe_embeddings, attention_mask=probe_attention_mask, labels=labels) |
|
|
| raise ValueError(f"Unsupported probe type for packaged model: {self.config.probe_type}") |
|
|