File size: 3,013 Bytes
7eaced5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Cached vector-search client for Streamlit (Qdrant + EPFL API embedder)."""
from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))
import config as _cfg  # noqa: F401

import logging
from typing import Any

import streamlit as st

logger = logging.getLogger(__name__)


@st.cache_resource(show_spinner=False)
def get_vector_resources():
    """Return cached (embedder, qdrant_store, error_msg)."""
    try:
        from src.config import load_config
        from src.embeddings import build_embedder
        from src.qdrant_store import QdrantStore

        cfg = load_config(config_path=str(_cfg.PROJECT_ROOT / "config" / "config.yaml"))

        embedder = build_embedder(cfg)
        qdrant = QdrantStore(
            mode=cfg["qdrant"].get("mode", "local"),
            local_path=cfg["qdrant"].get("local_path"),
            url=cfg["qdrant"].get("url"),
            host=cfg["qdrant"].get("host", "localhost"),
            port=cfg["qdrant"].get("port", 6333),
            api_key=cfg["qdrant"].get("api_key"),
            collection_name=_cfg.KG_NODES_COLLECTION,
            distance=cfg["qdrant"].get("distance", "cosine"),
            vector_dim=cfg["embeddings"]["dimension"],
        )
        return embedder, qdrant, ""
    except Exception as exc:
        logger.warning("Vector search unavailable: %s", exc)
        return None, None, str(exc)


def vector_search(
    query: str,
    top_k: int = 20,
    threshold: float = 0.5,
    node_label: str | None = None,
) -> tuple[list[dict[str, Any]], str]:
    """
    Run a vector similarity search against the kg_nodes Qdrant collection.

    Returns (results, error_message). Results are empty on error.
    If node_label is provided, filters results to that label only.
    """
    embedder, qdrant, err = get_vector_resources()
    if embedder is None:
        return [], err

    try:
        from qdrant_client.models import Filter, FieldCondition, MatchValue

        query_filter = None
        if node_label:
            query_filter = Filter(
                must=[FieldCondition(key="node_type", match=MatchValue(value=node_label))]
            )

        query_vec = embedder.encode_query(query)
        hits = qdrant.search(
            query_vec,
            top_k=top_k,
            score_threshold=threshold,
            query_filter=query_filter,
        )

        results = []
        for h in hits:
            p = h.get("payload", {})
            results.append({
                "score": round(h["score"], 4),
                "node_type": p.get("node_type", ""),
                "node_id": p.get("node_id", ""),
                "attribute_name": p.get("attribute_name", ""),
                "original_text": p.get("original_text", ""),
                "pref_label_en": p.get("pref_label_en", ""),
                "radar_version": p.get("radar_version"),
            })
        return results, ""
    except Exception as exc:
        return [], str(exc)