| """Cached vector-search client for Streamlit (Qdrant + EPFL API embedder).""" |
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| import config as _cfg |
|
|
| import logging |
| from typing import Any |
|
|
| import streamlit as st |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @st.cache_resource(show_spinner=False) |
| def get_vector_resources(): |
| """Return cached (embedder, qdrant_store, error_msg).""" |
| try: |
| from src.config import load_config |
| from src.embeddings import build_embedder |
| from src.qdrant_store import QdrantStore |
|
|
| cfg = load_config(config_path=str(_cfg.PROJECT_ROOT / "config" / "config.yaml")) |
|
|
| embedder = build_embedder(cfg) |
| qdrant = QdrantStore( |
| mode=cfg["qdrant"].get("mode", "local"), |
| local_path=cfg["qdrant"].get("local_path"), |
| url=cfg["qdrant"].get("url"), |
| host=cfg["qdrant"].get("host", "localhost"), |
| port=cfg["qdrant"].get("port", 6333), |
| api_key=cfg["qdrant"].get("api_key"), |
| collection_name=_cfg.KG_NODES_COLLECTION, |
| distance=cfg["qdrant"].get("distance", "cosine"), |
| vector_dim=cfg["embeddings"]["dimension"], |
| ) |
| return embedder, qdrant, "" |
| except Exception as exc: |
| logger.warning("Vector search unavailable: %s", exc) |
| return None, None, str(exc) |
|
|
|
|
| def vector_search( |
| query: str, |
| top_k: int = 20, |
| threshold: float = 0.5, |
| node_label: str | None = None, |
| ) -> tuple[list[dict[str, Any]], str]: |
| """ |
| Run a vector similarity search against the kg_nodes Qdrant collection. |
| |
| Returns (results, error_message). Results are empty on error. |
| If node_label is provided, filters results to that label only. |
| """ |
| embedder, qdrant, err = get_vector_resources() |
| if embedder is None: |
| return [], err |
|
|
| try: |
| from qdrant_client.models import Filter, FieldCondition, MatchValue |
|
|
| query_filter = None |
| if node_label: |
| query_filter = Filter( |
| must=[FieldCondition(key="node_type", match=MatchValue(value=node_label))] |
| ) |
|
|
| query_vec = embedder.encode_query(query) |
| hits = qdrant.search( |
| query_vec, |
| top_k=top_k, |
| score_threshold=threshold, |
| query_filter=query_filter, |
| ) |
|
|
| results = [] |
| for h in hits: |
| p = h.get("payload", {}) |
| results.append({ |
| "score": round(h["score"], 4), |
| "node_type": p.get("node_type", ""), |
| "node_id": p.get("node_id", ""), |
| "attribute_name": p.get("attribute_name", ""), |
| "original_text": p.get("original_text", ""), |
| "pref_label_en": p.get("pref_label_en", ""), |
| "radar_version": p.get("radar_version"), |
| }) |
| return results, "" |
| except Exception as exc: |
| return [], str(exc) |
|
|