#!/usr/bin/env python3
"""Export a concept co-occurrence graph from SurrealDB.
Outputs (default under ./exports):
- concept_graph.json: Sigma-friendly graph JSON
- concept_graph.html: self-contained HTML viewer (vis-network inlined)
Edges are co-occurrence within the same chunk:
If a chunk mentions concepts A and B, edge(A,B) += 1.
"""
from __future__ import annotations
import argparse
import itertools
import json
import math
import os
import random
import urllib.request
from pathlib import Path
from typing import Any
from surrealdb import Surreal
def _query_rows(
conn: Surreal, surql: str, vars: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
res = conn.query(surql, vars or {})
if not isinstance(res, list):
return []
return [r for r in res if isinstance(r, dict)]
def _as_str_id(value: Any) -> str:
return str(value)
def _linear_scale(
value: float, vmin: float, vmax: float, out_min: float, out_max: float
) -> float:
if vmax <= vmin:
return out_min
t = (value - vmin) / (vmax - vmin)
t = max(0.0, min(1.0, t))
return out_min + t * (out_max - out_min)
def _fetch_text(url: str, timeout_s: int = 20) -> str:
with urllib.request.urlopen(url, timeout=timeout_s) as resp:
return resp.read().decode("utf-8", errors="replace")
def _write_html_viewer(output_html: Path, payload: dict[str, Any], title: str) -> None:
"""Write a self-contained HTML viewer using vis-network.
We download the JS/CSS once at export time and inline it to avoid relying on
runtime CDN access.
"""
css_url = (
"https://cdn.jsdelivr.net/npm/vis-network@9.1.9/styles/vis-network.min.css"
)
js_url = "https://cdn.jsdelivr.net/npm/vis-network@9.1.9/standalone/umd/vis-network.min.js"
try:
vis_css = _fetch_text(css_url)
vis_js = _fetch_text(js_url)
except Exception as exc:
output_html.write_text(
"\n".join(
[
"",
'',
f"
{title}",
"",
"Failed to download vis-network assets while generating HTML.",
f"Error: {exc}",
"",
]
),
encoding="utf-8",
)
return
data_json = json.dumps(payload, ensure_ascii=False)
# f-string + braces => double braces for literal '{' in JS/CSS.
html = f"""
{title}
"""
output_html.write_text(html, encoding="utf-8")
def _top_concepts(conn: Surreal, top: int) -> list[dict[str, Any]]:
# Prefer derived frequency from edge table; concept table doesn't reliably
# have a frequency field in this dataset.
return _query_rows(
conn,
"""
SELECT out, out.id AS id, out.content AS name, count() AS frequency
FROM MENTIONS_CONCEPT
GROUP BY out
ORDER BY frequency DESC
LIMIT $top
""",
{"top": top},
)
def main() -> int:
parser = argparse.ArgumentParser(description="Export concept graph from SurrealDB")
parser.add_argument("--top", type=int, default=50, help="Top concepts to include")
parser.add_argument("--min-weight", type=int, default=1, help="Minimum edge weight")
parser.add_argument(
"--max-edges", type=int, default=500, help="Maximum edges output"
)
parser.add_argument(
"--max-concepts-per-chunk",
type=int,
default=25,
help="Cap concepts per chunk for co-occurrence",
)
parser.add_argument("--output-dir", default="exports", help="Output directory")
parser.add_argument(
"--db-url",
default=os.getenv("KG_DB_URL", "ws://localhost:8000/rpc"),
help="SurrealDB WS URL",
)
parser.add_argument(
"--db-name",
default=os.getenv("DB_NAME", "test_db"),
help="SurrealDB database name",
)
parser.add_argument(
"--namespace",
default=os.getenv("DB_NS", "kaig"),
help="SurrealDB namespace",
)
parser.add_argument("--username", default=os.getenv("DB_USER", "root"))
parser.add_argument("--password", default=os.getenv("DB_PASS", "root"))
parser.add_argument("--seed", type=int, default=7, help="Layout seed")
args = parser.parse_args()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_json = out_dir / "concept_graph.json"
out_html = out_dir / "concept_graph.html"
conn = Surreal(args.db_url)
conn.signin({"username": args.username, "password": args.password})
conn.use(args.namespace, args.db_name)
top_rows = _top_concepts(conn, args.top)
if not top_rows:
raise SystemExit("No concepts found")
concept_ids: list[str] = []
concept_rids: list[Any] = []
freqs: list[float] = []
names: list[str] = []
for r in top_rows:
cid = r.get("id")
name = r.get("name")
freq = r.get("frequency")
if cid is None:
continue
cid_s = _as_str_id(cid)
concept_ids.append(cid_s)
concept_rids.append(cid)
names.append(str(name or cid_s.split(":", 1)[-1]).strip().strip('"'))
try:
freqs.append(float(freq) if freq is not None else 1.0)
except Exception:
freqs.append(1.0)
freq_min = min(freqs) if freqs else 1.0
freq_max = max(freqs) if freqs else 1.0
random.seed(args.seed)
nodes: list[dict[str, Any]] = []
for i, cid_s in enumerate(concept_ids):
angle = (2.0 * math.pi * i) / max(1, len(concept_ids))
x = math.cos(angle)
y = math.sin(angle)
size = _linear_scale(freqs[i], freq_min, freq_max, 5.0, 20.0)
nodes.append(
{
"id": cid_s,
"label": names[i],
"size": round(size, 3),
"x": round(x, 6),
"y": round(y, 6),
"color": "#2563eb",
"frequency": freqs[i],
}
)
mappings = _query_rows(
conn,
"""
SELECT in AS chunk, out AS concept
FROM MENTIONS_CONCEPT
WHERE out IN $concepts
""",
{"concepts": concept_rids},
)
id_set = set(concept_ids)
chunk_to_concepts: dict[str, set[str]] = {}
for row in mappings:
chunk = row.get("chunk")
concept = row.get("concept")
if chunk is None or concept is None:
continue
c_id = _as_str_id(concept)
if c_id not in id_set:
continue
ch_id = _as_str_id(chunk)
s = chunk_to_concepts.setdefault(ch_id, set())
if len(s) < args.max_concepts_per_chunk:
s.add(c_id)
edge_counts: dict[tuple[str, str], int] = {}
for concepts in chunk_to_concepts.values():
if len(concepts) < 2:
continue
sorted_ids = sorted(concepts)
for a, b in itertools.combinations(sorted_ids, 2):
edge_counts[(a, b)] = edge_counts.get((a, b), 0) + 1
edges_raw = [(a, b, w) for (a, b), w in edge_counts.items() if w >= args.min_weight]
edges_raw.sort(key=lambda t: t[2], reverse=True)
edges_raw = edges_raw[: max(0, int(args.max_edges))]
edges: list[dict[str, Any]] = []
for i, (a, b, w) in enumerate(edges_raw):
edges.append({"id": f"e{i}", "source": a, "target": b, "weight": int(w)})
payload = {
"meta": {
"db_url": args.db_url,
"namespace": args.namespace,
"db_name": args.db_name,
"top": args.top,
"min_weight": args.min_weight,
"max_edges": args.max_edges,
},
"nodes": nodes,
"edges": edges,
}
out_json.write_text(
json.dumps(payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
_write_html_viewer(out_html, payload, "Concept Graph (Top Concepts)")
print(str(out_json))
print(str(out_html))
return 0
if __name__ == "__main__":
raise SystemExit(main())