File size: 19,850 Bytes
7eaced5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
"""
Hardcoded Queries tab β€” runs QueryExecutor methods against Neo4j.

Query definitions live in queries/hardcoded.py.  Each entry declares which
QueryExecutor method to call and what UI inputs to render.  No code changes
are needed here when adding new queries.

Special input type "node_picker":
    Runs a vector search on the user's term for a given node_type, shows a
    scrollable list of matches, and on click loads the selected label and
    auto-triggers the main query.

    Required field in input spec:
        node_type : Neo4j label to search (e.g. "UNESCOconcept", "Breakthrough")
    Optional fields:
        top_k     : number of candidates shown (default 7)
"""
from __future__ import annotations

from collections import defaultdict

import pandas as pd
import streamlit as st

from db.neo4j_client import get_neo4j_resources
from db.vector_client import vector_search
from queries.hardcoded import HARDCODED_QUERIES
from queries.mappings import apply_mapping


# Per-type defaults used by the node picker
_PICKER_ICONS: dict[str, str] = {
    "UNESCOconcept":  "πŸ“š",
    "Breakthrough":   "πŸ’‘",
    "Platform":       "πŸ—οΈ",
    "Emerging topic": "🌱",
    "SDGtarget":      "🎯",
    "SDGgoal":        "🌐",
    "SDGindicator":   "πŸ“Š",
    "OECDfield":      "πŸ”¬",
}

_PICKER_THRESHOLDS: dict[str, float] = {
    "UNESCOconcept":  0.55,
    "Breakthrough":   0.60,
    "Platform":       0.55,
    "Emerging topic": 0.55,
    "SDGtarget":      0.55,
    "SDGgoal":        0.45,
    "SDGindicator":   0.50,
    "OECDfield":      0.50,
}


# ---------------------------------------------------------------------------
# Node picker (vector search β†’ scrollable results β†’ click to select)
# ---------------------------------------------------------------------------

def _render_node_picker(inp: dict, query_id: str) -> str:
    """
    Render a vector-search-backed node picker for any node_type.
    Returns the selected label string, or '' if nothing is selected yet.
    Side-effect: sets session_state[f"hq_{query_id}_auto_run"] = True on click.
    """
    node_type = inp["node_type"]
    icon      = _PICKER_ICONS.get(node_type, "β€’")
    threshold = _PICKER_THRESHOLDS.get(node_type, 0.50)

    key_base     = f"hq_{query_id}_{inp['key']}"
    selected_key = f"{key_base}_selected"
    results_key  = f"{key_base}_results"

    selected = st.session_state.get(selected_key, "")

    # ── Already selected β†’ show badge + clear button ──────────────────────
    if selected:
        display = st.session_state.get(f"{selected_key}_display", selected)
        version = st.session_state.get(f"{selected_key}_version")
        version_suffix = f" ({version})" if version else ""
        col_sel, col_clr = st.columns([5, 1])
        with col_sel:
            st.success(f"{icon} **{display}**{version_suffix}")
        with col_clr:
            if st.button("βœ• Clear", key=f"{key_base}_clear"):
                st.session_state.pop(selected_key, None)
                st.session_state.pop(f"{selected_key}_display", None)
                st.session_state.pop(f"{selected_key}_version", None)
                st.session_state.pop(results_key, None)
                st.rerun()
        return selected

    # ── Search input ──────────────────────────────────────────────────────
    search_term = st.text_input(
        inp.get("label", f"Search for a {node_type}"),
        placeholder=inp.get("placeholder", ""),
        key=f"{key_base}_search_input",
    )

    if st.button("Search", key=f"{key_base}_search_btn", disabled=not search_term.strip()):
        mapped_term, was_mapped = apply_mapping(search_term)
        st.session_state[f"{key_base}_mapped"] = mapped_term if was_mapped else ""
        with st.spinner(f"Searching {node_type}…"):
            results, err = vector_search(
                mapped_term,
                top_k=inp.get("top_k", 7) * 15,
                threshold=threshold,
                node_label=node_type,
            )
        if err:
            st.error(f"Vector search error: {err}")
            return ""
        # Deduplicate: prefer pref_label_en, fall back to original_text
        # Store (score, radar_version, node_id) per label
        value_prop    = inp.get("value_prop")      # if set, use this field as stored value
        display_node_id = inp.get("display_node_id", False)  # show node_id as prefix in button/badge
        groups: dict[str, tuple[float, str | None, str | None]] = {}
        for r in results:
            lbl = r.get("pref_label_en") or r.get("original_text", "")
            if lbl:
                prev_score = groups.get(lbl, (0.0, None, None))[0]
                if r["score"] > prev_score:
                    groups[lbl] = (r["score"], r.get("radar_version"), r.get("node_id"))
        top = sorted(groups.items(), key=lambda x: x[1][0], reverse=True)[: inp.get("top_k", 7)]
        st.session_state[results_key] = top

    # ── Scrollable results ────────────────────────────────────────────────
    value_prop      = inp.get("value_prop")
    display_node_id = inp.get("display_node_id", False)
    mapped_label = st.session_state.get(f"{key_base}_mapped", "")
    if mapped_label:
        st.caption(f"Searched as: *{mapped_label}*")

    top = st.session_state.get(results_key)
    if top is not None:
        if top:
            with st.container(height=220):
                for i, (lbl, (score, version, node_id)) in enumerate(top):
                    version_suffix  = f"  ({version})" if version else ""
                    node_id_prefix  = f"[{node_id}]  " if display_node_id and node_id else ""
                    stored_value    = node_id if value_prop == "node_id" and node_id else lbl
                    if st.button(
                        f"{icon} {node_id_prefix}{lbl}{version_suffix}  Β·  {score:.3f}",
                        key=f"{key_base}_pick_{i}",
                        use_container_width=True,
                    ):
                        st.session_state[selected_key] = stored_value
                        st.session_state[f"{selected_key}_display"] = f"[{node_id}] {lbl}" if display_node_id and node_id else lbl
                        st.session_state[f"{selected_key}_version"] = version
                        st.session_state[f"hq_{query_id}_auto_run"] = True
                        st.rerun()
        else:
            st.warning(f"No {node_type} nodes found above threshold.")

    return ""


