File size: 8,386 Bytes
6059138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from __future__ import annotations

import ast
import operator
import re
from dataclasses import dataclass
from typing import Callable

from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mistralai import ChatMistralAI

from memory_agent.retrieval import HybridRetriever
from memory_agent.storage import FaissMemoryStore

FALLBACK_MESSAGE = (
    "I don't have enough stored knowledge yet to answer that. "
    "Share a few facts about your domain, and I will learn and help."
)
UNDERSTOOD_MESSAGE = "understood"


@dataclass(slots=True)
class ParsedFact:
    key: str
    value: str


class FactParser:
    _ASSIGNMENT_PATTERNS = (
        re.compile(r"^\s*([A-Za-z][A-Za-z0-9 _-]{0,80})\s*(?:=|:|is)\s*(.+?)\s*$", re.IGNORECASE),
        re.compile(
            r"^\s*(?:set|update|remember)\s+(.+?)\s+(?:to|as|=)\s+(.+?)\s*$",
            re.IGNORECASE,
        ),
    )

    def parse(self, text: str) -> ParsedFact | None:
        for pattern in self._ASSIGNMENT_PATTERNS:
            match = pattern.match(text)
            if not match:
                continue
            key = self._clean_key(match.group(1))
            value = self._clean_value(match.group(2))
            if key and value:
                return ParsedFact(key=key, value=value)
        return None

    @staticmethod
    def _clean_key(raw: str) -> str:
        cleaned = raw.strip().strip(".")
        cleaned = re.sub(r"\b(the|a|an)\b", "", cleaned, flags=re.IGNORECASE)
        cleaned = re.sub(r"\s+", " ", cleaned)
        return cleaned.strip()

    @staticmethod
    def _clean_value(raw: str) -> str:
        cleaned = raw.strip().strip(".").strip()
        if (cleaned.startswith('"') and cleaned.endswith('"')) or (
            cleaned.startswith("'") and cleaned.endswith("'")
        ):
            return cleaned[1:-1].strip()
        return cleaned


class SafeMathEvaluator:
    _BINARY_OPERATORS: dict[type[ast.operator], Callable[[float, float], float]] = {
        ast.Add: operator.add,
        ast.Sub: operator.sub,
        ast.Mult: operator.mul,
        ast.Div: operator.truediv,
        ast.Pow: operator.pow,
        ast.Mod: operator.mod,
        ast.FloorDiv: operator.floordiv,
    }
    _UNARY_OPERATORS: dict[type[ast.unaryop], Callable[[float], float]] = {
        ast.UAdd: operator.pos,
        ast.USub: operator.neg,
    }

    def evaluate(self, expression: str, facts: dict[str, str]) -> float:
        tree = ast.parse(expression, mode="eval")
        return float(self._evaluate_node(tree.body, facts))

    def _evaluate_node(self, node: ast.AST, facts: dict[str, str]) -> float:
        if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
            return float(node.value)

        if isinstance(node, ast.Name):
            key = FaissMemoryStore.normalize_key(node.id)
            if key not in facts:
                raise ValueError(f"Unknown variable: {node.id}")
            try:
                return float(facts[key])
            except ValueError as error:
                raise ValueError(f"Variable '{node.id}' is not numeric.") from error

        if isinstance(node, ast.BinOp):
            op_type = type(node.op)
            if op_type not in self._BINARY_OPERATORS:
                raise ValueError("Unsupported operator in expression.")
            left = self._evaluate_node(node.left, facts)
            right = self._evaluate_node(node.right, facts)
            return self._BINARY_OPERATORS[op_type](left, right)

        if isinstance(node, ast.UnaryOp):
            op_type = type(node.op)
            if op_type not in self._UNARY_OPERATORS:
                raise ValueError("Unsupported unary operator in expression.")
            operand = self._evaluate_node(node.operand, facts)
            return self._UNARY_OPERATORS[op_type](operand)

        raise ValueError("Unsupported expression.")


class KnowledgeService:
    def __init__(
        self,
        store: FaissMemoryStore,
        retriever: HybridRetriever,
        llm: ChatMistralAI,
        top_k: int,
        user_prompt_template: str,
    ) -> None:
        self._store = store
        self._retriever = retriever
        self._llm = llm
        self._top_k = top_k
        self._user_prompt_template = user_prompt_template
        self._fact_parser = FactParser()
        self._math_evaluator = SafeMathEvaluator()

    def store_knowledge(
        self,
        namespace: str,
        knowledge_text: str,
        fact_key: str | None = None,
        fact_value: str | None = None,
    ) -> str:
        parsed_fact = self._fact_parser.parse(knowledge_text)
        key = fact_key or (parsed_fact.key if parsed_fact else None)
        value = fact_value or (parsed_fact.value if parsed_fact else None)
        content = knowledge_text
        if key and value:
            content = f"{key} = {value}"

        self._store.upsert_knowledge(
            namespace=namespace,
            content=content,
            fact_key=key,
            fact_value=value,
        )
        return UNDERSTOOD_MESSAGE

    def answer_from_knowledge(self, namespace: str, question: str) -> str:
        facts = self._store.fetch_fact_map(namespace=namespace)
        expression = self._extract_expression(question=question)
        if expression:
            try:
                value = self._math_evaluator.evaluate(expression=expression, facts=facts)
                formatted_value = int(value) if value.is_integer() else round(value, 6)
                substituted = self._substitute_expression(expression=expression, facts=facts)
                return f"Based on stored knowledge: {substituted} = {formatted_value}."
            except ValueError:
                pass

        results = self._retriever.retrieve(namespace=namespace, query=question, k=self._top_k)
        if not results:
            return FALLBACK_MESSAGE

        context = "\n\n".join(
            f"- {doc.page_content}\n  metadata={doc.metadata}"
            for doc, _ in results
        )
        prompt = self._user_prompt_template.format(
            fallback_message=FALLBACK_MESSAGE,
            context=context,
            question=question,
        )
        response = self._llm.invoke(
            [
                SystemMessage(content="You are a memory-only answer tool."),
                HumanMessage(content=prompt),
            ]
        )
        content = str(response.content).strip()
        if not content:
            return FALLBACK_MESSAGE
        return content

    @staticmethod
    def _extract_expression(question: str) -> str | None:
        quoted = re.findall(r'["\']([^"\']+)["\']', question)
        for candidate in quoted:
            if re.search(r"[+\-*/^%()]", candidate):
                return candidate.replace("^", "**")

        inline = re.search(
            r"(?:calculate|equation|expression)\s*[:\-,]?\s*([A-Za-z0-9_+\-*/^%(). ]+)",
            question,
            flags=re.IGNORECASE,
        )
        if inline:
            expression = inline.group(1).strip().rstrip(".")
            fragment = re.search(
                r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*"
                r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)",
                expression,
                flags=re.IGNORECASE,
            )
            if fragment:
                return fragment.group(1).strip().replace("^", "**")
            if re.search(r"[+\-*/^%()]", expression):
                return expression.replace("^", "**")

        fragment = re.search(
            r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*"
            r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)",
            question,
            flags=re.IGNORECASE,
        )
        if fragment:
            return fragment.group(1).strip().replace("^", "**")
        return None

    @staticmethod
    def _substitute_expression(expression: str, facts: dict[str, str]) -> str:
        substituted = expression
        # Replace longer keys first so short keys do not partially replace longer names.
        for key in sorted(facts.keys(), key=len, reverse=True):
            value = facts[key]
            pattern = re.compile(rf"\b{re.escape(key)}\b", flags=re.IGNORECASE)
            substituted = pattern.sub(value, substituted)
        return substituted