File size: 5,023 Bytes
c474610 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | #!/usr/bin/env python
"""Static server + /api/save for the xj_mjhq30k model-comparison viewer.
Serves the repo statically and adds:
- GET /api/manifest : live re-scan of the sibling output dirs (so new images
appear without re-running gen_manifest.py)
- POST /api/save : copy each model's image for the current prompt + the
prompt text into ../selected/<idx>/
Use this (not plain http.server) so the Save button works:
python test/xj_mjhq30k_inference_outputs/viewer/viewer.py # :8765
python test/xj_mjhq30k_inference_outputs/viewer/viewer.py --port N
Then open http://<host>:<port>/test/xj_mjhq30k_inference_outputs/viewer/index.html
"""
import argparse
import json
import shutil
import sys
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from urllib.parse import urlparse
HERE = Path(__file__).resolve().parent # .../xj_mjhq30k_inference_outputs/viewer
OUTPUTS = HERE.parent # .../xj_mjhq30k_inference_outputs
REPO = HERE.parents[2] # repo root
SELECTED = OUTPUTS / "selected"
MANIFEST = HERE / "manifest.json"
PROMPTS_JSON = REPO / "assets" / "xj_mjhq30k_prompts.json"
REL_PREFIX = ".."
sys.path.insert(0, str(HERE))
from gen_manifest import build_manifest # noqa: E402
def save_selection(idx: int) -> dict:
# Prefer the live manifest so the dir list is current; fall back to the file.
try:
manifest = build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX)
except Exception:
manifest = json.loads(MANIFEST.read_text())
pad = manifest["pad"]
stem = str(idx).zfill(pad)
out = SELECTED / stem
out.mkdir(parents=True, exist_ok=True)
saved = []
for m in manifest["models"]:
src = OUTPUTS / m["dir"] / f"{stem}.png"
if src.exists():
shutil.copy(src, out / f"{m['name']}.png")
saved.append(m["name"])
prompt = manifest["prompts"][idx] if idx < len(manifest["prompts"]) else ""
(out / "prompt.txt").write_text(prompt)
(out / "meta.json").write_text(json.dumps({"id": idx, "saved_models": saved, "prompt": prompt}, indent=2))
return {"ok": True, "path": str(out.relative_to(REPO)), "saved_models": saved}
class Handler(SimpleHTTPRequestHandler):
def __init__(self, *a, **k):
super().__init__(*a, directory=str(REPO), **k)
def log_message(self, *a): # quiet
return
def _json(self, payload, status=200):
body = json.dumps(payload).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(body)
def do_GET(self):
path = urlparse(self.path).path
if path == "/api/manifest":
try:
self._json(build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX))
except Exception as e: # noqa: BLE001
self._json({"error": str(e)}, 500)
return
if path == "/api/selected":
ids = []
if SELECTED.exists():
for p in SELECTED.iterdir():
if p.is_dir() and p.name.isdigit():
ids.append(int(p.name))
self._json({"ids": sorted(ids)})
return
super().do_GET()
def do_POST(self):
if urlparse(self.path).path != "/api/save":
self.send_error(404)
return
length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(length) or b"{}")
idx = int(body.get("idx", -1))
if idx < 0:
self._json({"ok": False, "error": "bad idx"}, 400)
return
try:
self._json(save_selection(idx))
except Exception as e: # noqa: BLE001
self._json({"ok": False, "error": str(e)}, 500)
def main():
import socket
ap = argparse.ArgumentParser()
ap.add_argument("--host", default="0.0.0.0")
ap.add_argument("--port", type=int, default=8765)
args = ap.parse_args()
SELECTED.mkdir(parents=True, exist_ok=True)
host = socket.gethostname()
page = "/test/xj_mjhq30k_inference_outputs/viewer/index.html"
srv = ThreadingHTTPServer((args.host, args.port), Handler)
print("=" * 60, flush=True)
print(f" open http://{host}:{args.port}{page}", flush=True)
print(f" http://localhost:{args.port}{page}", flush=True)
print(f" (remote? ssh -L {args.port}:localhost:{args.port} {host})", flush=True)
print(f" saves -> {SELECTED.relative_to(REPO)}/<idx>/", flush=True)
print(" Ctrl-C to stop.", flush=True)
print("=" * 60, flush=True)
try:
srv.serve_forever()
except KeyboardInterrupt:
print("\n[viewer] shutting down")
srv.server_close()
if __name__ == "__main__":
main()
|