# ---------------------------------------------------------------------------
# Generic input collection
# ---------------------------------------------------------------------------

def _collect_inputs(query_def: dict) -> tuple[dict, bool]:
    """Render input widgets and return (collected_kwargs, all_required_filled)."""
    collected: dict = {}
    all_filled = True

    concept_inputs = [i for i in query_def["inputs"] if i["type"] == "node_picker"]
    text_inputs    = [i for i in query_def["inputs"] if i["type"] == "text"]
    other_inputs   = [i for i in query_def["inputs"] if i["type"] not in ("text", "node_picker")]

    # Node pickers (full width, vector-search backed)
    for inp in concept_inputs:
        val = _render_node_picker(inp, query_def["id"])
        collected[inp["key"]] = val
        if not val:
            all_filled = False

    # Plain text inputs
    for inp in text_inputs:
        val = st.text_input(
            inp["label"],
            placeholder=inp.get("placeholder", ""),
            help=inp.get("help"),
            key=f'hq_{query_def["id"]}_{inp["key"]}',
        )
        collected[inp["key"]] = val
        if not val.strip():
            all_filled = False

    # Selects / checkboxes / numbers β€” rendered in a single row of columns
    if other_inputs:
        cols = st.columns(len(other_inputs))
        for col, inp in zip(cols, other_inputs):
            with col:
                if inp["type"] == "select":
                    opts = inp["options"]
                    idx = opts.index(inp["default"]) if inp["default"] in opts else 0
                    opt_labels = inp.get("option_labels", {})
                    collected[inp["key"]] = st.selectbox(
                        inp["label"], opts, index=idx,
                        format_func=lambda v, _m=opt_labels: _m.get(v, str(v)),
                        help=inp.get("help"),
                        key=f'hq_{query_def["id"]}_{inp["key"]}',
                    )
                elif inp["type"] == "checkbox":
                    collected[inp["key"]] = st.checkbox(
                        inp["label"], value=inp["default"],
                        help=inp.get("help"),
                        key=f'hq_{query_def["id"]}_{inp["key"]}',
                    )
                elif inp["type"] == "number":
                    collected[inp["key"]] = int(st.number_input(
                        inp["label"],
                        min_value=inp.get("min", 1),
                        max_value=inp.get("max", 500),
                        value=inp["default"],
                        help=inp.get("help"),
                        key=f'hq_{query_def["id"]}_{inp["key"]}',
                    ))

    return collected, all_filled


# ---------------------------------------------------------------------------
# Result helpers
# ---------------------------------------------------------------------------

def _flatten_results(results: list[dict]) -> list[dict]:
    """Convert list values to comma-separated strings for dataframe display.
    Also folds radar_version into breakthrough_name where present."""
    flat = []
    for row in results:
        flat_row = {
            k: (", ".join(str(x) for x in v) if isinstance(v, list) else v)
            for k, v in row.items()
        }
        version = flat_row.pop("radar_version", None)
        if version is not None and "breakthrough_name" in flat_row:
            flat_row["breakthrough_name"] = f"{flat_row['breakthrough_name']} ({version})"
        flat.append(flat_row)
    return flat


