mosaic-zero / scripts /export_aeon_checkpoint.py
raylim's picture
Add Aeon model test suite and reproducibility scripts
0506a57 unverified
#!/usr/bin/env python
"""
Export Aeon PyTorch Lightning checkpoint to pickle format for inference.
This script converts a PyTorch Lightning checkpoint (.ckpt) file to a pickle
(.pkl) file that can be used with the Mosaic inference pipeline.
Usage:
python export_aeon_checkpoint.py \
--checkpoint data/checkpoint.ckpt \
--output data/aeon_model.pkl \
--metadata-dir data/metadata
Requirements:
- paladin package from git repo (must have AeonLightningModule)
- PyTorch Lightning
- Access to metadata files (n_classes.txt, ontology_embedding_dim.txt, target_dict.tsv)
"""
import argparse
import json
import pickle
from pathlib import Path
def load_metadata(metadata_dir: Path):
"""Load metadata required for model initialization.
Args:
metadata_dir: Directory containing metadata files
Returns:
SimpleMetadata object with n_classes, ontology_embedding_dim, and target_dicts
"""
# Read n_classes
with open(metadata_dir / "n_classes.txt") as f:
n_classes = int(f.read().strip())
# Read ontology_embedding_dim
with open(metadata_dir / "ontology_embedding_dim.txt") as f:
ontology_embedding_dim = int(f.read().strip())
# Read target_dict (JSON format with single quotes)
with open(metadata_dir / "target_dict.tsv") as f:
target_dict_str = f.read().strip().replace("'", '"')
target_dict = json.loads(target_dict_str)
# Create simple metadata object
class SimpleMetadata:
def __init__(self, n_classes, ontology_embedding_dim, target_dict):
self.n_classes = n_classes
self.ontology_embedding_dim = ontology_embedding_dim
self.target_dicts = [target_dict]
return SimpleMetadata(n_classes, ontology_embedding_dim, target_dict)
def export_checkpoint(checkpoint_path: Path, output_path: Path, metadata_dir: Path):
"""Export PyTorch Lightning checkpoint to pickle format.
Args:
checkpoint_path: Path to .ckpt file
output_path: Path to save .pkl file
metadata_dir: Directory containing metadata files
"""
try:
from paladin.pl_modules.aeon import AeonLightningModule
except ImportError:
raise ImportError(
"Failed to import AeonLightningModule. "
"Make sure paladin is installed from the git repository:\n"
" uv sync --upgrade-package paladin"
)
print(f"Loading metadata from {metadata_dir}...")
metadata = load_metadata(metadata_dir)
print(f"Loading checkpoint from {checkpoint_path}...")
pl_module = AeonLightningModule.load_from_checkpoint(
str(checkpoint_path),
metadata=metadata
)
# Extract the model
model = pl_module.model
print(f"Saving model to {output_path}...")
with open(output_path, "wb") as f:
pickle.dump(model, f)
print(f"✓ Successfully exported checkpoint to {output_path}")
# Print model info
file_size = output_path.stat().st_size / (1024 * 1024) # MB
print(f" Model size: {file_size:.1f} MB")
print(f" Model class: {type(model).__name__}")
print(f" Number of classes: {metadata.n_classes}")
print(f" Ontology embedding dim: {metadata.ontology_embedding_dim}")
print(f" Number of histologies: {len(metadata.target_dicts[0]['histologies'])}")
def main():
parser = argparse.ArgumentParser(
description="Export Aeon PyTorch Lightning checkpoint to pickle format"
)
parser.add_argument(
"--checkpoint",
type=Path,
required=True,
help="Path to PyTorch Lightning checkpoint (.ckpt)"
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Path to save exported model (.pkl)"
)
parser.add_argument(
"--metadata-dir",
type=Path,
default=Path("data/metadata"),
help="Directory containing metadata files (default: data/metadata)"
)
args = parser.parse_args()
# Validate inputs
if not args.checkpoint.exists():
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
if not args.metadata_dir.exists():
raise FileNotFoundError(f"Metadata directory not found: {args.metadata_dir}")
# Create output directory if needed
args.output.parent.mkdir(parents=True, exist_ok=True)
# Export checkpoint
export_checkpoint(args.checkpoint, args.output, args.metadata_dir)
if __name__ == "__main__":
main()