File size: 1,564 Bytes
424c56c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from __future__ import annotations

import json
from pathlib import Path

import numpy as np

from bio_llm.training.trainer import load_checkpoint


def export_checkpoint_to_npz(
    checkpoint_path: str | Path,
    output_path: str | Path,
    tokenizer_path: str | Path | None = None,
    max_seq_len: int | None = None,
) -> Path:
    config_overrides = {"max_seq_len": max_seq_len} if max_seq_len is not None else None
    model, tokenizer = load_checkpoint(
        checkpoint_path=checkpoint_path,
        tokenizer_path=tokenizer_path,
        config_overrides=config_overrides,
    )

    arrays: dict[str, np.ndarray] = {}
    parameter_names: list[str] = []
    for name, tensor in model.state_dict().items():
        safe_name = name.replace(".", "__")
        arrays[safe_name] = tensor.detach().cpu().numpy()
        parameter_names.append(name)

    arrays["__parameter_names__"] = np.array(parameter_names, dtype=object)
    arrays["__config_json__"] = np.array(json.dumps(model.config.to_dict(), ensure_ascii=True))
    arrays["__tokenizer_json__"] = np.array(
        json.dumps(
            {
                "type": "bpe" if hasattr(tokenizer, "merges") else "simple",
                "vocab": tokenizer.id_to_token,
                "merges": [list(pair) for pair in getattr(tokenizer, "merges", [])],
            },
            ensure_ascii=True,
        )
    )

    resolved_output = Path(output_path)
    resolved_output.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(resolved_output, **arrays)
    return resolved_output