from __future__ import annotations import ast import operator import re from dataclasses import dataclass from typing import Callable from langchain_core.messages import HumanMessage, SystemMessage from langchain_mistralai import ChatMistralAI from memory_agent.retrieval import HybridRetriever from memory_agent.storage import FaissMemoryStore FALLBACK_MESSAGE = ( "I don't have enough stored knowledge yet to answer that. " "Share a few facts about your domain, and I will learn and help." ) UNDERSTOOD_MESSAGE = "understood" @dataclass(slots=True) class ParsedFact: key: str value: str class FactParser: _ASSIGNMENT_PATTERNS = ( re.compile(r"^\s*([A-Za-z][A-Za-z0-9 _-]{0,80})\s*(?:=|:|is)\s*(.+?)\s*$", re.IGNORECASE), re.compile( r"^\s*(?:set|update|remember)\s+(.+?)\s+(?:to|as|=)\s+(.+?)\s*$", re.IGNORECASE, ), ) def parse(self, text: str) -> ParsedFact | None: for pattern in self._ASSIGNMENT_PATTERNS: match = pattern.match(text) if not match: continue key = self._clean_key(match.group(1)) value = self._clean_value(match.group(2)) if key and value: return ParsedFact(key=key, value=value) return None @staticmethod def _clean_key(raw: str) -> str: cleaned = raw.strip().strip(".") cleaned = re.sub(r"\b(the|a|an)\b", "", cleaned, flags=re.IGNORECASE) cleaned = re.sub(r"\s+", " ", cleaned) return cleaned.strip() @staticmethod def _clean_value(raw: str) -> str: cleaned = raw.strip().strip(".").strip() if (cleaned.startswith('"') and cleaned.endswith('"')) or ( cleaned.startswith("'") and cleaned.endswith("'") ): return cleaned[1:-1].strip() return cleaned class SafeMathEvaluator: _BINARY_OPERATORS: dict[type[ast.operator], Callable[[float, float], float]] = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv, ast.Pow: operator.pow, ast.Mod: operator.mod, ast.FloorDiv: operator.floordiv, } _UNARY_OPERATORS: dict[type[ast.unaryop], Callable[[float], float]] = { ast.UAdd: operator.pos, ast.USub: operator.neg, } def evaluate(self, expression: str, facts: dict[str, str]) -> float: tree = ast.parse(expression, mode="eval") return float(self._evaluate_node(tree.body, facts)) def _evaluate_node(self, node: ast.AST, facts: dict[str, str]) -> float: if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): return float(node.value) if isinstance(node, ast.Name): key = FaissMemoryStore.normalize_key(node.id) if key not in facts: raise ValueError(f"Unknown variable: {node.id}") try: return float(facts[key]) except ValueError as error: raise ValueError(f"Variable '{node.id}' is not numeric.") from error if isinstance(node, ast.BinOp): op_type = type(node.op) if op_type not in self._BINARY_OPERATORS: raise ValueError("Unsupported operator in expression.") left = self._evaluate_node(node.left, facts) right = self._evaluate_node(node.right, facts) return self._BINARY_OPERATORS[op_type](left, right) if isinstance(node, ast.UnaryOp): op_type = type(node.op) if op_type not in self._UNARY_OPERATORS: raise ValueError("Unsupported unary operator in expression.") operand = self._evaluate_node(node.operand, facts) return self._UNARY_OPERATORS[op_type](operand) raise ValueError("Unsupported expression.") class KnowledgeService: def __init__( self, store: FaissMemoryStore, retriever: HybridRetriever, llm: ChatMistralAI, top_k: int, user_prompt_template: str, ) -> None: self._store = store self._retriever = retriever self._llm = llm self._top_k = top_k self._user_prompt_template = user_prompt_template self._fact_parser = FactParser() self._math_evaluator = SafeMathEvaluator() def store_knowledge( self, namespace: str, knowledge_text: str, fact_key: str | None = None, fact_value: str | None = None, ) -> str: parsed_fact = self._fact_parser.parse(knowledge_text) key = fact_key or (parsed_fact.key if parsed_fact else None) value = fact_value or (parsed_fact.value if parsed_fact else None) content = knowledge_text if key and value: content = f"{key} = {value}" self._store.upsert_knowledge( namespace=namespace, content=content, fact_key=key, fact_value=value, ) return UNDERSTOOD_MESSAGE def answer_from_knowledge(self, namespace: str, question: str) -> str: facts = self._store.fetch_fact_map(namespace=namespace) expression = self._extract_expression(question=question) if expression: try: value = self._math_evaluator.evaluate(expression=expression, facts=facts) formatted_value = int(value) if value.is_integer() else round(value, 6) substituted = self._substitute_expression(expression=expression, facts=facts) return f"Based on stored knowledge: {substituted} = {formatted_value}." except ValueError: pass results = self._retriever.retrieve(namespace=namespace, query=question, k=self._top_k) if not results: return FALLBACK_MESSAGE context = "\n\n".join( f"- {doc.page_content}\n metadata={doc.metadata}" for doc, _ in results ) prompt = self._user_prompt_template.format( fallback_message=FALLBACK_MESSAGE, context=context, question=question, ) response = self._llm.invoke( [ SystemMessage(content="You are a memory-only answer tool."), HumanMessage(content=prompt), ] ) content = str(response.content).strip() if not content: return FALLBACK_MESSAGE return content @staticmethod def _extract_expression(question: str) -> str | None: quoted = re.findall(r'["\']([^"\']+)["\']', question) for candidate in quoted: if re.search(r"[+\-*/^%()]", candidate): return candidate.replace("^", "**") inline = re.search( r"(?:calculate|equation|expression)\s*[:\-,]?\s*([A-Za-z0-9_+\-*/^%(). ]+)", question, flags=re.IGNORECASE, ) if inline: expression = inline.group(1).strip().rstrip(".") fragment = re.search( r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*" r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)", expression, flags=re.IGNORECASE, ) if fragment: return fragment.group(1).strip().replace("^", "**") if re.search(r"[+\-*/^%()]", expression): return expression.replace("^", "**") fragment = re.search( r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*" r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)", question, flags=re.IGNORECASE, ) if fragment: return fragment.group(1).strip().replace("^", "**") return None @staticmethod def _substitute_expression(expression: str, facts: dict[str, str]) -> str: substituted = expression # Replace longer keys first so short keys do not partially replace longer names. for key in sorted(facts.keys(), key=len, reverse=True): value = facts[key] pattern = re.compile(rf"\b{re.escape(key)}\b", flags=re.IGNORECASE) substituted = pattern.sub(value, substituted) return substituted