| """Save privileged teacher images to disk on detail steps.""" |
| from __future__ import annotations |
|
|
| import json |
| import os |
| from typing import Any, Optional |
|
|
| from PIL import Image |
|
|
| from opsd_utils import debug_log as opsd_debug |
|
|
| _saved_counts: dict[int, int] = {} |
| _output_dir: Optional[str] = None |
| _cfg: dict[str, Any] = {} |
|
|
|
|
| def configure(output_dir: Optional[str] = None, privileged_debug_cfg: Optional[dict[str, Any]] = None) -> None: |
| global _output_dir, _cfg |
| _output_dir = output_dir |
| _cfg = dict(privileged_debug_cfg or {}) |
| _saved_counts.clear() |
|
|
|
|
| def _image_subdir() -> str: |
| return _cfg.get("image_subdir", "logs/images") |
|
|
|
|
| def maybe_save_privileged_images( |
| global_step: Optional[int], |
| sample_idx: int, |
| full_img: Optional[Image.Image], |
| crop_img: Optional[Image.Image], |
| meta: Optional[dict[str, Any]] = None, |
| output_dir: Optional[str] = None, |
| privileged_debug_cfg: Optional[dict[str, Any]] = None, |
| ) -> Optional[str]: |
| """ |
| Save teacher privileged images when should_log_detail(global_step) is true. |
| Returns base path prefix if saved, else None. |
| """ |
| if global_step is None: |
| return None |
| if not opsd_debug.should_log_detail(global_step): |
| return None |
|
|
| cfg = privileged_debug_cfg if privileged_debug_cfg is not None else _cfg |
| if not cfg.get("save_images", True): |
| return None |
|
|
| max_samples = int(cfg.get("max_samples_per_detail", 2)) |
| count = _saved_counts.get(global_step, 0) |
| if count >= max_samples: |
| return None |
|
|
| base_out = output_dir or _output_dir |
| if not base_out: |
| opsd_debug.log( |
| "privileged_debug", |
| "skip image save (no output_dir)", |
| global_step=global_step, |
| sample_idx=sample_idx, |
| ) |
| return None |
|
|
| subdir = os.path.join(base_out, _image_subdir() if cfg is _cfg else cfg.get("image_subdir", "logs/images")) |
| os.makedirs(subdir, exist_ok=True) |
|
|
| prefix = os.path.join(subdir, f"step_{int(global_step):06d}_idx_{sample_idx}") |
| saved_paths: list[str] = [] |
|
|
| if full_img is not None: |
| full_path = f"{prefix}_full.png" |
| full_img.save(full_path) |
| saved_paths.append(full_path) |
|
|
| if crop_img is not None: |
| crop_path = f"{prefix}_crop.png" |
| crop_img.save(crop_path) |
| saved_paths.append(crop_path) |
|
|
| meta_path = f"{prefix}_meta.json" |
| meta_payload = dict(meta or {}) |
| meta_payload.update( |
| { |
| "global_step": global_step, |
| "sample_idx": sample_idx, |
| "saved_paths": saved_paths, |
| } |
| ) |
| with open(meta_path, "w", encoding="utf-8") as f: |
| json.dump(meta_payload, f, ensure_ascii=False, indent=2) |
|
|
| _saved_counts[global_step] = count + 1 |
| opsd_debug.log_detail( |
| "privileged_debug", |
| "privileged images saved", |
| global_step=global_step, |
| sample_idx=sample_idx, |
| prefix=prefix, |
| saved_paths=saved_paths, |
| meta_path=meta_path, |
| **{k: v for k, v in (meta or {}).items() if k not in ("full_size", "crop_size")}, |
| ) |
| return prefix |
|
|