import os from pathlib import Path from typing import Any import yaml from qdrant_client import QdrantClient DEFAULT_CONFIG_PATH = Path(__file__).resolve().parent / "config.yaml" _DEFAULTS: dict[str, Any] = { "host": "localhost", "port": 6333, "path": None, "client_timeout": 10, "use_memory": False, "collection_name": "doc_ai_production", } _CLIENT_CACHE: dict[tuple, QdrantClient] = {} def _env_bool(name: str, default: bool | None = None) -> bool | None: raw = os.environ.get(name) if raw is None: return default v = raw.strip().lower() if v in ("1", "true", "yes", "y", "on"): return True if v in ("0", "false", "no", "n", "off"): return False return default def load_qdrant_settings(config_path: str | Path | None = None) -> dict[str, Any]: path = Path(config_path or DEFAULT_CONFIG_PATH) merged = dict(_DEFAULTS) if path.is_file(): with open(path, "r", encoding="utf-8") as f: full = yaml.safe_load(f) or {} q = dict(full.get("qdrant") or {}) if not q.get("host") and full.get("qdrant_host") is not None: q["host"] = full["qdrant_host"] if q.get("port") is None and full.get("qdrant_port") is not None: q["port"] = full["qdrant_port"] merged.update({k: v for k, v in q.items() if v is not None}) # Env overrides used across app/pipeline. if os.environ.get("QDRANT_HOST"): merged["host"] = os.environ["QDRANT_HOST"] if os.environ.get("QDRANT_PORT"): try: merged["port"] = int(os.environ["QDRANT_PORT"]) except ValueError: pass if os.environ.get("QDRANT_PATH", "").strip(): merged["path"] = os.environ["QDRANT_PATH"].strip() if os.environ.get("QDRANT_COLLECTION"): merged["collection_name"] = os.environ["QDRANT_COLLECTION"] use_memory_env = _env_bool("QDRANT_USE_MEMORY") if use_memory_env is not None: merged["use_memory"] = use_memory_env return merged def get_shared_qdrant_client( *, use_memory: bool, host: str, port: int, timeout: float, qdrant_path: str | None, ) -> QdrantClient: if use_memory: key = ("memory",) if key not in _CLIENT_CACHE: _CLIENT_CACHE[key] = QdrantClient(":memory:") return _CLIENT_CACHE[key] if qdrant_path: p = str(Path(str(qdrant_path)).expanduser().resolve()) key = ("path", p) if key not in _CLIENT_CACHE: Path(p).mkdir(parents=True, exist_ok=True) _CLIENT_CACHE[key] = QdrantClient(path=p) return _CLIENT_CACHE[key] key = ("http", host, int(port), float(timeout)) if key not in _CLIENT_CACHE: _CLIENT_CACHE[key] = QdrantClient(host=host, port=int(port), timeout=float(timeout)) return _CLIENT_CACHE[key]