File size: 5,769 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import shutil
import tempfile
from pathlib import Path
from typing import Optional
from huggingface_hub import HfApi
from torch import nn
try:
from base_models.supported_models import all_presets_with_paths
from probes.hybrid_probe import HybridProbe
from probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel
from utils import print_message
except ImportError:
from ..base_models.supported_models import all_presets_with_paths
from .hybrid_probe import HybridProbe
from .packaged_probe_model import PackagedProbeConfig, PackagedProbeModel
from ..utils import print_message
def _infer_probe_type(probe_model: nn.Module) -> str:
probe_class_name = probe_model.__class__.__name__
if probe_class_name == "LinearProbe":
return "linear"
if probe_class_name in ["TransformerForSequenceClassification", "TransformerForTokenClassification"]:
return "transformer"
if probe_class_name in ["RetrievalNetForSequenceClassification", "RetrievalNetForTokenClassification"]:
return "retrievalnet"
if probe_class_name in ["LyraForSequenceClassification", "LyraForTokenClassification"]:
return "lyra"
raise ValueError(f"Unsupported probe class for packaged export: {probe_class_name}")
def _is_supported_base_model(source_model_name: str) -> bool:
if source_model_name not in all_presets_with_paths:
return False
model_name_l = source_model_name.lower()
if "random" in model_name_l:
return False
if "onehot" in model_name_l:
return False
if "vec2vec" in model_name_l:
return False
return True
def _extract_sep_token_id(tokenizer) -> Optional[int]:
try:
tokenizer_backend = tokenizer.tokenizer
except AttributeError:
tokenizer_backend = tokenizer
if tokenizer_backend.sep_token_id is not None:
return int(tokenizer_backend.sep_token_id)
if tokenizer_backend.eos_token_id is not None:
return int(tokenizer_backend.eos_token_id)
return None
def _copy_runtime_code(export_dir: Path) -> None:
repo_root = Path(__file__).resolve().parents[3]
src_package_dir = repo_root / "src" / "protify"
dst_package_dir = export_dir / "protify"
for src_file in src_package_dir.rglob("*.py"):
relative_path = src_file.relative_to(src_package_dir)
dst_file = dst_package_dir / relative_path
dst_file.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dst_file)
packaged_model_file = Path(__file__).with_name("packaged_probe_model.py")
shutil.copy2(packaged_model_file, export_dir / "packaged_probe_model.py")
def _build_packaged_model(
trained_model: nn.Module,
source_model_name: str,
probe_args,
embedding_args,
tokenizer,
ppi: bool,
) -> PackagedProbeModel:
if isinstance(trained_model, HybridProbe):
base_model = trained_model.model
probe_model = trained_model.probe
else:
base_model = None
probe_model = trained_model
probe_type = _infer_probe_type(probe_model)
probe_config_dict = probe_model.config.to_dict()
sep_token_id = _extract_sep_token_id(tokenizer)
packaged_config = PackagedProbeConfig(
base_model_name=source_model_name,
probe_type=probe_type,
probe_config=probe_config_dict,
tokenwise=probe_args.tokenwise,
matrix_embed=embedding_args.matrix_embed,
pooling_types=embedding_args.pooling_types,
task_type=probe_args.task_type,
num_labels=probe_args.num_labels,
ppi=ppi,
add_token_ids=probe_args.add_token_ids,
sep_token_id=sep_token_id,
)
packaged_model = PackagedProbeModel(config=packaged_config, base_model=base_model, probe=probe_model)
return packaged_model.cpu()
def export_packaged_model_to_hub(
trained_model: nn.Module,
source_model_name: str,
probe_args,
embedding_args,
tokenizer,
repo_id: str,
model_card: str,
ppi: bool = False,
private: bool = True,
hf_token: Optional[str] = None,
) -> tuple[bool, str]:
if not _is_supported_base_model(source_model_name):
return False, f"Packaged export is not supported for base model: {source_model_name}"
packaged_model = _build_packaged_model(
trained_model=trained_model,
source_model_name=source_model_name,
probe_args=probe_args,
embedding_args=embedding_args,
tokenizer=tokenizer,
ppi=ppi,
)
with tempfile.TemporaryDirectory(prefix="protify_packaged_model_") as temp_dir:
export_dir = Path(temp_dir)
packaged_model.config.auto_map = {
"AutoConfig": "packaged_probe_model.PackagedProbeConfig",
"AutoModel": "packaged_probe_model.PackagedProbeModel",
}
packaged_model.config.architectures = ["PackagedProbeModel"]
packaged_model.save_pretrained(str(export_dir), safe_serialization=True)
tokenizer.save_pretrained(str(export_dir))
_copy_runtime_code(export_dir)
readme_path = export_dir / "README.md"
readme_path.write_text(model_card, encoding="utf-8")
if hf_token is None:
api = HfApi()
else:
api = HfApi(token=hf_token)
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=str(export_dir),
path_in_repo="",
)
print_message(f"Packaged model and tokenizer uploaded to Hugging Face Hub: {repo_id}")
return True, f"Uploaded packaged model to {repo_id}"
|