IMRNNs / src /imrnns /checkpoints.py
yashsaxena21's picture
Upload folder using huggingface_hub
14e3943 verified
from __future__ import annotations
import re
from pathlib import Path
from typing import Any
import torch
from .encoders import encoder_storage_key, normalize_encoder_name
from .model import IMRNN, ModelConfig
def default_checkpoint_name(encoder: str, dataset: str) -> str:
return f"imrnns-{encoder_storage_key(encoder)}-{dataset}.pt"
def sanitize_legacy_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
cleaned: dict[str, Any] = {}
for key, value in state_dict.items():
if key.startswith("e5_model.") or key.startswith("sbert."):
continue
mapped_key = key
mapped_key = re.sub(r"^(e5_projector|sbert_projector)\.", "projector.", mapped_key)
cleaned[mapped_key] = value
return cleaned
def save_checkpoint(
path: Path,
model: IMRNN,
metadata: dict[str, Any],
) -> None:
payload = {
"model_state": model.state_dict(),
"metadata": {
"checkpoint_format": "imrnns-adapter-only-v1",
**metadata,
},
}
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(payload, path)
def load_checkpoint(path: Path) -> tuple[dict[str, Any], dict[str, Any]]:
payload = torch.load(path, map_location="cpu", weights_only=True)
if isinstance(payload, dict) and "model_state" in payload:
return sanitize_legacy_state_dict(payload["model_state"]), payload.get("metadata", {})
if isinstance(payload, dict):
return sanitize_legacy_state_dict(payload), {}
raise TypeError(f"Unsupported checkpoint format in {path}")
def load_model(
checkpoint_path: Path,
model_config: ModelConfig,
device: str,
) -> tuple[IMRNN, dict[str, Any], list[str], list[str]]:
state_dict, metadata = load_checkpoint(checkpoint_path)
model = IMRNN(model_config)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
return model, metadata, missing, unexpected
def convert_legacy_checkpoint(
source_path: Path,
target_path: Path,
metadata: dict[str, Any],
) -> None:
state_dict, existing_metadata = load_checkpoint(source_path)
payload = {
"model_state": state_dict,
"metadata": {
"checkpoint_format": "imrnns-adapter-only-v1",
**existing_metadata,
**metadata,
},
}
target_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(payload, target_path)