abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""I/O utilities — JSON save/load with timestamps, checkpointing."""
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
logger = logging.getLogger(__name__)
class NumpyEncoder(json.JSONEncoder):
"""JSON encoder that handles numpy and torch types."""
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, torch.Tensor):
return obj.cpu().numpy().tolist()
return super().default(obj)
def save_json(data: Any, path: str, timestamp: bool = True):
"""Save data as JSON with optional timestamp in filename."""
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
if timestamp:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
stem = p.stem
path_with_ts = p.parent / f"{stem}_{ts}{p.suffix}"
with open(path_with_ts, "w") as f:
json.dump(data, f, indent=2, cls=NumpyEncoder)
# Also save without timestamp (latest)
with open(path, "w") as f:
json.dump(data, f, indent=2, cls=NumpyEncoder)
logger.info(f"Saved: {path} (+ timestamped copy)")
else:
with open(path, "w") as f:
json.dump(data, f, indent=2, cls=NumpyEncoder)
logger.info(f"Saved: {path}")
def load_json(path: str) -> Any:
"""Load JSON data."""
with open(path) as f:
return json.load(f)
def save_tensor(tensor: torch.Tensor, path: str):
"""Save a torch tensor."""
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
torch.save(tensor, path)
logger.info(f"Saved tensor: {path} (shape={tensor.shape})")
def load_tensor(path: str, device: str = "cpu") -> torch.Tensor:
"""Load a torch tensor."""
return torch.load(path, map_location=device, weights_only=True)
def get_checkpoint_path(
results_dir: str,
track: str,
style: str,
method: str,
backbone: str = "primary",
) -> str:
"""Get standardised checkpoint path."""
return os.path.join(results_dir, "captions", backbone, track, style, method)
def check_checkpoint(
results_dir: str,
track: str,
style: str,
method: str,
backbone: str = "primary",
) -> bool:
"""Check if a checkpoint exists for this combination."""
cp = get_checkpoint_path(results_dir, track, style, method, backbone)
results_file = os.path.join(cp, "results.json")
return os.path.exists(results_file)
def save_checkpoint(
data: Any,
results_dir: str,
track: str,
style: str,
method: str,
backbone: str = "primary",
):
"""Save checkpoint for resume capability."""
cp = get_checkpoint_path(results_dir, track, style, method, backbone)
os.makedirs(cp, exist_ok=True)
save_json(data, os.path.join(cp, "results.json"), timestamp=True)
def load_checkpoint(
results_dir: str,
track: str,
style: str,
method: str,
backbone: str = "primary",
) -> Optional[Any]:
"""Load checkpoint if it exists."""
cp = get_checkpoint_path(results_dir, track, style, method, backbone)
results_file = os.path.join(cp, "results.json")
if os.path.exists(results_file):
return load_json(results_file)
return None