gesda_knowledge_graph_demo / graph_UI /db /vector_client.py
henryschultz
huggingface deployment
7eaced5
Raw
History Blame Contribute Delete
3.01 kB
"""Cached vector-search client for Streamlit (Qdrant + EPFL API embedder)."""
from __future__ import annotations
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import config as _cfg # noqa: F401
import logging
from typing import Any
import streamlit as st
logger = logging.getLogger(__name__)
@st.cache_resource(show_spinner=False)
def get_vector_resources():
"""Return cached (embedder, qdrant_store, error_msg)."""
try:
from src.config import load_config
from src.embeddings import build_embedder
from src.qdrant_store import QdrantStore
cfg = load_config(config_path=str(_cfg.PROJECT_ROOT / "config" / "config.yaml"))
embedder = build_embedder(cfg)
qdrant = QdrantStore(
mode=cfg["qdrant"].get("mode", "local"),
local_path=cfg["qdrant"].get("local_path"),
url=cfg["qdrant"].get("url"),
host=cfg["qdrant"].get("host", "localhost"),
port=cfg["qdrant"].get("port", 6333),
api_key=cfg["qdrant"].get("api_key"),
collection_name=_cfg.KG_NODES_COLLECTION,
distance=cfg["qdrant"].get("distance", "cosine"),
vector_dim=cfg["embeddings"]["dimension"],
)
return embedder, qdrant, ""
except Exception as exc:
logger.warning("Vector search unavailable: %s", exc)
return None, None, str(exc)
def vector_search(
query: str,
top_k: int = 20,
threshold: float = 0.5,
node_label: str | None = None,
) -> tuple[list[dict[str, Any]], str]:
"""
Run a vector similarity search against the kg_nodes Qdrant collection.
Returns (results, error_message). Results are empty on error.
If node_label is provided, filters results to that label only.
"""
embedder, qdrant, err = get_vector_resources()
if embedder is None:
return [], err
try:
from qdrant_client.models import Filter, FieldCondition, MatchValue
query_filter = None
if node_label:
query_filter = Filter(
must=[FieldCondition(key="node_type", match=MatchValue(value=node_label))]
)
query_vec = embedder.encode_query(query)
hits = qdrant.search(
query_vec,
top_k=top_k,
score_threshold=threshold,
query_filter=query_filter,
)
results = []
for h in hits:
p = h.get("payload", {})
results.append({
"score": round(h["score"], 4),
"node_type": p.get("node_type", ""),
"node_id": p.get("node_id", ""),
"attribute_name": p.get("attribute_name", ""),
"original_text": p.get("original_text", ""),
"pref_label_en": p.get("pref_label_en", ""),
"radar_version": p.get("radar_version"),
})
return results, ""
except Exception as exc:
return [], str(exc)