| from __future__ import annotations |
|
|
| import ast |
| import operator |
| import re |
| from dataclasses import dataclass |
| from typing import Callable |
|
|
| from langchain_core.messages import HumanMessage, SystemMessage |
| from langchain_mistralai import ChatMistralAI |
|
|
| from memory_agent.retrieval import HybridRetriever |
| from memory_agent.storage import FaissMemoryStore |
|
|
| FALLBACK_MESSAGE = ( |
| "I don't have enough stored knowledge yet to answer that. " |
| "Share a few facts about your domain, and I will learn and help." |
| ) |
| UNDERSTOOD_MESSAGE = "understood" |
|
|
|
|
| @dataclass(slots=True) |
| class ParsedFact: |
| key: str |
| value: str |
|
|
|
|
| class FactParser: |
| _ASSIGNMENT_PATTERNS = ( |
| re.compile(r"^\s*([A-Za-z][A-Za-z0-9 _-]{0,80})\s*(?:=|:|is)\s*(.+?)\s*$", re.IGNORECASE), |
| re.compile( |
| r"^\s*(?:set|update|remember)\s+(.+?)\s+(?:to|as|=)\s+(.+?)\s*$", |
| re.IGNORECASE, |
| ), |
| ) |
|
|
| def parse(self, text: str) -> ParsedFact | None: |
| for pattern in self._ASSIGNMENT_PATTERNS: |
| match = pattern.match(text) |
| if not match: |
| continue |
| key = self._clean_key(match.group(1)) |
| value = self._clean_value(match.group(2)) |
| if key and value: |
| return ParsedFact(key=key, value=value) |
| return None |
|
|
| @staticmethod |
| def _clean_key(raw: str) -> str: |
| cleaned = raw.strip().strip(".") |
| cleaned = re.sub(r"\b(the|a|an)\b", "", cleaned, flags=re.IGNORECASE) |
| cleaned = re.sub(r"\s+", " ", cleaned) |
| return cleaned.strip() |
|
|
| @staticmethod |
| def _clean_value(raw: str) -> str: |
| cleaned = raw.strip().strip(".").strip() |
| if (cleaned.startswith('"') and cleaned.endswith('"')) or ( |
| cleaned.startswith("'") and cleaned.endswith("'") |
| ): |
| return cleaned[1:-1].strip() |
| return cleaned |
|
|
|
|
| class SafeMathEvaluator: |
| _BINARY_OPERATORS: dict[type[ast.operator], Callable[[float, float], float]] = { |
| ast.Add: operator.add, |
| ast.Sub: operator.sub, |
| ast.Mult: operator.mul, |
| ast.Div: operator.truediv, |
| ast.Pow: operator.pow, |
| ast.Mod: operator.mod, |
| ast.FloorDiv: operator.floordiv, |
| } |
| _UNARY_OPERATORS: dict[type[ast.unaryop], Callable[[float], float]] = { |
| ast.UAdd: operator.pos, |
| ast.USub: operator.neg, |
| } |
|
|
| def evaluate(self, expression: str, facts: dict[str, str]) -> float: |
| tree = ast.parse(expression, mode="eval") |
| return float(self._evaluate_node(tree.body, facts)) |
|
|
| def _evaluate_node(self, node: ast.AST, facts: dict[str, str]) -> float: |
| if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): |
| return float(node.value) |
|
|
| if isinstance(node, ast.Name): |
| key = FaissMemoryStore.normalize_key(node.id) |
| if key not in facts: |
| raise ValueError(f"Unknown variable: {node.id}") |
| try: |
| return float(facts[key]) |
| except ValueError as error: |
| raise ValueError(f"Variable '{node.id}' is not numeric.") from error |
|
|
| if isinstance(node, ast.BinOp): |
| op_type = type(node.op) |
| if op_type not in self._BINARY_OPERATORS: |
| raise ValueError("Unsupported operator in expression.") |
| left = self._evaluate_node(node.left, facts) |
| right = self._evaluate_node(node.right, facts) |
| return self._BINARY_OPERATORS[op_type](left, right) |
|
|
| if isinstance(node, ast.UnaryOp): |
| op_type = type(node.op) |
| if op_type not in self._UNARY_OPERATORS: |
| raise ValueError("Unsupported unary operator in expression.") |
| operand = self._evaluate_node(node.operand, facts) |
| return self._UNARY_OPERATORS[op_type](operand) |
|
|
| raise ValueError("Unsupported expression.") |
|
|
|
|
| class KnowledgeService: |
| def __init__( |
| self, |
| store: FaissMemoryStore, |
| retriever: HybridRetriever, |
| llm: ChatMistralAI, |
| top_k: int, |
| user_prompt_template: str, |
| ) -> None: |
| self._store = store |
| self._retriever = retriever |
| self._llm = llm |
| self._top_k = top_k |
| self._user_prompt_template = user_prompt_template |
| self._fact_parser = FactParser() |
| self._math_evaluator = SafeMathEvaluator() |
|
|
| def store_knowledge( |
| self, |
| namespace: str, |
| knowledge_text: str, |
| fact_key: str | None = None, |
| fact_value: str | None = None, |
| ) -> str: |
| parsed_fact = self._fact_parser.parse(knowledge_text) |
| key = fact_key or (parsed_fact.key if parsed_fact else None) |
| value = fact_value or (parsed_fact.value if parsed_fact else None) |
| content = knowledge_text |
| if key and value: |
| content = f"{key} = {value}" |
|
|
| self._store.upsert_knowledge( |
| namespace=namespace, |
| content=content, |
| fact_key=key, |
| fact_value=value, |
| ) |
| return UNDERSTOOD_MESSAGE |
|
|
| def answer_from_knowledge(self, namespace: str, question: str) -> str: |
| facts = self._store.fetch_fact_map(namespace=namespace) |
| expression = self._extract_expression(question=question) |
| if expression: |
| try: |
| value = self._math_evaluator.evaluate(expression=expression, facts=facts) |
| formatted_value = int(value) if value.is_integer() else round(value, 6) |
| substituted = self._substitute_expression(expression=expression, facts=facts) |
| return f"Based on stored knowledge: {substituted} = {formatted_value}." |
| except ValueError: |
| pass |
|
|
| results = self._retriever.retrieve(namespace=namespace, query=question, k=self._top_k) |
| if not results: |
| return FALLBACK_MESSAGE |
|
|
| context = "\n\n".join( |
| f"- {doc.page_content}\n metadata={doc.metadata}" |
| for doc, _ in results |
| ) |
| prompt = self._user_prompt_template.format( |
| fallback_message=FALLBACK_MESSAGE, |
| context=context, |
| question=question, |
| ) |
| response = self._llm.invoke( |
| [ |
| SystemMessage(content="You are a memory-only answer tool."), |
| HumanMessage(content=prompt), |
| ] |
| ) |
| content = str(response.content).strip() |
| if not content: |
| return FALLBACK_MESSAGE |
| return content |
|
|
| @staticmethod |
| def _extract_expression(question: str) -> str | None: |
| quoted = re.findall(r'["\']([^"\']+)["\']', question) |
| for candidate in quoted: |
| if re.search(r"[+\-*/^%()]", candidate): |
| return candidate.replace("^", "**") |
|
|
| inline = re.search( |
| r"(?:calculate|equation|expression)\s*[:\-,]?\s*([A-Za-z0-9_+\-*/^%(). ]+)", |
| question, |
| flags=re.IGNORECASE, |
| ) |
| if inline: |
| expression = inline.group(1).strip().rstrip(".") |
| fragment = re.search( |
| r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*" |
| r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)", |
| expression, |
| flags=re.IGNORECASE, |
| ) |
| if fragment: |
| return fragment.group(1).strip().replace("^", "**") |
| if re.search(r"[+\-*/^%()]", expression): |
| return expression.replace("^", "**") |
|
|
| fragment = re.search( |
| r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*" |
| r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)", |
| question, |
| flags=re.IGNORECASE, |
| ) |
| if fragment: |
| return fragment.group(1).strip().replace("^", "**") |
| return None |
|
|
| @staticmethod |
| def _substitute_expression(expression: str, facts: dict[str, str]) -> str: |
| substituted = expression |
| |
| for key in sorted(facts.keys(), key=len, reverse=True): |
| value = facts[key] |
| pattern = re.compile(rf"\b{re.escape(key)}\b", flags=re.IGNORECASE) |
| substituted = pattern.sub(value, substituted) |
| return substituted |
|
|