Spaces:
Sleeping
Sleeping
File size: 4,512 Bytes
0506a57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
#!/usr/bin/env python
"""
Export Aeon PyTorch Lightning checkpoint to pickle format for inference.
This script converts a PyTorch Lightning checkpoint (.ckpt) file to a pickle
(.pkl) file that can be used with the Mosaic inference pipeline.
Usage:
python export_aeon_checkpoint.py \
--checkpoint data/checkpoint.ckpt \
--output data/aeon_model.pkl \
--metadata-dir data/metadata
Requirements:
- paladin package from git repo (must have AeonLightningModule)
- PyTorch Lightning
- Access to metadata files (n_classes.txt, ontology_embedding_dim.txt, target_dict.tsv)
"""
import argparse
import json
import pickle
from pathlib import Path
def load_metadata(metadata_dir: Path):
"""Load metadata required for model initialization.
Args:
metadata_dir: Directory containing metadata files
Returns:
SimpleMetadata object with n_classes, ontology_embedding_dim, and target_dicts
"""
# Read n_classes
with open(metadata_dir / "n_classes.txt") as f:
n_classes = int(f.read().strip())
# Read ontology_embedding_dim
with open(metadata_dir / "ontology_embedding_dim.txt") as f:
ontology_embedding_dim = int(f.read().strip())
# Read target_dict (JSON format with single quotes)
with open(metadata_dir / "target_dict.tsv") as f:
target_dict_str = f.read().strip().replace("'", '"')
target_dict = json.loads(target_dict_str)
# Create simple metadata object
class SimpleMetadata:
def __init__(self, n_classes, ontology_embedding_dim, target_dict):
self.n_classes = n_classes
self.ontology_embedding_dim = ontology_embedding_dim
self.target_dicts = [target_dict]
return SimpleMetadata(n_classes, ontology_embedding_dim, target_dict)
def export_checkpoint(checkpoint_path: Path, output_path: Path, metadata_dir: Path):
"""Export PyTorch Lightning checkpoint to pickle format.
Args:
checkpoint_path: Path to .ckpt file
output_path: Path to save .pkl file
metadata_dir: Directory containing metadata files
"""
try:
from paladin.pl_modules.aeon import AeonLightningModule
except ImportError:
raise ImportError(
"Failed to import AeonLightningModule. "
"Make sure paladin is installed from the git repository:\n"
" uv sync --upgrade-package paladin"
)
print(f"Loading metadata from {metadata_dir}...")
metadata = load_metadata(metadata_dir)
print(f"Loading checkpoint from {checkpoint_path}...")
pl_module = AeonLightningModule.load_from_checkpoint(
str(checkpoint_path),
metadata=metadata
)
# Extract the model
model = pl_module.model
print(f"Saving model to {output_path}...")
with open(output_path, "wb") as f:
pickle.dump(model, f)
print(f"✓ Successfully exported checkpoint to {output_path}")
# Print model info
file_size = output_path.stat().st_size / (1024 * 1024) # MB
print(f" Model size: {file_size:.1f} MB")
print(f" Model class: {type(model).__name__}")
print(f" Number of classes: {metadata.n_classes}")
print(f" Ontology embedding dim: {metadata.ontology_embedding_dim}")
print(f" Number of histologies: {len(metadata.target_dicts[0]['histologies'])}")
def main():
parser = argparse.ArgumentParser(
description="Export Aeon PyTorch Lightning checkpoint to pickle format"
)
parser.add_argument(
"--checkpoint",
type=Path,
required=True,
help="Path to PyTorch Lightning checkpoint (.ckpt)"
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Path to save exported model (.pkl)"
)
parser.add_argument(
"--metadata-dir",
type=Path,
default=Path("data/metadata"),
help="Directory containing metadata files (default: data/metadata)"
)
args = parser.parse_args()
# Validate inputs
if not args.checkpoint.exists():
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}")
if not args.metadata_dir.exists():
raise FileNotFoundError(f"Metadata directory not found: {args.metadata_dir}")
# Create output directory if needed
args.output.parent.mkdir(parents=True, exist_ok=True)
# Export checkpoint
export_checkpoint(args.checkpoint, args.output, args.metadata_dir)
if __name__ == "__main__":
main()
|