tmp_1 / viewer /viewer.py
rockeycoss's picture
Add files using upload-large-folder tool
c474610 verified
Raw
History Blame Contribute Delete
5.02 kB
#!/usr/bin/env python
"""Static server + /api/save for the xj_mjhq30k model-comparison viewer.
Serves the repo statically and adds:
- GET /api/manifest : live re-scan of the sibling output dirs (so new images
appear without re-running gen_manifest.py)
- POST /api/save : copy each model's image for the current prompt + the
prompt text into ../selected/<idx>/
Use this (not plain http.server) so the Save button works:
python test/xj_mjhq30k_inference_outputs/viewer/viewer.py # :8765
python test/xj_mjhq30k_inference_outputs/viewer/viewer.py --port N
Then open http://<host>:<port>/test/xj_mjhq30k_inference_outputs/viewer/index.html
"""
import argparse
import json
import shutil
import sys
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from urllib.parse import urlparse
HERE = Path(__file__).resolve().parent # .../xj_mjhq30k_inference_outputs/viewer
OUTPUTS = HERE.parent # .../xj_mjhq30k_inference_outputs
REPO = HERE.parents[2] # repo root
SELECTED = OUTPUTS / "selected"
MANIFEST = HERE / "manifest.json"
PROMPTS_JSON = REPO / "assets" / "xj_mjhq30k_prompts.json"
REL_PREFIX = ".."
sys.path.insert(0, str(HERE))
from gen_manifest import build_manifest # noqa: E402
def save_selection(idx: int) -> dict:
# Prefer the live manifest so the dir list is current; fall back to the file.
try:
manifest = build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX)
except Exception:
manifest = json.loads(MANIFEST.read_text())
pad = manifest["pad"]
stem = str(idx).zfill(pad)
out = SELECTED / stem
out.mkdir(parents=True, exist_ok=True)
saved = []
for m in manifest["models"]:
src = OUTPUTS / m["dir"] / f"{stem}.png"
if src.exists():
shutil.copy(src, out / f"{m['name']}.png")
saved.append(m["name"])
prompt = manifest["prompts"][idx] if idx < len(manifest["prompts"]) else ""
(out / "prompt.txt").write_text(prompt)
(out / "meta.json").write_text(json.dumps({"id": idx, "saved_models": saved, "prompt": prompt}, indent=2))
return {"ok": True, "path": str(out.relative_to(REPO)), "saved_models": saved}
class Handler(SimpleHTTPRequestHandler):
def __init__(self, *a, **k):
super().__init__(*a, directory=str(REPO), **k)
def log_message(self, *a): # quiet
return
def _json(self, payload, status=200):
body = json.dumps(payload).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(body)
def do_GET(self):
path = urlparse(self.path).path
if path == "/api/manifest":
try:
self._json(build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX))
except Exception as e: # noqa: BLE001
self._json({"error": str(e)}, 500)
return
if path == "/api/selected":
ids = []
if SELECTED.exists():
for p in SELECTED.iterdir():
if p.is_dir() and p.name.isdigit():
ids.append(int(p.name))
self._json({"ids": sorted(ids)})
return
super().do_GET()
def do_POST(self):
if urlparse(self.path).path != "/api/save":
self.send_error(404)
return
length = int(self.headers.get("Content-Length", 0))
body = json.loads(self.rfile.read(length) or b"{}")
idx = int(body.get("idx", -1))
if idx < 0:
self._json({"ok": False, "error": "bad idx"}, 400)
return
try:
self._json(save_selection(idx))
except Exception as e: # noqa: BLE001
self._json({"ok": False, "error": str(e)}, 500)
def main():
import socket
ap = argparse.ArgumentParser()
ap.add_argument("--host", default="0.0.0.0")
ap.add_argument("--port", type=int, default=8765)
args = ap.parse_args()
SELECTED.mkdir(parents=True, exist_ok=True)
host = socket.gethostname()
page = "/test/xj_mjhq30k_inference_outputs/viewer/index.html"
srv = ThreadingHTTPServer((args.host, args.port), Handler)
print("=" * 60, flush=True)
print(f" open http://{host}:{args.port}{page}", flush=True)
print(f" http://localhost:{args.port}{page}", flush=True)
print(f" (remote? ssh -L {args.port}:localhost:{args.port} {host})", flush=True)
print(f" saves -> {SELECTED.relative_to(REPO)}/<idx>/", flush=True)
print(" Ctrl-C to stop.", flush=True)
print("=" * 60, flush=True)
try:
srv.serve_forever()
except KeyboardInterrupt:
print("\n[viewer] shutting down")
srv.server_close()
if __name__ == "__main__":
main()