File size: 4,241 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Vektoru datubāzes klients ar lokālu persistenci un semantisko meklēšanu."""

from __future__ import annotations

import json
import logging
import math
import os
import uuid
from pathlib import Path
from threading import RLock
from typing import Any

logger = logging.getLogger(__name__)


def _cosine_similarity(left: list[float], right: list[float]) -> float:
    if not left or not right or len(left) != len(right):
        return 0.0
    numerator = sum(a * b for a, b in zip(left, right, strict=False))
    left_norm = math.sqrt(sum(value * value for value in left))
    right_norm = math.sqrt(sum(value * value for value in right))
    if left_norm == 0 or right_norm == 0:
        return 0.0
    return numerator / (left_norm * right_norm)


class VectorDbClient:
    """Universāls vektoru datubāzes klients ar drošu fallback implementāciju."""

    def __init__(self, backend: str = "local_json", storage_path: str | None = None) -> None:
        self.backend = backend
        configured_path = storage_path or os.getenv(
            "MARIS_VECTOR_STORE_PATH", "~/.maris/vector-store.json"
        )
        self._storage_path = Path(configured_path).expanduser()
        self._lock = RLock()
        self._collections: dict[str, list[dict[str, Any]]] = {}
        self._load()

    def _load(self) -> None:
        if not self._storage_path.exists():
            return
        try:
            payload = json.loads(self._storage_path.read_text(encoding="utf-8"))
        except Exception as exc:  # noqa: BLE001
            logger.warning("Neizdevās ielādēt vector store no %s: %s", self._storage_path, exc)
            return
        if isinstance(payload, dict):
            self._collections = {
                str(collection): list(entries)
                for collection, entries in payload.items()
                if isinstance(collection, str) and isinstance(entries, list)
            }

    def _persist(self) -> None:
        self._storage_path.parent.mkdir(parents=True, exist_ok=True)
        payload = json.dumps(self._collections, ensure_ascii=False)
        tmp_path = self._storage_path.with_suffix(".tmp")
        tmp_path.write_text(payload, encoding="utf-8")
        os.replace(tmp_path, self._storage_path)

    def add(self, collection: str, text: str, metadata: dict[str, Any]) -> None:
        """Pievieno ierakstu kolekcijai un persistē to."""
        from memory.vector_store.embeddings import embed_text

        normalized_collection = collection.strip() or "default"
        normalized_text = text.strip()
        if not normalized_text:
            return

        vector = embed_text(normalized_text)
        record = {
            "id": str(uuid.uuid4()),
            "text": normalized_text,
            "metadata": dict(metadata),
            "vector": vector,
        }
        with self._lock:
            self._collections.setdefault(normalized_collection, []).append(record)
            self._persist()
        logger.debug("Pievienots vektorveikalā: %s", normalized_collection)

    def search(
        self, collection: str, query: str, top_k: int = 5
    ) -> list[dict[str, Any]]:
        """Meklē semantiski līdzīgus ierakstus."""
        from memory.vector_store.embeddings import embed_text

        normalized_collection = collection.strip() or "default"
        normalized_query = query.strip()
        if not normalized_query:
            return []

        vector = embed_text(normalized_query)
        with self._lock:
            records = list(self._collections.get(normalized_collection, []))

        ranked: list[dict[str, Any]] = []
        for record in records:
            score = _cosine_similarity(vector, list(record.get("vector", [])))
            if score <= 0:
                continue
            ranked.append(
                {
                    "id": record.get("id"),
                    "text": record.get("text"),
                    "metadata": record.get("metadata", {}),
                    "score": round(score, 6),
                }
            )
        ranked.sort(key=lambda item: float(item["score"]), reverse=True)
        logger.debug("Meklē vektorveikalā: %s", normalized_collection)
        return ranked[: max(top_k, 0)]