Debito commited on
Commit
48d761f
·
verified ·
1 Parent(s): 6ff68fe

Upload checkpoint_manager.py

Browse files
Files changed (1) hide show
  1. checkpoints/checkpoint_manager.py +557 -0
checkpoints/checkpoint_manager.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Checkpoint Manager for Mamba Swarm
3
+ Handles saving, loading, and managing model checkpoints
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import time
9
+ import shutil
10
+ import logging
11
+ import torch
12
+ import threading
13
+ from typing import Dict, List, Any, Optional, Tuple
14
+ from dataclasses import dataclass, asdict
15
+ from pathlib import Path
16
+ from datetime import datetime
17
+ import pickle
18
+ import hashlib
19
+
20
+ @dataclass
21
+ class CheckpointMetadata:
22
+ checkpoint_id: str
23
+ timestamp: float
24
+ epoch: int
25
+ step: int
26
+ loss: float
27
+ model_config: Dict[str, Any]
28
+ training_config: Dict[str, Any]
29
+ metrics: Dict[str, float]
30
+ file_path: str
31
+ file_size: int
32
+ checksum: str
33
+
34
+ class CheckpointManager:
35
+ """Manages model checkpoints for Mamba Swarm"""
36
+
37
+ def __init__(self,
38
+ checkpoint_dir: str = "./checkpoints",
39
+ max_checkpoints: int = 10,
40
+ save_interval: int = 1000,
41
+ best_metric: str = "loss",
42
+ best_metric_mode: str = "min"):
43
+
44
+ self.checkpoint_dir = Path(checkpoint_dir)
45
+ self.max_checkpoints = max_checkpoints
46
+ self.save_interval = save_interval
47
+ self.best_metric = best_metric
48
+ self.best_metric_mode = best_metric_mode
49
+
50
+ self.logger = logging.getLogger(__name__)
51
+ self.lock = threading.Lock()
52
+
53
+ # Create checkpoint directory
54
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ # Metadata storage
57
+ self.metadata_file = self.checkpoint_dir / "metadata.json"
58
+ self.checkpoints_metadata: Dict[str, CheckpointMetadata] = {}
59
+
60
+ # Best checkpoint tracking
61
+ self.best_checkpoint_id: Optional[str] = None
62
+ self.best_metric_value: Optional[float] = None
63
+
64
+ # Load existing metadata
65
+ self._load_metadata()
66
+
67
+ def save_checkpoint(self,
68
+ model_state: Dict[str, Any],
69
+ optimizer_state: Optional[Dict[str, Any]] = None,
70
+ scheduler_state: Optional[Dict[str, Any]] = None,
71
+ epoch: int = 0,
72
+ step: int = 0,
73
+ loss: float = 0.0,
74
+ metrics: Optional[Dict[str, float]] = None,
75
+ model_config: Optional[Dict[str, Any]] = None,
76
+ training_config: Optional[Dict[str, Any]] = None,
77
+ force_save: bool = False) -> str:
78
+ """Save a checkpoint"""
79
+
80
+ # Check if we should save based on interval
81
+ if not force_save and step % self.save_interval != 0:
82
+ return None
83
+
84
+ # Generate checkpoint ID
85
+ checkpoint_id = self._generate_checkpoint_id(epoch, step)
86
+
87
+ # Prepare checkpoint data
88
+ checkpoint_data = {
89
+ "model_state": model_state,
90
+ "optimizer_state": optimizer_state,
91
+ "scheduler_state": scheduler_state,
92
+ "epoch": epoch,
93
+ "step": step,
94
+ "loss": loss,
95
+ "metrics": metrics or {},
96
+ "model_config": model_config or {},
97
+ "training_config": training_config or {},
98
+ "timestamp": time.time()
99
+ }
100
+
101
+ # Save checkpoint file
102
+ checkpoint_path = self.checkpoint_dir / f"{checkpoint_id}.pt"
103
+
104
+ with self.lock:
105
+ try:
106
+ torch.save(checkpoint_data, checkpoint_path)
107
+
108
+ # Calculate file size and checksum
109
+ file_size = checkpoint_path.stat().st_size
110
+ checksum = self._calculate_checksum(checkpoint_path)
111
+
112
+ # Create metadata
113
+ metadata = CheckpointMetadata(
114
+ checkpoint_id=checkpoint_id,
115
+ timestamp=checkpoint_data["timestamp"],
116
+ epoch=epoch,
117
+ step=step,
118
+ loss=loss,
119
+ model_config=model_config or {},
120
+ training_config=training_config or {},
121
+ metrics=metrics or {},
122
+ file_path=str(checkpoint_path),
123
+ file_size=file_size,
124
+ checksum=checksum
125
+ )
126
+
127
+ # Store metadata
128
+ self.checkpoints_metadata[checkpoint_id] = metadata
129
+
130
+ # Update best checkpoint
131
+ self._update_best_checkpoint(checkpoint_id, metrics or {"loss": loss})
132
+
133
+ # Clean up old checkpoints
134
+ self._cleanup_old_checkpoints()
135
+
136
+ # Save metadata
137
+ self._save_metadata()
138
+
139
+ self.logger.info(f"Saved checkpoint {checkpoint_id} at step {step}")
140
+ return checkpoint_id
141
+
142
+ except Exception as e:
143
+ self.logger.error(f"Failed to save checkpoint: {e}")
144
+ # Clean up partial file
145
+ if checkpoint_path.exists():
146
+ checkpoint_path.unlink()
147
+ raise
148
+
149
+ def load_checkpoint(self, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
150
+ """Load a checkpoint"""
151
+
152
+ # Use best checkpoint if none specified
153
+ if checkpoint_id is None:
154
+ checkpoint_id = self.best_checkpoint_id
155
+
156
+ if checkpoint_id is None or checkpoint_id not in self.checkpoints_metadata:
157
+ self.logger.warning(f"Checkpoint {checkpoint_id} not found")
158
+ return None
159
+
160
+ metadata = self.checkpoints_metadata[checkpoint_id]
161
+ checkpoint_path = Path(metadata.file_path)
162
+
163
+ if not checkpoint_path.exists():
164
+ self.logger.error(f"Checkpoint file {checkpoint_path} does not exist")
165
+ return None
166
+
167
+ try:
168
+ # Verify checksum
169
+ if not self._verify_checksum(checkpoint_path, metadata.checksum):
170
+ self.logger.error(f"Checkpoint {checkpoint_id} failed checksum verification")
171
+ return None
172
+
173
+ # Load checkpoint
174
+ checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
175
+
176
+ self.logger.info(f"Loaded checkpoint {checkpoint_id} from step {metadata.step}")
177
+ return checkpoint_data
178
+
179
+ except Exception as e:
180
+ self.logger.error(f"Failed to load checkpoint {checkpoint_id}: {e}")
181
+ return None
182
+
183
+ def load_best_checkpoint(self) -> Optional[Dict[str, Any]]:
184
+ """Load the best checkpoint"""
185
+ return self.load_checkpoint(self.best_checkpoint_id)
186
+
187
+ def load_latest_checkpoint(self) -> Optional[Dict[str, Any]]:
188
+ """Load the most recent checkpoint"""
189
+ if not self.checkpoints_metadata:
190
+ return None
191
+
192
+ # Find latest checkpoint by timestamp
193
+ latest_id = max(self.checkpoints_metadata.keys(),
194
+ key=lambda x: self.checkpoints_metadata[x].timestamp)
195
+
196
+ return self.load_checkpoint(latest_id)
197
+
198
+ def list_checkpoints(self, sort_by: str = "timestamp") -> List[CheckpointMetadata]:
199
+ """List all available checkpoints"""
200
+ checkpoints = list(self.checkpoints_metadata.values())
201
+
202
+ if sort_by == "timestamp":
203
+ checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
204
+ elif sort_by == "step":
205
+ checkpoints.sort(key=lambda x: x.step, reverse=True)
206
+ elif sort_by == "loss":
207
+ checkpoints.sort(key=lambda x: x.loss)
208
+
209
+ return checkpoints
210
+
211
+ def delete_checkpoint(self, checkpoint_id: str) -> bool:
212
+ """Delete a specific checkpoint"""
213
+ if checkpoint_id not in self.checkpoints_metadata:
214
+ self.logger.warning(f"Checkpoint {checkpoint_id} not found")
215
+ return False
216
+
217
+ metadata = self.checkpoints_metadata[checkpoint_id]
218
+ checkpoint_path = Path(metadata.file_path)
219
+
220
+ with self.lock:
221
+ try:
222
+ # Remove file
223
+ if checkpoint_path.exists():
224
+ checkpoint_path.unlink()
225
+
226
+ # Remove from metadata
227
+ del self.checkpoints_metadata[checkpoint_id]
228
+
229
+ # Update best checkpoint if needed
230
+ if checkpoint_id == self.best_checkpoint_id:
231
+ self._find_new_best_checkpoint()
232
+
233
+ # Save metadata
234
+ self._save_metadata()
235
+
236
+ self.logger.info(f"Deleted checkpoint {checkpoint_id}")
237
+ return True
238
+
239
+ except Exception as e:
240
+ self.logger.error(f"Failed to delete checkpoint {checkpoint_id}: {e}")
241
+ return False
242
+
243
+ def get_checkpoint_info(self, checkpoint_id: str) -> Optional[CheckpointMetadata]:
244
+ """Get information about a specific checkpoint"""
245
+ return self.checkpoints_metadata.get(checkpoint_id)
246
+
247
+ def export_checkpoint(self, checkpoint_id: str, export_path: str) -> bool:
248
+ """Export a checkpoint to a different location"""
249
+ if checkpoint_id not in self.checkpoints_metadata:
250
+ self.logger.error(f"Checkpoint {checkpoint_id} not found")
251
+ return False
252
+
253
+ metadata = self.checkpoints_metadata[checkpoint_id]
254
+ source_path = Path(metadata.file_path)
255
+ export_path = Path(export_path)
256
+
257
+ try:
258
+ # Copy checkpoint file
259
+ shutil.copy2(source_path, export_path)
260
+
261
+ # Copy metadata
262
+ metadata_export_path = export_path.with_suffix('.json')
263
+ with open(metadata_export_path, 'w') as f:
264
+ json.dump(asdict(metadata), f, indent=2)
265
+
266
+ self.logger.info(f"Exported checkpoint {checkpoint_id} to {export_path}")
267
+ return True
268
+
269
+ except Exception as e:
270
+ self.logger.error(f"Failed to export checkpoint {checkpoint_id}: {e}")
271
+ return False
272
+
273
+ def import_checkpoint(self, checkpoint_path: str, metadata_path: Optional[str] = None) -> Optional[str]:
274
+ """Import a checkpoint from external location"""
275
+ checkpoint_path = Path(checkpoint_path)
276
+
277
+ if not checkpoint_path.exists():
278
+ self.logger.error(f"Checkpoint file {checkpoint_path} does not exist")
279
+ return None
280
+
281
+ try:
282
+ # Load metadata if provided
283
+ if metadata_path:
284
+ with open(metadata_path, 'r') as f:
285
+ metadata_dict = json.load(f)
286
+ metadata = CheckpointMetadata(**metadata_dict)
287
+ else:
288
+ # Try to extract metadata from checkpoint
289
+ checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
290
+ metadata = CheckpointMetadata(
291
+ checkpoint_id=self._generate_checkpoint_id(
292
+ checkpoint_data.get("epoch", 0),
293
+ checkpoint_data.get("step", 0)
294
+ ),
295
+ timestamp=checkpoint_data.get("timestamp", time.time()),
296
+ epoch=checkpoint_data.get("epoch", 0),
297
+ step=checkpoint_data.get("step", 0),
298
+ loss=checkpoint_data.get("loss", 0.0),
299
+ model_config=checkpoint_data.get("model_config", {}),
300
+ training_config=checkpoint_data.get("training_config", {}),
301
+ metrics=checkpoint_data.get("metrics", {}),
302
+ file_path="", # Will be set below
303
+ file_size=0, # Will be set below
304
+ checksum="" # Will be set below
305
+ )
306
+
307
+ # Copy to checkpoint directory
308
+ new_checkpoint_path = self.checkpoint_dir / f"{metadata.checkpoint_id}.pt"
309
+ shutil.copy2(checkpoint_path, new_checkpoint_path)
310
+
311
+ # Update metadata
312
+ metadata.file_path = str(new_checkpoint_path)
313
+ metadata.file_size = new_checkpoint_path.stat().st_size
314
+ metadata.checksum = self._calculate_checksum(new_checkpoint_path)
315
+
316
+ with self.lock:
317
+ self.checkpoints_metadata[metadata.checkpoint_id] = metadata
318
+ self._update_best_checkpoint(metadata.checkpoint_id, metadata.metrics)
319
+ self._save_metadata()
320
+
321
+ self.logger.info(f"Imported checkpoint {metadata.checkpoint_id}")
322
+ return metadata.checkpoint_id
323
+
324
+ except Exception as e:
325
+ self.logger.error(f"Failed to import checkpoint: {e}")
326
+ return None
327
+
328
+ def _generate_checkpoint_id(self, epoch: int, step: int) -> str:
329
+ """Generate unique checkpoint ID"""
330
+ timestamp = int(time.time())
331
+ return f"checkpoint_epoch_{epoch}_step_{step}_{timestamp}"
332
+
333
+ def _calculate_checksum(self, file_path: Path) -> str:
334
+ """Calculate MD5 checksum of file"""
335
+ hash_md5 = hashlib.md5()
336
+ with open(file_path, "rb") as f:
337
+ for chunk in iter(lambda: f.read(4096), b""):
338
+ hash_md5.update(chunk)
339
+ return hash_md5.hexdigest()
340
+
341
+ def _verify_checksum(self, file_path: Path, expected_checksum: str) -> bool:
342
+ """Verify file checksum"""
343
+ actual_checksum = self._calculate_checksum(file_path)
344
+ return actual_checksum == expected_checksum
345
+
346
+ def _update_best_checkpoint(self, checkpoint_id: str, metrics: Dict[str, float]):
347
+ """Update best checkpoint based on metrics"""
348
+ if self.best_metric not in metrics:
349
+ return
350
+
351
+ metric_value = metrics[self.best_metric]
352
+
353
+ if self.best_metric_value is None:
354
+ # First checkpoint
355
+ self.best_checkpoint_id = checkpoint_id
356
+ self.best_metric_value = metric_value
357
+ else:
358
+ # Compare with current best
359
+ is_better = False
360
+ if self.best_metric_mode == "min":
361
+ is_better = metric_value < self.best_metric_value
362
+ elif self.best_metric_mode == "max":
363
+ is_better = metric_value > self.best_metric_value
364
+
365
+ if is_better:
366
+ self.best_checkpoint_id = checkpoint_id
367
+ self.best_metric_value = metric_value
368
+ self.logger.info(f"New best checkpoint: {checkpoint_id} ({self.best_metric}: {metric_value})")
369
+
370
+ def _find_new_best_checkpoint(self):
371
+ """Find new best checkpoint after deletion"""
372
+ if not self.checkpoints_metadata:
373
+ self.best_checkpoint_id = None
374
+ self.best_metric_value = None
375
+ return
376
+
377
+ best_id = None
378
+ best_value = None
379
+
380
+ for checkpoint_id, metadata in self.checkpoints_metadata.items():
381
+ if self.best_metric in metadata.metrics:
382
+ metric_value = metadata.metrics[self.best_metric]
383
+
384
+ if best_value is None:
385
+ best_id = checkpoint_id
386
+ best_value = metric_value
387
+ else:
388
+ is_better = False
389
+ if self.best_metric_mode == "min":
390
+ is_better = metric_value < best_value
391
+ elif self.best_metric_mode == "max":
392
+ is_better = metric_value > best_value
393
+
394
+ if is_better:
395
+ best_id = checkpoint_id
396
+ best_value = metric_value
397
+
398
+ self.best_checkpoint_id = best_id
399
+ self.best_metric_value = best_value
400
+
401
+ def _cleanup_old_checkpoints(self):
402
+ """Remove old checkpoints to maintain max_checkpoints limit"""
403
+ if len(self.checkpoints_metadata) <= self.max_checkpoints:
404
+ return
405
+
406
+ # Sort by timestamp (oldest first)
407
+ sorted_checkpoints = sorted(
408
+ self.checkpoints_metadata.items(),
409
+ key=lambda x: x[1].timestamp
410
+ )
411
+
412
+ # Calculate how many to remove
413
+ num_to_remove = len(sorted_checkpoints) - self.max_checkpoints
414
+
415
+ for i in range(num_to_remove):
416
+ checkpoint_id, metadata = sorted_checkpoints[i]
417
+
418
+ # Don't delete the best checkpoint
419
+ if checkpoint_id == self.best_checkpoint_id:
420
+ continue
421
+
422
+ # Delete checkpoint
423
+ checkpoint_path = Path(metadata.file_path)
424
+ if checkpoint_path.exists():
425
+ checkpoint_path.unlink()
426
+
427
+ del self.checkpoints_metadata[checkpoint_id]
428
+ self.logger.info(f"Cleaned up old checkpoint: {checkpoint_id}")
429
+
430
+ def _load_metadata(self):
431
+ """Load checkpoint metadata from file"""
432
+ if not self.metadata_file.exists():
433
+ return
434
+
435
+ try:
436
+ with open(self.metadata_file, 'r') as f:
437
+ data = json.load(f)
438
+
439
+ # Load checkpoint metadata
440
+ for checkpoint_id, metadata_dict in data.get("checkpoints", {}).items():
441
+ metadata = CheckpointMetadata(**metadata_dict)
442
+ self.checkpoints_metadata[checkpoint_id] = metadata
443
+
444
+ # Load best checkpoint info
445
+ self.best_checkpoint_id = data.get("best_checkpoint_id")
446
+ self.best_metric_value = data.get("best_metric_value")
447
+
448
+ self.logger.info(f"Loaded metadata for {len(self.checkpoints_metadata)} checkpoints")
449
+
450
+ except Exception as e:
451
+ self.logger.error(f"Failed to load metadata: {e}")
452
+
453
+ def _save_metadata(self):
454
+ """Save checkpoint metadata to file"""
455
+ try:
456
+ data = {
457
+ "checkpoints": {
458
+ checkpoint_id: asdict(metadata)
459
+ for checkpoint_id, metadata in self.checkpoints_metadata.items()
460
+ },
461
+ "best_checkpoint_id": self.best_checkpoint_id,
462
+ "best_metric_value": self.best_metric_value,
463
+ "last_updated": time.time()
464
+ }
465
+
466
+ # Write to temporary file first
467
+ temp_file = self.metadata_file.with_suffix('.tmp')
468
+ with open(temp_file, 'w') as f:
469
+ json.dump(data, f, indent=2)
470
+
471
+ # Atomic rename
472
+ temp_file.replace(self.metadata_file)
473
+
474
+ except Exception as e:
475
+ self.logger.error(f"Failed to save metadata: {e}")
476
+
477
+ def get_storage_usage(self) -> Dict[str, Any]:
478
+ """Get storage usage statistics"""
479
+ total_size = 0
480
+ checkpoint_count = len(self.checkpoints_metadata)
481
+
482
+ for metadata in self.checkpoints_metadata.values():
483
+ total_size += metadata.file_size
484
+
485
+ return {
486
+ "total_size_bytes": total_size,
487
+ "total_size_mb": total_size / (1024 * 1024),
488
+ "total_size_gb": total_size / (1024 * 1024 * 1024),
489
+ "checkpoint_count": checkpoint_count,
490
+ "average_size_mb": (total_size / checkpoint_count / (1024 * 1024)) if checkpoint_count > 0 else 0,
491
+ "checkpoint_directory": str(self.checkpoint_dir)
492
+ }
493
+
494
+ def cleanup_all_checkpoints(self):
495
+ """Remove all checkpoints (dangerous operation)"""
496
+ with self.lock:
497
+ for metadata in self.checkpoints_metadata.values():
498
+ checkpoint_path = Path(metadata.file_path)
499
+ if checkpoint_path.exists():
500
+ checkpoint_path.unlink()
501
+
502
+ self.checkpoints_metadata.clear()
503
+ self.best_checkpoint_id = None
504
+ self.best_metric_value = None
505
+
506
+ # Remove metadata file
507
+ if self.metadata_file.exists():
508
+ self.metadata_file.unlink()
509
+
510
+ self.logger.info("Cleaned up all checkpoints")
511
+
512
+ # Example usage and testing
513
+ if __name__ == "__main__":
514
+ # Create checkpoint manager
515
+ checkpoint_manager = CheckpointManager(
516
+ checkpoint_dir="./test_checkpoints",
517
+ max_checkpoints=5,
518
+ save_interval=100
519
+ )
520
+
521
+ # Simulate saving checkpoints
522
+ for step in range(0, 1000, 100):
523
+ model_state = {"layer_weights": torch.randn(10, 10)}
524
+ optimizer_state = {"param_groups": [{"lr": 0.001}]}
525
+
526
+ metrics = {
527
+ "loss": 1.0 - step / 1000.0, # Decreasing loss
528
+ "accuracy": step / 1000.0 # Increasing accuracy
529
+ }
530
+
531
+ checkpoint_id = checkpoint_manager.save_checkpoint(
532
+ model_state=model_state,
533
+ optimizer_state=optimizer_state,
534
+ step=step,
535
+ loss=metrics["loss"],
536
+ metrics=metrics,
537
+ force_save=True
538
+ )
539
+
540
+ print(f"Saved checkpoint: {checkpoint_id}")
541
+
542
+ # List checkpoints
543
+ print("\nAvailable checkpoints:")
544
+ for metadata in checkpoint_manager.list_checkpoints():
545
+ print(f" {metadata.checkpoint_id}: step {metadata.step}, loss {metadata.loss:.3f}")
546
+
547
+ # Load best checkpoint
548
+ best_checkpoint = checkpoint_manager.load_best_checkpoint()
549
+ print(f"\nLoaded best checkpoint: {checkpoint_manager.best_checkpoint_id}")
550
+
551
+ # Get storage usage
552
+ usage = checkpoint_manager.get_storage_usage()
553
+ print(f"\nStorage usage: {usage['total_size_mb']:.2f} MB ({usage['checkpoint_count']} checkpoints)")
554
+
555
+ # Cleanup
556
+ checkpoint_manager.cleanup_all_checkpoints()
557
+ print("Cleaned up test checkpoints")