_SKOS_COLORS = {
    "IS_BROADER":          "#3498DB",
    "IS_NARROWER":         "#9B59B6",
    "IS_BROADER_CONCEPT":  "#3498DB",
    "IS_NARROWER_CONCEPT": "#9B59B6",
    "IS_RELATED_CONCEPT":  "#F39C12",
}
_REL_COLORS = {"ADVANCES": "#2ECC71", "REQUIRES": "#E74C3C"}


def _rel_badge(rel: str, color: str) -> str:
    return (f'<span style="font-size:10px; background:{color}22; color:{color}; '
            f'border:1px solid {color}66; border-radius:4px; padding:1px 5px; '
            f'margin:0 3px; white-space:nowrap;">{rel}</span>')


def _node_span(label: str, icon: str, bold: bool = False) -> str:
    weight = "700" if bold else "400"
    return f'<span style="font-weight:{weight}; color:#1a1e2a;">{icon} {label}</span>'


def _path_card(parts: list[str]) -> None:
    st.markdown(
        '<div style="background:#fff; border:1px solid #dde2ec; border-radius:8px; '
        'padding:8px 12px; margin-bottom:6px; line-height:2;">'
        + " ".join(parts) + '</div>',
        unsafe_allow_html=True,
    )


def _is_bridge_result(results: list[dict]) -> bool:
    if not results:
        return False
    keys = set(results[0].keys())
    return "b1_name" in keys and "b2_name" in keys and "rel_type_1" in keys


def _is_sdg_contributor_result(results: list[dict]) -> bool:
    if not results:
        return False
    keys = set(results[0].keys())
    return "sdg_target_id" in keys and "breakthrough_name" in keys


def _is_neighborhood_result(results: list[dict]) -> bool:
    if not results:
        return False
    keys = set(results[0].keys())
    return "center_concept" in keys and "relationship_1" in keys


def _breakthrough_label(name: str, version) -> str:
    return f"{name} ({version})" if version else name


def _render_bridge_results(results: list[dict]) -> None:
    for row in results:
        parts: list[str] = []
        parts.append(_node_span(_breakthrough_label(row["b1_name"], row.get("b1_radar_version")), "πŸ’‘", bold=True))
        parts.append(_rel_badge(row.get("rel_type_1", "β†’"),
                                _REL_COLORS.get(row.get("rel_type_1", ""), "#888")))
        i = 1
        while f"concept_{i}" in row:
            v = row[f"concept_{i}"]
            if v is not None:
                parts.append(_node_span(v, "πŸ“š"))
            sr_key = f"skos_rel_{i}"
            if sr_key in row and row[sr_key] is not None:
                sr = row[sr_key]
                parts.append(_rel_badge(sr, _SKOS_COLORS.get(sr, "#888")))
            i += 1
        parts.append(_rel_badge(row.get("rel_type_2", "β†’"),
                                _REL_COLORS.get(row.get("rel_type_2", ""), "#888")))
        parts.append(_node_span(_breakthrough_label(row["b2_name"], row.get("b2_radar_version")), "πŸ’‘", bold=True))
        _path_card(parts)


def _render_sdg_contributor_results(results: list[dict]) -> None:
    _CONTRIBUTES_TO_COLOR = "#9B59B6"
    for row in results:
        parts: list[str] = []
        # Breakthrough on the left
        parts.append(_node_span(_breakthrough_label(row["breakthrough_name"], row.get("radar_version")), "πŸ’‘", bold=True))
        b_rel = row.get("b_rel", "β†’")
        parts.append(_rel_badge(b_rel, _REL_COLORS.get(b_rel, "#888")))
        # Intermediate concepts β€” stored in reverse order (concept_1 is closest to target,
        # concept_N is closest to breakthrough), so walk from highest to lowest index
        max_i = sum(1 for k in row if k.startswith("concept_") and row[k] is not None)
        for i in range(max_i, 0, -1):
            v = row.get(f"concept_{i}")
            if v is not None:
                parts.append(_node_span(v, "πŸ“š"))
            if i > 1:
                sr = row.get(f"skos_rel_{i - 1}")
                if sr is not None:
                    parts.append(_rel_badge(sr, _SKOS_COLORS.get(sr, "#888")))
        # CONTRIBUTES_TO β†’ SDGtarget on the right
        parts.append(_rel_badge("CONTRIBUTES_TO", _CONTRIBUTES_TO_COLOR))
        parts.append(_node_span(row["sdg_target_id"], "🎯", bold=True))
        dist = row.get("distance", "")
        parts.append(f'<span style="font-size:11px; color:#6b7280; margin-left:6px;">dist={dist}</span>')
        _path_card(parts)


