Spaces:
Sleeping
Sleeping
| """I/O utilities — JSON save/load with timestamps, checkpointing.""" | |
| import json | |
| import logging | |
| import os | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| import numpy as np | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class NumpyEncoder(json.JSONEncoder): | |
| """JSON encoder that handles numpy and torch types.""" | |
| def default(self, obj): | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, torch.Tensor): | |
| return obj.cpu().numpy().tolist() | |
| return super().default(obj) | |
| def save_json(data: Any, path: str, timestamp: bool = True): | |
| """Save data as JSON with optional timestamp in filename.""" | |
| p = Path(path) | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| if timestamp: | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| stem = p.stem | |
| path_with_ts = p.parent / f"{stem}_{ts}{p.suffix}" | |
| with open(path_with_ts, "w") as f: | |
| json.dump(data, f, indent=2, cls=NumpyEncoder) | |
| # Also save without timestamp (latest) | |
| with open(path, "w") as f: | |
| json.dump(data, f, indent=2, cls=NumpyEncoder) | |
| logger.info(f"Saved: {path} (+ timestamped copy)") | |
| else: | |
| with open(path, "w") as f: | |
| json.dump(data, f, indent=2, cls=NumpyEncoder) | |
| logger.info(f"Saved: {path}") | |
| def load_json(path: str) -> Any: | |
| """Load JSON data.""" | |
| with open(path) as f: | |
| return json.load(f) | |
| def save_tensor(tensor: torch.Tensor, path: str): | |
| """Save a torch tensor.""" | |
| p = Path(path) | |
| p.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save(tensor, path) | |
| logger.info(f"Saved tensor: {path} (shape={tensor.shape})") | |
| def load_tensor(path: str, device: str = "cpu") -> torch.Tensor: | |
| """Load a torch tensor.""" | |
| return torch.load(path, map_location=device, weights_only=True) | |
| def get_checkpoint_path( | |
| results_dir: str, | |
| track: str, | |
| style: str, | |
| method: str, | |
| backbone: str = "primary", | |
| ) -> str: | |
| """Get standardised checkpoint path.""" | |
| return os.path.join(results_dir, "captions", backbone, track, style, method) | |
| def check_checkpoint( | |
| results_dir: str, | |
| track: str, | |
| style: str, | |
| method: str, | |
| backbone: str = "primary", | |
| ) -> bool: | |
| """Check if a checkpoint exists for this combination.""" | |
| cp = get_checkpoint_path(results_dir, track, style, method, backbone) | |
| results_file = os.path.join(cp, "results.json") | |
| return os.path.exists(results_file) | |
| def save_checkpoint( | |
| data: Any, | |
| results_dir: str, | |
| track: str, | |
| style: str, | |
| method: str, | |
| backbone: str = "primary", | |
| ): | |
| """Save checkpoint for resume capability.""" | |
| cp = get_checkpoint_path(results_dir, track, style, method, backbone) | |
| os.makedirs(cp, exist_ok=True) | |
| save_json(data, os.path.join(cp, "results.json"), timestamp=True) | |
| def load_checkpoint( | |
| results_dir: str, | |
| track: str, | |
| style: str, | |
| method: str, | |
| backbone: str = "primary", | |
| ) -> Optional[Any]: | |
| """Load checkpoint if it exists.""" | |
| cp = get_checkpoint_path(results_dir, track, style, method, backbone) | |
| results_file = os.path.join(cp, "results.json") | |
| if os.path.exists(results_file): | |
| return load_json(results_file) | |
| return None | |