File size: 5,023 Bytes
c474610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
"""Static server + /api/save for the xj_mjhq30k model-comparison viewer.

Serves the repo statically and adds:
  - GET  /api/manifest : live re-scan of the sibling output dirs (so new images
                         appear without re-running gen_manifest.py)
  - POST /api/save     : copy each model's image for the current prompt + the
                         prompt text into ../selected/<idx>/

Use this (not plain http.server) so the Save button works:
    python test/xj_mjhq30k_inference_outputs/viewer/viewer.py            # :8765
    python test/xj_mjhq30k_inference_outputs/viewer/viewer.py --port N

Then open http://<host>:<port>/test/xj_mjhq30k_inference_outputs/viewer/index.html
"""
import argparse
import json
import shutil
import sys
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from urllib.parse import urlparse

HERE = Path(__file__).resolve().parent          # .../xj_mjhq30k_inference_outputs/viewer
OUTPUTS = HERE.parent                            # .../xj_mjhq30k_inference_outputs
REPO = HERE.parents[2]                           # repo root
SELECTED = OUTPUTS / "selected"
MANIFEST = HERE / "manifest.json"
PROMPTS_JSON = REPO / "assets" / "xj_mjhq30k_prompts.json"
REL_PREFIX = ".."

sys.path.insert(0, str(HERE))
from gen_manifest import build_manifest  # noqa: E402


def save_selection(idx: int) -> dict:
    # Prefer the live manifest so the dir list is current; fall back to the file.
    try:
        manifest = build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX)
    except Exception:
        manifest = json.loads(MANIFEST.read_text())
    pad = manifest["pad"]
    stem = str(idx).zfill(pad)
    out = SELECTED / stem
    out.mkdir(parents=True, exist_ok=True)

    saved = []
    for m in manifest["models"]:
        src = OUTPUTS / m["dir"] / f"{stem}.png"
        if src.exists():
            shutil.copy(src, out / f"{m['name']}.png")
            saved.append(m["name"])

    prompt = manifest["prompts"][idx] if idx < len(manifest["prompts"]) else ""
    (out / "prompt.txt").write_text(prompt)
    (out / "meta.json").write_text(json.dumps({"id": idx, "saved_models": saved, "prompt": prompt}, indent=2))
    return {"ok": True, "path": str(out.relative_to(REPO)), "saved_models": saved}


class Handler(SimpleHTTPRequestHandler):
    def __init__(self, *a, **k):
        super().__init__(*a, directory=str(REPO), **k)

    def log_message(self, *a):  # quiet
        return

    def _json(self, payload, status=200):
        body = json.dumps(payload).encode()
        self.send_response(status)
        self.send_header("Content-Type", "application/json")
        self.send_header("Content-Length", str(len(body)))
        self.send_header("Cache-Control", "no-store")
        self.end_headers()
        self.wfile.write(body)

    def do_GET(self):
        path = urlparse(self.path).path
        if path == "/api/manifest":
            try:
                self._json(build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX))
            except Exception as e:  # noqa: BLE001
                self._json({"error": str(e)}, 500)
            return
        if path == "/api/selected":
            ids = []
            if SELECTED.exists():
                for p in SELECTED.iterdir():
                    if p.is_dir() and p.name.isdigit():
                        ids.append(int(p.name))
            self._json({"ids": sorted(ids)})
            return
        super().do_GET()

    def do_POST(self):
        if urlparse(self.path).path != "/api/save":
            self.send_error(404)
            return
        length = int(self.headers.get("Content-Length", 0))
        body = json.loads(self.rfile.read(length) or b"{}")
        idx = int(body.get("idx", -1))
        if idx < 0:
            self._json({"ok": False, "error": "bad idx"}, 400)
            return
        try:
            self._json(save_selection(idx))
        except Exception as e:  # noqa: BLE001
            self._json({"ok": False, "error": str(e)}, 500)


def main():
    import socket
    ap = argparse.ArgumentParser()
    ap.add_argument("--host", default="0.0.0.0")
    ap.add_argument("--port", type=int, default=8765)
    args = ap.parse_args()
    SELECTED.mkdir(parents=True, exist_ok=True)
    host = socket.gethostname()
    page = "/test/xj_mjhq30k_inference_outputs/viewer/index.html"
    srv = ThreadingHTTPServer((args.host, args.port), Handler)
    print("=" * 60, flush=True)
    print(f"  open   http://{host}:{args.port}{page}", flush=True)
    print(f"         http://localhost:{args.port}{page}", flush=True)
    print(f"  (remote? ssh -L {args.port}:localhost:{args.port} {host})", flush=True)
    print(f"  saves  -> {SELECTED.relative_to(REPO)}/<idx>/", flush=True)
    print("  Ctrl-C to stop.", flush=True)
    print("=" * 60, flush=True)
    try:
        srv.serve_forever()
    except KeyboardInterrupt:
        print("\n[viewer] shutting down")
        srv.server_close()


if __name__ == "__main__":
    main()