RAG / qdrant_shared.py
bakhil-aissa's picture
Upload qdrant_shared.py
4d84ee1 verified
import os
from pathlib import Path
from typing import Any
import yaml
from qdrant_client import QdrantClient
DEFAULT_CONFIG_PATH = Path(__file__).resolve().parent / "config.yaml"
_DEFAULTS: dict[str, Any] = {
"host": "localhost",
"port": 6333,
"path": None,
"client_timeout": 10,
"use_memory": False,
"collection_name": "doc_ai_production",
}
_CLIENT_CACHE: dict[tuple, QdrantClient] = {}
def _env_bool(name: str, default: bool | None = None) -> bool | None:
raw = os.environ.get(name)
if raw is None:
return default
v = raw.strip().lower()
if v in ("1", "true", "yes", "y", "on"):
return True
if v in ("0", "false", "no", "n", "off"):
return False
return default
def load_qdrant_settings(config_path: str | Path | None = None) -> dict[str, Any]:
path = Path(config_path or DEFAULT_CONFIG_PATH)
merged = dict(_DEFAULTS)
if path.is_file():
with open(path, "r", encoding="utf-8") as f:
full = yaml.safe_load(f) or {}
q = dict(full.get("qdrant") or {})
if not q.get("host") and full.get("qdrant_host") is not None:
q["host"] = full["qdrant_host"]
if q.get("port") is None and full.get("qdrant_port") is not None:
q["port"] = full["qdrant_port"]
merged.update({k: v for k, v in q.items() if v is not None})
# Env overrides used across app/pipeline.
if os.environ.get("QDRANT_HOST"):
merged["host"] = os.environ["QDRANT_HOST"]
if os.environ.get("QDRANT_PORT"):
try:
merged["port"] = int(os.environ["QDRANT_PORT"])
except ValueError:
pass
if os.environ.get("QDRANT_PATH", "").strip():
merged["path"] = os.environ["QDRANT_PATH"].strip()
if os.environ.get("QDRANT_COLLECTION"):
merged["collection_name"] = os.environ["QDRANT_COLLECTION"]
use_memory_env = _env_bool("QDRANT_USE_MEMORY")
if use_memory_env is not None:
merged["use_memory"] = use_memory_env
return merged
def get_shared_qdrant_client(
*,
use_memory: bool,
host: str,
port: int,
timeout: float,
qdrant_path: str | None,
) -> QdrantClient:
if use_memory:
key = ("memory",)
if key not in _CLIENT_CACHE:
_CLIENT_CACHE[key] = QdrantClient(":memory:")
return _CLIENT_CACHE[key]
if qdrant_path:
p = str(Path(str(qdrant_path)).expanduser().resolve())
key = ("path", p)
if key not in _CLIENT_CACHE:
Path(p).mkdir(parents=True, exist_ok=True)
_CLIENT_CACHE[key] = QdrantClient(path=p)
return _CLIENT_CACHE[key]
key = ("http", host, int(port), float(timeout))
if key not in _CLIENT_CACHE:
_CLIENT_CACHE[key] = QdrantClient(host=host, port=int(port), timeout=float(timeout))
return _CLIENT_CACHE[key]