Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Export Aeon PyTorch Lightning checkpoint to pickle format for inference. | |
| This script converts a PyTorch Lightning checkpoint (.ckpt) file to a pickle | |
| (.pkl) file that can be used with the Mosaic inference pipeline. | |
| Usage: | |
| python export_aeon_checkpoint.py \ | |
| --checkpoint data/checkpoint.ckpt \ | |
| --output data/aeon_model.pkl \ | |
| --metadata-dir data/metadata | |
| Requirements: | |
| - paladin package from git repo (must have AeonLightningModule) | |
| - PyTorch Lightning | |
| - Access to metadata files (n_classes.txt, ontology_embedding_dim.txt, target_dict.tsv) | |
| """ | |
| import argparse | |
| import json | |
| import pickle | |
| from pathlib import Path | |
| def load_metadata(metadata_dir: Path): | |
| """Load metadata required for model initialization. | |
| Args: | |
| metadata_dir: Directory containing metadata files | |
| Returns: | |
| SimpleMetadata object with n_classes, ontology_embedding_dim, and target_dicts | |
| """ | |
| # Read n_classes | |
| with open(metadata_dir / "n_classes.txt") as f: | |
| n_classes = int(f.read().strip()) | |
| # Read ontology_embedding_dim | |
| with open(metadata_dir / "ontology_embedding_dim.txt") as f: | |
| ontology_embedding_dim = int(f.read().strip()) | |
| # Read target_dict (JSON format with single quotes) | |
| with open(metadata_dir / "target_dict.tsv") as f: | |
| target_dict_str = f.read().strip().replace("'", '"') | |
| target_dict = json.loads(target_dict_str) | |
| # Create simple metadata object | |
| class SimpleMetadata: | |
| def __init__(self, n_classes, ontology_embedding_dim, target_dict): | |
| self.n_classes = n_classes | |
| self.ontology_embedding_dim = ontology_embedding_dim | |
| self.target_dicts = [target_dict] | |
| return SimpleMetadata(n_classes, ontology_embedding_dim, target_dict) | |
| def export_checkpoint(checkpoint_path: Path, output_path: Path, metadata_dir: Path): | |
| """Export PyTorch Lightning checkpoint to pickle format. | |
| Args: | |
| checkpoint_path: Path to .ckpt file | |
| output_path: Path to save .pkl file | |
| metadata_dir: Directory containing metadata files | |
| """ | |
| try: | |
| from paladin.pl_modules.aeon import AeonLightningModule | |
| except ImportError: | |
| raise ImportError( | |
| "Failed to import AeonLightningModule. " | |
| "Make sure paladin is installed from the git repository:\n" | |
| " uv sync --upgrade-package paladin" | |
| ) | |
| print(f"Loading metadata from {metadata_dir}...") | |
| metadata = load_metadata(metadata_dir) | |
| print(f"Loading checkpoint from {checkpoint_path}...") | |
| pl_module = AeonLightningModule.load_from_checkpoint( | |
| str(checkpoint_path), | |
| metadata=metadata | |
| ) | |
| # Extract the model | |
| model = pl_module.model | |
| print(f"Saving model to {output_path}...") | |
| with open(output_path, "wb") as f: | |
| pickle.dump(model, f) | |
| print(f"✓ Successfully exported checkpoint to {output_path}") | |
| # Print model info | |
| file_size = output_path.stat().st_size / (1024 * 1024) # MB | |
| print(f" Model size: {file_size:.1f} MB") | |
| print(f" Model class: {type(model).__name__}") | |
| print(f" Number of classes: {metadata.n_classes}") | |
| print(f" Ontology embedding dim: {metadata.ontology_embedding_dim}") | |
| print(f" Number of histologies: {len(metadata.target_dicts[0]['histologies'])}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Export Aeon PyTorch Lightning checkpoint to pickle format" | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=Path, | |
| required=True, | |
| help="Path to PyTorch Lightning checkpoint (.ckpt)" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| required=True, | |
| help="Path to save exported model (.pkl)" | |
| ) | |
| parser.add_argument( | |
| "--metadata-dir", | |
| type=Path, | |
| default=Path("data/metadata"), | |
| help="Directory containing metadata files (default: data/metadata)" | |
| ) | |
| args = parser.parse_args() | |
| # Validate inputs | |
| if not args.checkpoint.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}") | |
| if not args.metadata_dir.exists(): | |
| raise FileNotFoundError(f"Metadata directory not found: {args.metadata_dir}") | |
| # Create output directory if needed | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| # Export checkpoint | |
| export_checkpoint(args.checkpoint, args.output, args.metadata_dir) | |
| if __name__ == "__main__": | |
| main() | |