Marik1337's picture
Add application file
6059138
from __future__ import annotations
import ast
import operator
import re
from dataclasses import dataclass
from typing import Callable
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_mistralai import ChatMistralAI
from memory_agent.retrieval import HybridRetriever
from memory_agent.storage import FaissMemoryStore
FALLBACK_MESSAGE = (
"I don't have enough stored knowledge yet to answer that. "
"Share a few facts about your domain, and I will learn and help."
)
UNDERSTOOD_MESSAGE = "understood"
@dataclass(slots=True)
class ParsedFact:
key: str
value: str
class FactParser:
_ASSIGNMENT_PATTERNS = (
re.compile(r"^\s*([A-Za-z][A-Za-z0-9 _-]{0,80})\s*(?:=|:|is)\s*(.+?)\s*$", re.IGNORECASE),
re.compile(
r"^\s*(?:set|update|remember)\s+(.+?)\s+(?:to|as|=)\s+(.+?)\s*$",
re.IGNORECASE,
),
)
def parse(self, text: str) -> ParsedFact | None:
for pattern in self._ASSIGNMENT_PATTERNS:
match = pattern.match(text)
if not match:
continue
key = self._clean_key(match.group(1))
value = self._clean_value(match.group(2))
if key and value:
return ParsedFact(key=key, value=value)
return None
@staticmethod
def _clean_key(raw: str) -> str:
cleaned = raw.strip().strip(".")
cleaned = re.sub(r"\b(the|a|an)\b", "", cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r"\s+", " ", cleaned)
return cleaned.strip()
@staticmethod
def _clean_value(raw: str) -> str:
cleaned = raw.strip().strip(".").strip()
if (cleaned.startswith('"') and cleaned.endswith('"')) or (
cleaned.startswith("'") and cleaned.endswith("'")
):
return cleaned[1:-1].strip()
return cleaned
class SafeMathEvaluator:
_BINARY_OPERATORS: dict[type[ast.operator], Callable[[float, float], float]] = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.Pow: operator.pow,
ast.Mod: operator.mod,
ast.FloorDiv: operator.floordiv,
}
_UNARY_OPERATORS: dict[type[ast.unaryop], Callable[[float], float]] = {
ast.UAdd: operator.pos,
ast.USub: operator.neg,
}
def evaluate(self, expression: str, facts: dict[str, str]) -> float:
tree = ast.parse(expression, mode="eval")
return float(self._evaluate_node(tree.body, facts))
def _evaluate_node(self, node: ast.AST, facts: dict[str, str]) -> float:
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return float(node.value)
if isinstance(node, ast.Name):
key = FaissMemoryStore.normalize_key(node.id)
if key not in facts:
raise ValueError(f"Unknown variable: {node.id}")
try:
return float(facts[key])
except ValueError as error:
raise ValueError(f"Variable '{node.id}' is not numeric.") from error
if isinstance(node, ast.BinOp):
op_type = type(node.op)
if op_type not in self._BINARY_OPERATORS:
raise ValueError("Unsupported operator in expression.")
left = self._evaluate_node(node.left, facts)
right = self._evaluate_node(node.right, facts)
return self._BINARY_OPERATORS[op_type](left, right)
if isinstance(node, ast.UnaryOp):
op_type = type(node.op)
if op_type not in self._UNARY_OPERATORS:
raise ValueError("Unsupported unary operator in expression.")
operand = self._evaluate_node(node.operand, facts)
return self._UNARY_OPERATORS[op_type](operand)
raise ValueError("Unsupported expression.")
class KnowledgeService:
def __init__(
self,
store: FaissMemoryStore,
retriever: HybridRetriever,
llm: ChatMistralAI,
top_k: int,
user_prompt_template: str,
) -> None:
self._store = store
self._retriever = retriever
self._llm = llm
self._top_k = top_k
self._user_prompt_template = user_prompt_template
self._fact_parser = FactParser()
self._math_evaluator = SafeMathEvaluator()
def store_knowledge(
self,
namespace: str,
knowledge_text: str,
fact_key: str | None = None,
fact_value: str | None = None,
) -> str:
parsed_fact = self._fact_parser.parse(knowledge_text)
key = fact_key or (parsed_fact.key if parsed_fact else None)
value = fact_value or (parsed_fact.value if parsed_fact else None)
content = knowledge_text
if key and value:
content = f"{key} = {value}"
self._store.upsert_knowledge(
namespace=namespace,
content=content,
fact_key=key,
fact_value=value,
)
return UNDERSTOOD_MESSAGE
def answer_from_knowledge(self, namespace: str, question: str) -> str:
facts = self._store.fetch_fact_map(namespace=namespace)
expression = self._extract_expression(question=question)
if expression:
try:
value = self._math_evaluator.evaluate(expression=expression, facts=facts)
formatted_value = int(value) if value.is_integer() else round(value, 6)
substituted = self._substitute_expression(expression=expression, facts=facts)
return f"Based on stored knowledge: {substituted} = {formatted_value}."
except ValueError:
pass
results = self._retriever.retrieve(namespace=namespace, query=question, k=self._top_k)
if not results:
return FALLBACK_MESSAGE
context = "\n\n".join(
f"- {doc.page_content}\n metadata={doc.metadata}"
for doc, _ in results
)
prompt = self._user_prompt_template.format(
fallback_message=FALLBACK_MESSAGE,
context=context,
question=question,
)
response = self._llm.invoke(
[
SystemMessage(content="You are a memory-only answer tool."),
HumanMessage(content=prompt),
]
)
content = str(response.content).strip()
if not content:
return FALLBACK_MESSAGE
return content
@staticmethod
def _extract_expression(question: str) -> str | None:
quoted = re.findall(r'["\']([^"\']+)["\']', question)
for candidate in quoted:
if re.search(r"[+\-*/^%()]", candidate):
return candidate.replace("^", "**")
inline = re.search(
r"(?:calculate|equation|expression)\s*[:\-,]?\s*([A-Za-z0-9_+\-*/^%(). ]+)",
question,
flags=re.IGNORECASE,
)
if inline:
expression = inline.group(1).strip().rstrip(".")
fragment = re.search(
r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*"
r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)",
expression,
flags=re.IGNORECASE,
)
if fragment:
return fragment.group(1).strip().replace("^", "**")
if re.search(r"[+\-*/^%()]", expression):
return expression.replace("^", "**")
fragment = re.search(
r"((?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*"
r"(?:[+\-*/^%]\s*(?:[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|\([^)]+\))\s*)+)",
question,
flags=re.IGNORECASE,
)
if fragment:
return fragment.group(1).strip().replace("^", "**")
return None
@staticmethod
def _substitute_expression(expression: str, facts: dict[str, str]) -> str:
substituted = expression
# Replace longer keys first so short keys do not partially replace longer names.
for key in sorted(facts.keys(), key=len, reverse=True):
value = facts[key]
pattern = re.compile(rf"\b{re.escape(key)}\b", flags=re.IGNORECASE)
substituted = pattern.sub(value, substituted)
return substituted