Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import hashlib | |
| import json | |
| import os | |
| from pathlib import Path | |
| from uuid import NAMESPACE_URL, UUID, uuid5 | |
| import numpy as np | |
| from sqlalchemy import create_engine, text | |
| from sqlalchemy.orm import Session | |
| DATABASE_URL = os.environ.get("DATABASE_URL_SYNC", | |
| "postgresql+psycopg://archstyle:archstyle@postgres:5432/archstyle") | |
| ALIAS = { | |
| "Moscow Luzhkov style architecture": "Late 20th century Moscow architecture", | |
| } | |
| RUNS_BASE = Path(os.environ.get("RUNS_RES_DIR", "/runs_res")) | |
| REPO_BASE = Path(os.environ.get("REPO_DIR", "/repo")) | |
| def _results_root() -> Path: | |
| override = os.environ.get("RESULTS_DIR") | |
| if override: | |
| return Path(override) | |
| if (RUNS_BASE / "aggregate").is_dir(): | |
| return RUNS_BASE / "aggregate" | |
| return RUNS_BASE | |
| RESULTS = _results_root() | |
| EMB_FILES = ( | |
| ("test", str(RESULTS / "embeddings/embeddings_test.npz")), | |
| ("val", str(RESULTS / "embeddings/embeddings_val.npz")), | |
| ("train", str(RESULTS / "embeddings/embeddings_train.npz")), | |
| ) | |
| SUMMARY_CSV = str(RESULTS / "summary_table.csv") | |
| COMPUTE_CSV = str(RESULTS / "compute_cost_table.csv") | |
| IDX_FALLBACKS = ( | |
| str(REPO_BASE / "pipeline/results/splits/idx_to_class.json"), | |
| str(REPO_BASE / "results/splits/idx_to_class.json"), | |
| str(RESULTS / "splits/idx_to_class.json"), | |
| ) | |
| def _stable_uuid(path: str) -> UUID: | |
| return uuid5(NAMESPACE_URL, f"archstyle://{path}") | |
| def _sha256(path: str) -> str: | |
| return hashlib.sha256(path.encode("utf-8")).hexdigest() | |
| def _ensure_pgvector(engine): | |
| with engine.connect() as conn: | |
| conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) | |
| conn.commit() | |
| def seed_classes(session: Session) -> int: | |
| paths = [Path(p) for p in IDX_FALLBACKS] | |
| src = next((p for p in paths if p.is_file()), None) | |
| if src is None: | |
| print("classes: idx_to_class.json not found; skipping") | |
| return 0 | |
| idx = json.loads(src.read_text(encoding="utf-8")) | |
| rows = [ | |
| { | |
| "idx": int(k), | |
| "name": v, | |
| "display_name": ALIAS.get(v, v), | |
| } | |
| for k, v in idx.items() | |
| ] | |
| session.execute( | |
| text( | |
| """ | |
| INSERT INTO classes (idx, name, display_name) | |
| VALUES (:idx, :name, :display_name) | |
| ON CONFLICT (idx) DO UPDATE SET | |
| name = EXCLUDED.name, | |
| display_name = EXCLUDED.display_name | |
| """ | |
| ), | |
| rows, | |
| ) | |
| print(f"classes: upserted {len(rows)}") | |
| return len(rows) | |
| def seed_models(session: Session) -> int: | |
| if not Path(SUMMARY_CSV).is_file(): | |
| print("model_meta: no summary_table.csv; skipping") | |
| return 0 | |
| summary = list(csv.DictReader(open(SUMMARY_CSV, encoding="utf-8"))) | |
| compute = {} | |
| if Path(COMPUTE_CSV).is_file(): | |
| for row in csv.DictReader(open(COMPUTE_CSV, encoding="utf-8")): | |
| compute[row.get("model")] = row | |
| n = 0 | |
| for r in summary: | |
| name = r.get("model") | |
| if not name: | |
| continue | |
| family = "transformer" if any(s in name for s in ("vit", "swin", "dinov2")) else \ | |
| ("ensemble" if "ensemble" in name else "cnn") | |
| c = compute.get(name, {}) | |
| def _f(key, default=None): | |
| v = r.get(key) or c.get(key) | |
| try: | |
| return float(v) if v not in (None, "") else default | |
| except ValueError: | |
| return default | |
| session.execute( | |
| text( | |
| """ | |
| INSERT INTO model_meta(name, family, params_m, gflops, accuracy, | |
| macro_f1, bal_acc, inference_ms, image_size, hf_repo) | |
| VALUES (:name, :family, :params_m, :gflops, :acc, :f1, :bal, | |
| :inf, :img, :repo) | |
| ON CONFLICT (name) DO UPDATE SET | |
| accuracy = EXCLUDED.accuracy, | |
| macro_f1 = EXCLUDED.macro_f1, | |
| bal_acc = EXCLUDED.bal_acc, | |
| inference_ms = EXCLUDED.inference_ms, | |
| params_m = EXCLUDED.params_m, | |
| gflops = EXCLUDED.gflops | |
| """ | |
| ), | |
| { | |
| "name": name, | |
| "family": family, | |
| "params_m": _f("params_m", 0.0) or 0.0, | |
| "gflops": _f("gflops"), | |
| "acc": _f("accuracy"), | |
| "f1": _f("macro_f1"), | |
| "bal": _f("bal_acc"), | |
| "inf": _f("inference_ms"), | |
| "img": int(_f("image_size", 224) or 224), | |
| "repo": os.environ.get("HF_MODEL_REPO", "kkkaredaw/archstyle55-backbones"), | |
| }, | |
| ) | |
| n += 1 | |
| print(f"model_meta: upserted {n}") | |
| return n | |
| def seed_embeddings(session: Session, *, max_rows: int | None = None, | |
| chunk_size: int = 256) -> int: | |
| inserted = 0 | |
| seen_sha: set[str] = set() | |
| def _flush(rows_img, rows_emb): | |
| if not rows_img: | |
| return | |
| session.execute( | |
| text( | |
| """ | |
| INSERT INTO images (id, sha256, source, style_label, blob_url) | |
| VALUES (:id, :sha256, :source, :style_label, :blob_url) | |
| ON CONFLICT (sha256) DO NOTHING | |
| """ | |
| ), | |
| rows_img, | |
| ) | |
| session.execute( | |
| text( | |
| """ | |
| INSERT INTO embeddings (image_id, model, vec) | |
| VALUES (:image_id, :model, CAST(:vec AS vector)) | |
| ON CONFLICT (image_id, model) DO NOTHING | |
| """ | |
| ), | |
| rows_emb, | |
| ) | |
| session.commit() | |
| for source, path in EMB_FILES: | |
| if not Path(path).is_file(): | |
| print(f"embeddings: skip missing {path}") | |
| continue | |
| with np.load(path, allow_pickle=True, mmap_mode="r") as d: | |
| feats = d["features"] | |
| labels = d["labels"] | |
| paths = d["paths"] | |
| class_names = d["class_names"] | |
| n_rows = len(feats) if max_rows is None else min(len(feats), max_rows) | |
| print(f"embeddings: ingest {n_rows} from {path}") | |
| rows_img: list[dict] = [] | |
| rows_emb: list[dict] = [] | |
| for i in range(n_rows): | |
| p = str(paths[i]) | |
| sha = _sha256(p) | |
| if sha in seen_sha: | |
| continue | |
| seen_sha.add(sha) | |
| uid = _stable_uuid(p) | |
| label_name = str(class_names[int(labels[i])]) | |
| rows_img.append({ | |
| "id": uid, | |
| "sha256": sha, | |
| "source": f"split:{source}", | |
| "style_label": label_name, | |
| "blob_url": p, | |
| }) | |
| vec_str = "[" + ",".join(f"{float(x):.6f}" for x in feats[i]) + "]" | |
| rows_emb.append({ | |
| "image_id": uid, | |
| "model": "dinov2_vitb14", | |
| "vec": vec_str, | |
| }) | |
| if len(rows_img) >= chunk_size: | |
| _flush(rows_img, rows_emb) | |
| inserted += len(rows_img) | |
| rows_img.clear() | |
| rows_emb.clear() | |
| print(f" inserted {inserted}") | |
| _flush(rows_img, rows_emb) | |
| inserted += len(rows_img) | |
| rows_img.clear() | |
| rows_emb.clear() | |
| print(f" inserted {inserted}") | |
| return inserted | |
| def main() -> int: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--max-rows", type=int, default=None, | |
| help="cap rows per file (debug)") | |
| parser.add_argument("--skip-embeddings", action="store_true") | |
| args = parser.parse_args() | |
| engine = create_engine(DATABASE_URL, pool_pre_ping=True) | |
| _ensure_pgvector(engine) | |
| with Session(engine) as session: | |
| seed_classes(session) | |
| seed_models(session) | |
| session.commit() | |
| if not args.skip_embeddings: | |
| seed_embeddings(session, max_rows=args.max_rows) | |
| print("done.") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |