Socrates_docker / classify_memory_reference.py
AlessandroAmodioNGI's picture
feat: long-term chat history β€” time-based window + multi-vector FAISS db5
1775907
Raw
History Blame Contribute Delete
2.73 kB
"""
classify_memory_reference.py
Detects whether a user message is explicitly recalling a past conversation.
Only when this fires does the system query db5 (long-term FAISS history).
Returns:
{
"recalls_past": bool,
"time_hint": str | None, e.g. "last week", "in March", "a while ago"
"character_hint": str | None, e.g. "socrates", "diogenes"
}
Uses Tier C (nano) β€” binary classifier, minimal tokens.
"""
import json
import re
from typing import Optional
_SYSTEM_PROMPT = """You are a memory-reference detector for a chat application.
Decide if the user's message explicitly tries to recall a PAST conversation β€”
phrases like "do you remember", "we talked about", "you told me", "last time",
"a while ago you said", "didn't we discuss", "remember when", etc.
Respond ONLY with a JSON object, no prose:
{
"recalls_past": true | false,
"time_hint": "<time expression from the message, or null>",
"character_hint": "<character name mentioned (socrates/diogenes/nietzsche/camus/schopenhauer), or null>"
}"""
def classify_memory_reference(user_msg: str) -> dict:
"""
Returns {"recalls_past": bool, "time_hint": str|None, "character_hint": str|None}.
Fails safe: returns recalls_past=False on any error.
"""
default = {"recalls_past": False, "time_hint": None, "character_hint": None}
if not user_msg or not user_msg.strip():
return default
# Fast path: skip LLM if no memory-signal words present
_SIGNALS = (
"remember", "recall", "we talked", "you told", "you said",
"last time", "last week", "last month", "a while ago", "before",
"didn't we", "we discussed", "you mentioned",
)
msg_lower = user_msg.lower()
if not any(s in msg_lower for s in _SIGNALS):
return default
try:
from llm_client import client, get_model
resp = client.chat.completions.create(
model=get_model("C"),
messages=[
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
temperature=0.0,
max_tokens=80,
)
raw = resp.choices[0].message.content.strip()
# strip markdown fences if present
raw = re.sub(r"^```(?:json)?\s*|\s*```$", "", raw, flags=re.DOTALL).strip()
result = json.loads(raw)
return {
"recalls_past": bool(result.get("recalls_past", False)),
"time_hint": result.get("time_hint") or None,
"character_hint": (result.get("character_hint") or "").lower() or None,
}
except Exception as e:
print(f"[classify_memory_reference] error: {e}")
return default