File size: 4,512 Bytes
0506a57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python
"""
Export Aeon PyTorch Lightning checkpoint to pickle format for inference.

This script converts a PyTorch Lightning checkpoint (.ckpt) file to a pickle
(.pkl) file that can be used with the Mosaic inference pipeline.

Usage:
    python export_aeon_checkpoint.py \
        --checkpoint data/checkpoint.ckpt \
        --output data/aeon_model.pkl \
        --metadata-dir data/metadata

Requirements:
    - paladin package from git repo (must have AeonLightningModule)
    - PyTorch Lightning
    - Access to metadata files (n_classes.txt, ontology_embedding_dim.txt, target_dict.tsv)
"""

import argparse
import json
import pickle
from pathlib import Path


def load_metadata(metadata_dir: Path):
    """Load metadata required for model initialization.

    Args:
        metadata_dir: Directory containing metadata files

    Returns:
        SimpleMetadata object with n_classes, ontology_embedding_dim, and target_dicts
    """
    # Read n_classes
    with open(metadata_dir / "n_classes.txt") as f:
        n_classes = int(f.read().strip())

    # Read ontology_embedding_dim
    with open(metadata_dir / "ontology_embedding_dim.txt") as f:
        ontology_embedding_dim = int(f.read().strip())

    # Read target_dict (JSON format with single quotes)
    with open(metadata_dir / "target_dict.tsv") as f:
        target_dict_str = f.read().strip().replace("'", '"')
        target_dict = json.loads(target_dict_str)

    # Create simple metadata object
    class SimpleMetadata:
        def __init__(self, n_classes, ontology_embedding_dim, target_dict):
            self.n_classes = n_classes
            self.ontology_embedding_dim = ontology_embedding_dim
            self.target_dicts = [target_dict]

    return SimpleMetadata(n_classes, ontology_embedding_dim, target_dict)


def export_checkpoint(checkpoint_path: Path, output_path: Path, metadata_dir: Path):
    """Export PyTorch Lightning checkpoint to pickle format.

    Args:
        checkpoint_path: Path to .ckpt file
        output_path: Path to save .pkl file
        metadata_dir: Directory containing metadata files
    """
    try:
        from paladin.pl_modules.aeon import AeonLightningModule
    except ImportError:
        raise ImportError(
            "Failed to import AeonLightningModule. "
            "Make sure paladin is installed from the git repository:\n"
            "  uv sync --upgrade-package paladin"
        )

    print(f"Loading metadata from {metadata_dir}...")
    metadata = load_metadata(metadata_dir)

    print(f"Loading checkpoint from {checkpoint_path}...")
    pl_module = AeonLightningModule.load_from_checkpoint(
        str(checkpoint_path),
        metadata=metadata
    )

    # Extract the model
    model = pl_module.model

    print(f"Saving model to {output_path}...")
    with open(output_path, "wb") as f:
        pickle.dump(model, f)

    print(f"✓ Successfully exported checkpoint to {output_path}")

    # Print model info
    file_size = output_path.stat().st_size / (1024 * 1024)  # MB
    print(f"  Model size: {file_size:.1f} MB")
    print(f"  Model class: {type(model).__name__}")
    print(f"  Number of classes: {metadata.n_classes}")
    print(f"  Ontology embedding dim: {metadata.ontology_embedding_dim}")
    print(f"  Number of histologies: {len(metadata.target_dicts[0]['histologies'])}")


def main():
    parser = argparse.ArgumentParser(
        description="Export Aeon PyTorch Lightning checkpoint to pickle format"
    )
    parser.add_argument(
        "--checkpoint",
        type=Path,
        required=True,
        help="Path to PyTorch Lightning checkpoint (.ckpt)"
    )
    parser.add_argument(
        "--output",
        type=Path,
        required=True,
        help="Path to save exported model (.pkl)"
    )
    parser.add_argument(
        "--metadata-dir",
        type=Path,
        default=Path("data/metadata"),
        help="Directory containing metadata files (default: data/metadata)"
    )

    args = parser.parse_args()

    # Validate inputs
    if not args.checkpoint.exists():
        raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")

    if not args.metadata_dir.exists():
        raise FileNotFoundError(f"Metadata directory not found: {args.metadata_dir}")

    # Create output directory if needed
    args.output.parent.mkdir(parents=True, exist_ok=True)

    # Export checkpoint
    export_checkpoint(args.checkpoint, args.output, args.metadata_dir)


if __name__ == "__main__":
    main()