| from __future__ import annotations |
|
|
| import re |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
|
|
| from .encoders import encoder_storage_key, normalize_encoder_name |
| from .model import IMRNN, ModelConfig |
|
|
|
|
| def default_checkpoint_name(encoder: str, dataset: str) -> str: |
| return f"imrnns-{encoder_storage_key(encoder)}-{dataset}.pt" |
|
|
|
|
| def sanitize_legacy_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: |
| cleaned: dict[str, Any] = {} |
| for key, value in state_dict.items(): |
| if key.startswith("e5_model.") or key.startswith("sbert."): |
| continue |
| mapped_key = key |
| mapped_key = re.sub(r"^(e5_projector|sbert_projector)\.", "projector.", mapped_key) |
| cleaned[mapped_key] = value |
| return cleaned |
|
|
|
|
| def save_checkpoint( |
| path: Path, |
| model: IMRNN, |
| metadata: dict[str, Any], |
| ) -> None: |
| payload = { |
| "model_state": model.state_dict(), |
| "metadata": { |
| "checkpoint_format": "imrnns-adapter-only-v1", |
| **metadata, |
| }, |
| } |
| path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(payload, path) |
|
|
|
|
| def load_checkpoint(path: Path) -> tuple[dict[str, Any], dict[str, Any]]: |
| payload = torch.load(path, map_location="cpu", weights_only=True) |
| if isinstance(payload, dict) and "model_state" in payload: |
| return sanitize_legacy_state_dict(payload["model_state"]), payload.get("metadata", {}) |
| if isinstance(payload, dict): |
| return sanitize_legacy_state_dict(payload), {} |
| raise TypeError(f"Unsupported checkpoint format in {path}") |
|
|
|
|
| def load_model( |
| checkpoint_path: Path, |
| model_config: ModelConfig, |
| device: str, |
| ) -> tuple[IMRNN, dict[str, Any], list[str], list[str]]: |
| state_dict, metadata = load_checkpoint(checkpoint_path) |
| model = IMRNN(model_config) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| model.to(device) |
| model.eval() |
| return model, metadata, missing, unexpected |
|
|
|
|
| def convert_legacy_checkpoint( |
| source_path: Path, |
| target_path: Path, |
| metadata: dict[str, Any], |
| ) -> None: |
| state_dict, existing_metadata = load_checkpoint(source_path) |
| payload = { |
| "model_state": state_dict, |
| "metadata": { |
| "checkpoint_format": "imrnns-adapter-only-v1", |
| **existing_metadata, |
| **metadata, |
| }, |
| } |
| target_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(payload, target_path) |
|
|