| |
| """Static server + /api/save for the xj_mjhq30k model-comparison viewer. |
| |
| Serves the repo statically and adds: |
| - GET /api/manifest : live re-scan of the sibling output dirs (so new images |
| appear without re-running gen_manifest.py) |
| - POST /api/save : copy each model's image for the current prompt + the |
| prompt text into ../selected/<idx>/ |
| |
| Use this (not plain http.server) so the Save button works: |
| python test/xj_mjhq30k_inference_outputs/viewer/viewer.py # :8765 |
| python test/xj_mjhq30k_inference_outputs/viewer/viewer.py --port N |
| |
| Then open http://<host>:<port>/test/xj_mjhq30k_inference_outputs/viewer/index.html |
| """ |
| import argparse |
| import json |
| import shutil |
| import sys |
| from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer |
| from pathlib import Path |
| from urllib.parse import urlparse |
|
|
| HERE = Path(__file__).resolve().parent |
| OUTPUTS = HERE.parent |
| REPO = HERE.parents[2] |
| SELECTED = OUTPUTS / "selected" |
| MANIFEST = HERE / "manifest.json" |
| PROMPTS_JSON = REPO / "assets" / "xj_mjhq30k_prompts.json" |
| REL_PREFIX = ".." |
|
|
| sys.path.insert(0, str(HERE)) |
| from gen_manifest import build_manifest |
|
|
|
|
| def save_selection(idx: int) -> dict: |
| |
| try: |
| manifest = build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX) |
| except Exception: |
| manifest = json.loads(MANIFEST.read_text()) |
| pad = manifest["pad"] |
| stem = str(idx).zfill(pad) |
| out = SELECTED / stem |
| out.mkdir(parents=True, exist_ok=True) |
|
|
| saved = [] |
| for m in manifest["models"]: |
| src = OUTPUTS / m["dir"] / f"{stem}.png" |
| if src.exists(): |
| shutil.copy(src, out / f"{m['name']}.png") |
| saved.append(m["name"]) |
|
|
| prompt = manifest["prompts"][idx] if idx < len(manifest["prompts"]) else "" |
| (out / "prompt.txt").write_text(prompt) |
| (out / "meta.json").write_text(json.dumps({"id": idx, "saved_models": saved, "prompt": prompt}, indent=2)) |
| return {"ok": True, "path": str(out.relative_to(REPO)), "saved_models": saved} |
|
|
|
|
| class Handler(SimpleHTTPRequestHandler): |
| def __init__(self, *a, **k): |
| super().__init__(*a, directory=str(REPO), **k) |
|
|
| def log_message(self, *a): |
| return |
|
|
| def _json(self, payload, status=200): |
| body = json.dumps(payload).encode() |
| self.send_response(status) |
| self.send_header("Content-Type", "application/json") |
| self.send_header("Content-Length", str(len(body))) |
| self.send_header("Cache-Control", "no-store") |
| self.end_headers() |
| self.wfile.write(body) |
|
|
| def do_GET(self): |
| path = urlparse(self.path).path |
| if path == "/api/manifest": |
| try: |
| self._json(build_manifest(OUTPUTS, PROMPTS_JSON, REL_PREFIX)) |
| except Exception as e: |
| self._json({"error": str(e)}, 500) |
| return |
| if path == "/api/selected": |
| ids = [] |
| if SELECTED.exists(): |
| for p in SELECTED.iterdir(): |
| if p.is_dir() and p.name.isdigit(): |
| ids.append(int(p.name)) |
| self._json({"ids": sorted(ids)}) |
| return |
| super().do_GET() |
|
|
| def do_POST(self): |
| if urlparse(self.path).path != "/api/save": |
| self.send_error(404) |
| return |
| length = int(self.headers.get("Content-Length", 0)) |
| body = json.loads(self.rfile.read(length) or b"{}") |
| idx = int(body.get("idx", -1)) |
| if idx < 0: |
| self._json({"ok": False, "error": "bad idx"}, 400) |
| return |
| try: |
| self._json(save_selection(idx)) |
| except Exception as e: |
| self._json({"ok": False, "error": str(e)}, 500) |
|
|
|
|
| def main(): |
| import socket |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--host", default="0.0.0.0") |
| ap.add_argument("--port", type=int, default=8765) |
| args = ap.parse_args() |
| SELECTED.mkdir(parents=True, exist_ok=True) |
| host = socket.gethostname() |
| page = "/test/xj_mjhq30k_inference_outputs/viewer/index.html" |
| srv = ThreadingHTTPServer((args.host, args.port), Handler) |
| print("=" * 60, flush=True) |
| print(f" open http://{host}:{args.port}{page}", flush=True) |
| print(f" http://localhost:{args.port}{page}", flush=True) |
| print(f" (remote? ssh -L {args.port}:localhost:{args.port} {host})", flush=True) |
| print(f" saves -> {SELECTED.relative_to(REPO)}/<idx>/", flush=True) |
| print(" Ctrl-C to stop.", flush=True) |
| print("=" * 60, flush=True) |
| try: |
| srv.serve_forever() |
| except KeyboardInterrupt: |
| print("\n[viewer] shutting down") |
| srv.server_close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|