|
|
|
|
|
""" |
|
|
Export Aeon PyTorch Lightning checkpoint to pickle format for inference. |
|
|
|
|
|
This script converts a PyTorch Lightning checkpoint (.ckpt) file to a pickle |
|
|
(.pkl) file that can be used with the Mosaic inference pipeline. |
|
|
|
|
|
Usage: |
|
|
python export_aeon_checkpoint.py \ |
|
|
--checkpoint data/checkpoint.ckpt \ |
|
|
--output data/aeon_model.pkl \ |
|
|
--metadata-dir data/metadata |
|
|
|
|
|
Requirements: |
|
|
- paladin package from git repo (must have AeonLightningModule) |
|
|
- PyTorch Lightning |
|
|
- Access to metadata files (n_classes.txt, ontology_embedding_dim.txt, target_dict.tsv) |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def load_metadata(metadata_dir: Path): |
|
|
"""Load metadata required for model initialization. |
|
|
|
|
|
Args: |
|
|
metadata_dir: Directory containing metadata files |
|
|
|
|
|
Returns: |
|
|
SimpleMetadata object with n_classes, ontology_embedding_dim, and target_dicts |
|
|
""" |
|
|
|
|
|
with open(metadata_dir / "n_classes.txt") as f: |
|
|
n_classes = int(f.read().strip()) |
|
|
|
|
|
|
|
|
with open(metadata_dir / "ontology_embedding_dim.txt") as f: |
|
|
ontology_embedding_dim = int(f.read().strip()) |
|
|
|
|
|
|
|
|
with open(metadata_dir / "target_dict.tsv") as f: |
|
|
target_dict_str = f.read().strip().replace("'", '"') |
|
|
target_dict = json.loads(target_dict_str) |
|
|
|
|
|
|
|
|
class SimpleMetadata: |
|
|
def __init__(self, n_classes, ontology_embedding_dim, target_dict): |
|
|
self.n_classes = n_classes |
|
|
self.ontology_embedding_dim = ontology_embedding_dim |
|
|
self.target_dicts = [target_dict] |
|
|
|
|
|
return SimpleMetadata(n_classes, ontology_embedding_dim, target_dict) |
|
|
|
|
|
|
|
|
def export_checkpoint(checkpoint_path: Path, output_path: Path, metadata_dir: Path): |
|
|
"""Export PyTorch Lightning checkpoint to pickle format. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to .ckpt file |
|
|
output_path: Path to save .pkl file |
|
|
metadata_dir: Directory containing metadata files |
|
|
""" |
|
|
try: |
|
|
from paladin.pl_modules.aeon import AeonLightningModule |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Failed to import AeonLightningModule. " |
|
|
"Make sure paladin is installed from the git repository:\n" |
|
|
" uv sync --upgrade-package paladin" |
|
|
) |
|
|
|
|
|
print(f"Loading metadata from {metadata_dir}...") |
|
|
metadata = load_metadata(metadata_dir) |
|
|
|
|
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
|
pl_module = AeonLightningModule.load_from_checkpoint( |
|
|
str(checkpoint_path), metadata=metadata |
|
|
) |
|
|
|
|
|
|
|
|
model = pl_module.model |
|
|
|
|
|
print(f"Saving model to {output_path}...") |
|
|
with open(output_path, "wb") as f: |
|
|
pickle.dump(model, f) |
|
|
|
|
|
print(f"✓ Successfully exported checkpoint to {output_path}") |
|
|
|
|
|
|
|
|
file_size = output_path.stat().st_size / (1024 * 1024) |
|
|
print(f" Model size: {file_size:.1f} MB") |
|
|
print(f" Model class: {type(model).__name__}") |
|
|
print(f" Number of classes: {metadata.n_classes}") |
|
|
print(f" Ontology embedding dim: {metadata.ontology_embedding_dim}") |
|
|
print(f" Number of histologies: {len(metadata.target_dicts[0]['histologies'])}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Export Aeon PyTorch Lightning checkpoint to pickle format" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
type=Path, |
|
|
required=True, |
|
|
help="Path to PyTorch Lightning checkpoint (.ckpt)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", type=Path, required=True, help="Path to save exported model (.pkl)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--metadata-dir", |
|
|
type=Path, |
|
|
default=Path("data/metadata"), |
|
|
help="Directory containing metadata files (default: data/metadata)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not args.checkpoint.exists(): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}") |
|
|
|
|
|
if not args.metadata_dir.exists(): |
|
|
raise FileNotFoundError(f"Metadata directory not found: {args.metadata_dir}") |
|
|
|
|
|
|
|
|
args.output.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
export_checkpoint(args.checkpoint, args.output, args.metadata_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|