| import shutil |
| import tempfile |
| import gc |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModel, BertConfig, BertModel, BertTokenizerFast |
|
|
| try: |
| from probes.linear_probe import LinearProbe, LinearProbeConfig |
| from probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel |
| from probes.transformer_probe import TransformerForSequenceClassification, TransformerProbeConfig |
| except ImportError: |
| from ..probes.linear_probe import LinearProbe, LinearProbeConfig |
| from ..probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel |
| from ..probes.transformer_probe import TransformerForSequenceClassification, TransformerProbeConfig |
|
|
|
|
| def _copy_runtime_code(save_dir: Path) -> None: |
| repo_root = Path(__file__).resolve().parents[3] |
| src_package_dir = repo_root / "src" / "protify" |
| dst_package_dir = save_dir / "protify" |
| for src_file in src_package_dir.rglob("*.py"): |
| relative_path = src_file.relative_to(src_package_dir) |
| dst_file = dst_package_dir / relative_path |
| dst_file.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(src_file, dst_file) |
| packaged_model_file = repo_root / "src" / "protify" / "probes" / "packaged_probe_model.py" |
| shutil.copy2(packaged_model_file, save_dir / "packaged_probe_model.py") |
|
|
|
|
| def _create_tiny_backbone(backbone_dir: Path) -> tuple[BertModel, BertTokenizerFast]: |
| vocab_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "A", "B", "C", "D"] |
| vocab_path = backbone_dir / "vocab.txt" |
| vocab_path.write_text("\n".join(vocab_tokens), encoding="utf-8") |
| tokenizer = BertTokenizerFast(vocab_file=str(vocab_path), do_lower_case=False) |
| config = BertConfig( |
| vocab_size=len(vocab_tokens), |
| hidden_size=16, |
| num_hidden_layers=2, |
| num_attention_heads=2, |
| intermediate_size=32, |
| ) |
| model = BertModel(config).eval() |
| model.save_pretrained(str(backbone_dir)) |
| tokenizer.save_pretrained(str(backbone_dir)) |
| return model, tokenizer |
|
|
|
|
| def _save_and_load_with_automodel( |
| packaged_model: PackagedProbeModel, |
| tokenizer: BertTokenizerFast, |
| model_dir: Path, |
| ) -> AutoModel: |
| packaged_model.config.auto_map = { |
| "AutoConfig": "packaged_probe_model.PackagedProbeConfig", |
| "AutoModel": "packaged_probe_model.PackagedProbeModel", |
| } |
| packaged_model.config.architectures = ["PackagedProbeModel"] |
| packaged_model.save_pretrained(str(model_dir), safe_serialization=True) |
| tokenizer.save_pretrained(str(model_dir)) |
| _copy_runtime_code(model_dir) |
| return AutoModel.from_pretrained(str(model_dir), trust_remote_code=True) |
|
|
|
|
| def test_linear_packaged_roundtrip() -> None: |
| with tempfile.TemporaryDirectory(prefix="protify_linear_packaged_test_", ignore_cleanup_errors=True) as temp_dir: |
| temp_path = Path(temp_dir) |
| backbone_dir = temp_path / "backbone" |
| model_dir = temp_path / "linear_packaged_model" |
| backbone_dir.mkdir(parents=True, exist_ok=True) |
| model_dir.mkdir(parents=True, exist_ok=True) |
|
|
| backbone, tokenizer = _create_tiny_backbone(backbone_dir) |
| probe_config = LinearProbeConfig( |
| input_size=16, |
| hidden_size=32, |
| dropout=0.1, |
| num_labels=3, |
| n_layers=1, |
| task_type="singlelabel", |
| ) |
| probe = LinearProbe(probe_config).eval() |
| packaged_config = PackagedProbeConfig( |
| base_model_name=str(backbone_dir), |
| probe_type="linear", |
| probe_config=probe.config.to_dict(), |
| tokenwise=False, |
| matrix_embed=False, |
| pooling_types=["mean"], |
| task_type="singlelabel", |
| num_labels=3, |
| ppi=False, |
| add_token_ids=False, |
| sep_token_id=tokenizer.sep_token_id, |
| ) |
| packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() |
| loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) |
|
|
| batch = tokenizer(["A B C A", "B C D A"], padding="longest", return_tensors="pt") |
| outputs = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) |
| assert outputs.logits.shape == (2, 3), f"Unexpected linear packaged logits shape: {outputs.logits.shape}" |
| del loaded_model |
| gc.collect() |
|
|
|
|
| def test_transformer_packaged_roundtrip() -> None: |
| with tempfile.TemporaryDirectory(prefix="protify_transformer_packaged_test_", ignore_cleanup_errors=True) as temp_dir: |
| temp_path = Path(temp_dir) |
| backbone_dir = temp_path / "backbone" |
| model_dir = temp_path / "transformer_packaged_model" |
| backbone_dir.mkdir(parents=True, exist_ok=True) |
| model_dir.mkdir(parents=True, exist_ok=True) |
|
|
| backbone, tokenizer = _create_tiny_backbone(backbone_dir) |
| probe_config = TransformerProbeConfig( |
| input_size=16, |
| hidden_size=16, |
| classifier_size=24, |
| transformer_dropout=0.1, |
| classifier_dropout=0.1, |
| num_labels=2, |
| n_layers=1, |
| token_attention=False, |
| n_heads=2, |
| task_type="singlelabel", |
| rotary=False, |
| pre_ln=True, |
| probe_pooling_types=["mean"], |
| use_bias=False, |
| add_token_ids=False, |
| ) |
| probe = TransformerForSequenceClassification(probe_config).eval() |
| packaged_config = PackagedProbeConfig( |
| base_model_name=str(backbone_dir), |
| probe_type="transformer", |
| probe_config=probe.config.to_dict(), |
| tokenwise=False, |
| matrix_embed=True, |
| pooling_types=["mean"], |
| task_type="singlelabel", |
| num_labels=2, |
| ppi=False, |
| add_token_ids=False, |
| sep_token_id=tokenizer.sep_token_id, |
| ) |
| packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() |
| loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) |
|
|
| batch = tokenizer(["A B C D", "D C B A"], padding="longest", return_tensors="pt") |
| outputs = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) |
| assert outputs.logits.shape == (2, 2), f"Unexpected transformer packaged logits shape: {outputs.logits.shape}" |
| del loaded_model |
| gc.collect() |
|
|
|
|
| def test_ppi_packaged_inference_with_and_without_token_type_ids() -> None: |
| with tempfile.TemporaryDirectory(prefix="protify_ppi_packaged_test_", ignore_cleanup_errors=True) as temp_dir: |
| temp_path = Path(temp_dir) |
| backbone_dir = temp_path / "backbone" |
| model_dir = temp_path / "ppi_packaged_model" |
| backbone_dir.mkdir(parents=True, exist_ok=True) |
| model_dir.mkdir(parents=True, exist_ok=True) |
|
|
| backbone, tokenizer = _create_tiny_backbone(backbone_dir) |
| probe_config = LinearProbeConfig( |
| input_size=32, |
| hidden_size=24, |
| dropout=0.1, |
| num_labels=2, |
| n_layers=1, |
| task_type="singlelabel", |
| ) |
| probe = LinearProbe(probe_config).eval() |
| packaged_config = PackagedProbeConfig( |
| base_model_name=str(backbone_dir), |
| probe_type="linear", |
| probe_config=probe.config.to_dict(), |
| tokenwise=False, |
| matrix_embed=False, |
| pooling_types=["mean"], |
| task_type="singlelabel", |
| num_labels=2, |
| ppi=True, |
| add_token_ids=False, |
| sep_token_id=tokenizer.sep_token_id, |
| ) |
| packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() |
| loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) |
|
|
| pair_batch = tokenizer( |
| ["A B C", "B C D"], |
| ["D C B", "A C B"], |
| padding="longest", |
| return_tensors="pt", |
| ) |
|
|
| outputs_with_token_types = loaded_model( |
| input_ids=pair_batch["input_ids"], |
| attention_mask=pair_batch["attention_mask"], |
| token_type_ids=pair_batch["token_type_ids"], |
| ) |
| assert outputs_with_token_types.logits.shape == (2, 2), "PPI logits shape mismatch with token_type_ids" |
|
|
| outputs_without_token_types = loaded_model( |
| input_ids=pair_batch["input_ids"], |
| attention_mask=pair_batch["attention_mask"], |
| ) |
| assert outputs_without_token_types.logits.shape == (2, 2), "PPI logits shape mismatch without token_type_ids" |
| del loaded_model |
| gc.collect() |
|
|
|
|
| def main() -> None: |
| torch.manual_seed(0) |
| test_linear_packaged_roundtrip() |
| test_transformer_packaged_roundtrip() |
| test_ppi_packaged_inference_with_and_without_token_type_ids() |
| print("Packaged probe model smoke tests passed.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|