File size: 3,109 Bytes
36d0b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Save privileged teacher images to disk on detail steps."""
from __future__ import annotations

import json
import os
from typing import Any, Optional

from PIL import Image

from opsd_utils import debug_log as opsd_debug

_saved_counts: dict[int, int] = {}
_output_dir: Optional[str] = None
_cfg: dict[str, Any] = {}


def configure(output_dir: Optional[str] = None, privileged_debug_cfg: Optional[dict[str, Any]] = None) -> None:
    global _output_dir, _cfg
    _output_dir = output_dir
    _cfg = dict(privileged_debug_cfg or {})
    _saved_counts.clear()


def _image_subdir() -> str:
    return _cfg.get("image_subdir", "logs/images")


def maybe_save_privileged_images(
    global_step: Optional[int],
    sample_idx: int,
    full_img: Optional[Image.Image],
    crop_img: Optional[Image.Image],
    meta: Optional[dict[str, Any]] = None,
    output_dir: Optional[str] = None,
    privileged_debug_cfg: Optional[dict[str, Any]] = None,
) -> Optional[str]:
    """
    Save teacher privileged images when should_log_detail(global_step) is true.
    Returns base path prefix if saved, else None.
    """
    if global_step is None:
        return None
    if not opsd_debug.should_log_detail(global_step):
        return None

    cfg = privileged_debug_cfg if privileged_debug_cfg is not None else _cfg
    if not cfg.get("save_images", True):
        return None

    max_samples = int(cfg.get("max_samples_per_detail", 2))
    count = _saved_counts.get(global_step, 0)
    if count >= max_samples:
        return None

    base_out = output_dir or _output_dir
    if not base_out:
        opsd_debug.log(
            "privileged_debug",
            "skip image save (no output_dir)",
            global_step=global_step,
            sample_idx=sample_idx,
        )
        return None

    subdir = os.path.join(base_out, _image_subdir() if cfg is _cfg else cfg.get("image_subdir", "logs/images"))
    os.makedirs(subdir, exist_ok=True)

    prefix = os.path.join(subdir, f"step_{int(global_step):06d}_idx_{sample_idx}")
    saved_paths: list[str] = []

    if full_img is not None:
        full_path = f"{prefix}_full.png"
        full_img.save(full_path)
        saved_paths.append(full_path)

    if crop_img is not None:
        crop_path = f"{prefix}_crop.png"
        crop_img.save(crop_path)
        saved_paths.append(crop_path)

    meta_path = f"{prefix}_meta.json"
    meta_payload = dict(meta or {})
    meta_payload.update(
        {
            "global_step": global_step,
            "sample_idx": sample_idx,
            "saved_paths": saved_paths,
        }
    )
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta_payload, f, ensure_ascii=False, indent=2)

    _saved_counts[global_step] = count + 1
    opsd_debug.log_detail(
        "privileged_debug",
        "privileged images saved",
        global_step=global_step,
        sample_idx=sample_idx,
        prefix=prefix,
        saved_paths=saved_paths,
        meta_path=meta_path,
        **{k: v for k, v in (meta or {}).items() if k not in ("full_size", "crop_size")},
    )
    return prefix