from __future__ import annotations import json from pathlib import Path import numpy as np from bio_llm.training.trainer import load_checkpoint def export_checkpoint_to_npz( checkpoint_path: str | Path, output_path: str | Path, tokenizer_path: str | Path | None = None, max_seq_len: int | None = None, ) -> Path: config_overrides = {"max_seq_len": max_seq_len} if max_seq_len is not None else None model, tokenizer = load_checkpoint( checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, config_overrides=config_overrides, ) arrays: dict[str, np.ndarray] = {} parameter_names: list[str] = [] for name, tensor in model.state_dict().items(): safe_name = name.replace(".", "__") arrays[safe_name] = tensor.detach().cpu().numpy() parameter_names.append(name) arrays["__parameter_names__"] = np.array(parameter_names, dtype=object) arrays["__config_json__"] = np.array(json.dumps(model.config.to_dict(), ensure_ascii=True)) arrays["__tokenizer_json__"] = np.array( json.dumps( { "type": "bpe" if hasattr(tokenizer, "merges") else "simple", "vocab": tokenizer.id_to_token, "merges": [list(pair) for pair in getattr(tokenizer, "merges", [])], }, ensure_ascii=True, ) ) resolved_output = Path(output_path) resolved_output.parent.mkdir(parents=True, exist_ok=True) np.savez_compressed(resolved_output, **arrays) return resolved_output