def _render_neighborhood_results(results: list[dict]) -> None:
    for row in results:
        parts: list[str] = []
        # Center node
        center = row.get("center_concept")
        if center is not None:
            parts.append(_node_span(center, "πŸ“š", bold=True))
        # Walk positional relationship_N / neighbor_concept_N columns, skip nulls
        i = 1
        while f"relationship_{i}" in row or f"neighbor_concept_{i}" in row:
            rel = row.get(f"relationship_{i}")
            node = row.get(f"neighbor_concept_{i}")
            if rel is not None:
                parts.append(_rel_badge(rel, _SKOS_COLORS.get(rel, "#888")))
            if node is not None:
                parts.append(_node_span(node, "πŸ“š"))
            i += 1
        if len(parts) > 1:
            _path_card(parts)


# ---------------------------------------------------------------------------
# Main render
# ---------------------------------------------------------------------------

def render() -> None:
    st.subheader("Hardcoded Queries")
    st.caption(
        "Pre-configured queries to explore the structure, and get insights from the Knowledge Graph."
    )

    _, executor, neo4j_err = get_neo4j_resources()
    if executor is None:
        st.error(f"Neo4j unavailable: {neo4j_err}")
        return

    # ── Query selector ────────────────────────────────────────────────────────
    by_category: dict[str, list] = defaultdict(list)
    for q in HARDCODED_QUERIES:
        by_category[q["category"]].append(q)

    options = [q["id"] for q in HARDCODED_QUERIES]
    labels  = {q["id"]: f'{q["icon"]} {q["name"]}' for q in HARDCODED_QUERIES}

    col_sel, col_main = st.columns([1, 3])

    with col_sel:
        st.markdown('<div class="section-header">Query type</div>', unsafe_allow_html=True)
        selected_id = st.radio(
            "Select query",
            options,
            format_func=lambda oid: labels[oid],
            label_visibility="collapsed",
            key="hq_selected",
        )

    query_def = next(q for q in HARDCODED_QUERIES if q["id"] == selected_id)

    with col_main:
        st.markdown(
            f'<div style="border-left:3px solid #3498DB; padding-left:12px; margin-bottom:16px;">'
            f'<div style="font-size:18px; font-weight:700; color:#1a1e2a;">'
            f'{query_def["icon"]} {query_def["name"]}</div>'
            f'<div style="font-size:13px; color:#6b7280; margin-top:4px;">'
            f'{query_def["description"]}</div>'
            f'<div style="font-size:12px; color:#6b7280; margin-top:4px;">'
            f'method: <code>{query_def["method"]}</code></div>'
            f'</div>',
            unsafe_allow_html=True,
        )

        has_picker = any(i["type"] == "node_picker" for i in query_def["inputs"])

        collected, all_filled = _collect_inputs(query_def)

        if has_picker:
            # No run button β€” query fires automatically whenever a concept is
            # selected. Changing hops/limit also re-runs immediately.
            st.session_state.pop(f"hq_{selected_id}_auto_run", None)
            run = all_filled
            if not run:
                st.info("Search for a node above β€” click a result to run the query.")
                return
        else:
            auto_run = st.session_state.pop(f"hq_{selected_id}_auto_run", False)
            run = st.button("β–Ά Run", type="primary", key=f"hq_run_{selected_id}") or auto_run
            if not run:
                if query_def["inputs"]:
                    st.info("Fill in the parameters above and press **β–Ά Run**.")
                return
            if not all_filled:
                st.warning("Please fill in all required fields.")
                return

        kwargs = {**collected, **query_def.get("fixed_params", {})}

        with st.spinner("Running query…"):
            try:
                results = getattr(executor, query_def["method"])(**kwargs)
            except Exception as exc:
                st.error(f"Query error: {exc}")
                return

        if not results:
            st.info("Query returned no results.")
            return

        st.success(f"**{len(results)}** row(s) returned")

        with st.expander("Cypher query", expanded=False):
            st.code(executor.last_query.strip(), language="cypher")
            if executor.last_params:
                st.caption("Parameters: " + ", ".join(f"`{k}` = `{v}`" for k, v in executor.last_params.items()))

        if _is_bridge_result(results):
            _render_bridge_results(results)
        elif _is_sdg_contributor_result(results):
            _render_sdg_contributor_results(results)
        elif _is_neighborhood_result(results):
            _render_neighborhood_results(results)
        else:
            flat = _flatten_results(results)
            df = pd.DataFrame(flat)
            st.dataframe(df, use_container_width=True, height=420)

        flat = _flatten_results(results)
        csv = pd.DataFrame(flat).to_csv(index=False)
        st.download_button(
            "Download CSV",
            data=csv,
            file_name=f"{selected_id}.csv",
            mime="text/csv",
        )