| """Session-aware memory retrieval for text generation.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
| import re |
| from collections import defaultdict, deque |
| from dataclasses import dataclass |
| from datetime import UTC, datetime |
| from pathlib import Path |
| from threading import RLock |
| from typing import Any |
|
|
| from maris_core.utils.env import get_env_any |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _TOKEN_RE = re.compile(r"\w+", flags=re.UNICODE) |
| _TRIGRAM_WINDOW = 3 |
| _GENERIC_MEMORY_TOKENS = { |
| "tas", |
| "tad", |
| "par", |
| "vai", |
| "bet", |
| "kur", |
| "kad", |
| "man", |
| "tev", |
| "jau", |
| "ari", |
| "arī", |
| "nav", |
| "bija", |
| "būt", |
| "this", |
| "that", |
| "with", |
| "from", |
| } |
| _SOURCE_BONUS = { |
| "live": 0.16, |
| "history": 0.12, |
| "vision_context": 0.14, |
| "voice_stt": 0.11, |
| "voice_tts": 0.11, |
| "autonomous_goal": 0.14, |
| } |
| |
| |
| _USER_FOCUS_MARKERS = ( |
| "es gribu", |
| "es vēlos", |
| "man vajag", |
| "mans mērķis", |
| "man svarīgi", |
| "esmu", |
| "strādāju", |
| "būvēju", |
| "veidoju", |
| "mēs būvējam", |
| "mēs veidojam", |
| "i want", |
| "i need", |
| "my goal", |
| "important to me", |
| "we are building", |
| ) |
| _USER_GOAL_MARKERS = ("gribu", "vēlos", "vajag", "mērķ", "want", "need") |
| _USER_PREFERENCE_MARKERS = ("svarīgi", "prefer", "patīk", "important") |
| _USER_FOCUS_QUERY_OVERLAP_WEIGHT = 0.55 |
| _USER_FOCUS_MARKER_BONUS = 0.24 |
| _USER_FOCUS_RECENCY_WEIGHT = 0.2 |
| _ACTIVE_THREAD_MAX_WORDS = 18 |
| |
| |
| _ACTIVE_THREAD_MARKERS = ( |
| "?", |
| "palīdzi", |
| "izveido", |
| "uztaisi", |
| "turpin", |
| "nākam", |
| "priorit", |
| "vajag", |
| "need", |
| "help", |
| "next step", |
| "continue", |
| ) |
| _ACTIVE_THREAD_QUERY_OVERLAP_WEIGHT = 0.55 |
| _ACTIVE_THREAD_MARKER_BONUS = 0.2 |
| _ACTIVE_THREAD_RECENCY_WEIGHT = 0.25 |
| _CONTINUATION_QUERY_BONUS = 0.24 |
| _CONTINUATION_QUERY_MARKERS = ( |
| "turpin", |
| "iepriekš", |
| "šo pašu", |
| "šajā pašā", |
| "nākam", |
| "same context", |
| "same thread", |
| "continue", |
| "pick up", |
| "previous context", |
| ) |
|
|
|
|
| @dataclass(frozen=True, slots=True) |
| class MemoryMatch: |
| role: str |
| content: str |
| score: float |
| source: str |
|
|
|
|
| class ConversationMemoryStore: |
| """In-memory conversational memory with lightweight relevance scoring.""" |
|
|
| def __init__(self, max_entries_per_session: int = 120, storage_path: str | None = None) -> None: |
| self._max_entries_per_session = max_entries_per_session |
| |
| |
| self._lock = RLock() |
| self._sessions: defaultdict[str, deque[dict[str, Any]]] = defaultdict( |
| lambda: deque(maxlen=max_entries_per_session) |
| ) |
| self._global: deque[dict[str, Any]] = deque(maxlen=max_entries_per_session * 4) |
| self._storage_path = ( |
| Path(storage_path.strip()).expanduser() |
| if storage_path and storage_path.strip() |
| else None |
| ) |
| self._load_from_disk() |
|
|
| def remember_message( |
| self, |
| session_id: str, |
| role: str, |
| content: str, |
| *, |
| source: str = "live", |
| ) -> None: |
| normalized_role = role.strip().lower() |
| normalized_content = content.strip() |
| if normalized_role not in {"user", "assistant"} or not normalized_content: |
| return |
|
|
| normalized_session_id = session_id.strip() or "default" |
| entry = { |
| "role": normalized_role, |
| "content": normalized_content, |
| "timestamp": datetime.now(tz=UTC).isoformat(), |
| "source": source, |
| } |
| with self._lock: |
| if self._is_duplicate(self._sessions[normalized_session_id], entry): |
| return |
| self._sessions[normalized_session_id].append(entry) |
| self._global.append({**entry, "session_id": normalized_session_id}) |
| self._persist_to_disk() |
|
|
| def seed_history(self, session_id: str, history: list[dict[str, str]]) -> None: |
| for item in history: |
| self.remember_message( |
| session_id, |
| item.get("role", ""), |
| item.get("content", ""), |
| source="history", |
| ) |
|
|
| def retrieve_relevant_context( |
| self, session_id: str, query: str, *, limit: int = 4 |
| ) -> list[MemoryMatch]: |
| query_token_sequence = _extract_semantic_token_sequence(query) |
| query_tokens = set(query_token_sequence) |
| query_text = query.strip().lower() |
| continuation_query = _looks_like_continuation_query(query_text) |
| if not query_tokens and not continuation_query: |
| return [] |
|
|
| normalized_session_id = session_id.strip() or "default" |
| query_phrases = _extract_phrases(query_token_sequence) |
| query_trigrams = _char_trigrams(query_text) |
| with self._lock: |
| candidates = list(self._sessions.get(normalized_session_id, ())) + list(self._global) |
| ranked: list[MemoryMatch] = [] |
| seen: set[tuple[str, str]] = set() |
| total = max(len(candidates), 1) |
|
|
| for index, candidate in enumerate(candidates): |
| content = str(candidate.get("content", "")).strip() |
| role = str(candidate.get("role", "")).strip().lower() |
| if not content or content == query or role not in {"user", "assistant"}: |
| continue |
|
|
| content_text = content.lower() |
| candidate_token_sequence = _extract_semantic_token_sequence(content_text) |
| candidate_tokens = set(candidate_token_sequence) |
| if not candidate_tokens: |
| continue |
|
|
| overlap = _jaccard_similarity(query_tokens, candidate_tokens) |
| phrase_overlap = _jaccard_similarity( |
| query_phrases, |
| _extract_phrases(candidate_token_sequence), |
| ) |
| trigram_overlap = _jaccard_similarity(query_trigrams, _char_trigrams(content_text)) |
| if ( |
| overlap <= 0 |
| and phrase_overlap <= 0 |
| and trigram_overlap < 0.12 |
| and query_text not in content_text |
| and content_text not in query_text |
| and not continuation_query |
| ): |
| continue |
|
|
| recency_bonus = (index + 1) / total * 0.22 |
| session_bonus = ( |
| 0.24 |
| if candidate.get("session_id", normalized_session_id) == normalized_session_id |
| else 0.04 |
| ) |
| substring_bonus = 0.18 if query_text in content_text else 0.0 |
| source_bonus = _SOURCE_BONUS.get(str(candidate.get("source", "memory")), 0.08) |
| continuation_bonus = ( |
| _CONTINUATION_QUERY_BONUS |
| if continuation_query |
| and candidate.get("session_id", normalized_session_id) == normalized_session_id |
| else 0.0 |
| ) |
| score = ( |
| overlap * 0.55 |
| + phrase_overlap * 0.25 |
| + trigram_overlap * 0.20 |
| + recency_bonus |
| + session_bonus |
| + substring_bonus |
| + source_bonus |
| + continuation_bonus |
| ) |
|
|
| dedupe_key = (role, content) |
| if dedupe_key in seen: |
| continue |
| seen.add(dedupe_key) |
| ranked.append( |
| MemoryMatch( |
| role=role, |
| content=content, |
| score=score, |
| source=str(candidate.get("source", "memory")), |
| ) |
| ) |
|
|
| ranked.sort(key=lambda item: item.score, reverse=True) |
| return ranked[:limit] |
|
|
| def summarize_session(self, session_id: str, *, limit: int = 4) -> list[str]: |
| normalized_session_id = session_id.strip() or "default" |
| with self._lock: |
| entries = list(self._sessions.get(normalized_session_id, ())) |
|
|
| if not entries: |
| return [] |
|
|
| summaries: list[str] = [] |
| seen: set[str] = set() |
| recent_entries = reversed(entries[-min(len(entries), limit * 3) :]) |
| for entry in recent_entries: |
| content = str(entry.get("content", "")).strip() |
| if not content: |
| continue |
| concise = _compact_summary_text(content) |
| if not concise: |
| continue |
| lowered = concise.lower() |
| if lowered in seen: |
| continue |
| seen.add(lowered) |
| role = str(entry.get("role", "")).strip().lower() or "assistant" |
| prefix = "Lietotājs" if role == "user" else "Maris" |
| summaries.append(f"{prefix}: {concise}") |
| if len(summaries) >= limit: |
| break |
|
|
| summaries.reverse() |
| return summaries |
|
|
| def summarize_user_focus( |
| self, |
| session_id: str, |
| *, |
| query: str = "", |
| limit: int = 4, |
| ) -> list[str]: |
| normalized_session_id = session_id.strip() or "default" |
| query_tokens = set(_extract_semantic_token_sequence(query)) |
| with self._lock: |
| entries = list(self._sessions.get(normalized_session_id, ())) |
|
|
| if not entries: |
| return [] |
|
|
| candidates: list[tuple[float, int, str]] = [] |
| total = len(entries) |
| for index, entry in enumerate(entries): |
| if str(entry.get("role", "")).strip().lower() != "user": |
| continue |
| content = str(entry.get("content", "")).strip() |
| if not content: |
| continue |
| for candidate in _extract_user_focus_candidates(content): |
| candidate_tokens = set(_extract_semantic_token_sequence(candidate)) |
| overlap = ( |
| _jaccard_similarity(query_tokens, candidate_tokens) if query_tokens else 0.0 |
| ) |
| marker_bonus = ( |
| _USER_FOCUS_MARKER_BONUS if _looks_like_user_focus(candidate) else 0.0 |
| ) |
| recency_bonus = ((index + 1) / max(total, 1)) * _USER_FOCUS_RECENCY_WEIGHT |
| score = overlap * _USER_FOCUS_QUERY_OVERLAP_WEIGHT + marker_bonus + recency_bonus |
| candidates.append((score, index, candidate)) |
|
|
| if not candidates: |
| return [] |
|
|
| candidates.sort(key=lambda item: (item[0], item[1]), reverse=True) |
| summaries: list[str] = [] |
| seen: set[str] = set() |
| for _score, _index, candidate in candidates: |
| lowered = candidate.lower() |
| if lowered in seen: |
| continue |
| seen.add(lowered) |
| label = _classify_user_focus(candidate, lowered=lowered) |
| summaries.append(f"{label}: {candidate}") |
| if len(summaries) >= limit: |
| break |
| return summaries |
|
|
| def summarize_active_threads( |
| self, |
| session_id: str, |
| *, |
| query: str = "", |
| limit: int = 3, |
| ) -> list[str]: |
| normalized_session_id = session_id.strip() or "default" |
| query_tokens = set(_extract_semantic_token_sequence(query)) |
| with self._lock: |
| entries = list(self._sessions.get(normalized_session_id, ())) |
|
|
| if not entries: |
| return [] |
|
|
| candidates: list[tuple[float, int, str]] = [] |
| total = len(entries) |
| for index, entry in enumerate(entries): |
| if str(entry.get("role", "")).strip().lower() != "user": |
| continue |
| content = str(entry.get("content", "")).strip() |
| if not content: |
| continue |
| for candidate in _extract_active_thread_candidates(content): |
| lowered = candidate.lower() |
| candidate_tokens = set(_extract_semantic_token_sequence(candidate)) |
| overlap = ( |
| _jaccard_similarity(query_tokens, candidate_tokens) if query_tokens else 0.0 |
| ) |
| marker_bonus = ( |
| _ACTIVE_THREAD_MARKER_BONUS |
| if _looks_like_active_thread(candidate, lowered=lowered) |
| else 0.0 |
| ) |
| recency_bonus = ((index + 1) / max(total, 1)) * _ACTIVE_THREAD_RECENCY_WEIGHT |
| score = overlap * _ACTIVE_THREAD_QUERY_OVERLAP_WEIGHT + marker_bonus + recency_bonus |
| candidates.append((score, index, candidate)) |
|
|
| if not candidates: |
| return [] |
|
|
| candidates.sort(key=lambda item: (item[0], item[1]), reverse=True) |
| summaries: list[str] = [] |
| seen: set[str] = set() |
| for _score, _index, candidate in candidates: |
| lowered = candidate.lower() |
| if lowered in seen: |
| continue |
| seen.add(lowered) |
| label = _classify_active_thread(candidate, lowered=lowered) |
| summaries.append(f"{label}: {candidate}") |
| if len(summaries) >= limit: |
| break |
| return summaries |
|
|
| def clear(self) -> None: |
| with self._lock: |
| self._sessions.clear() |
| self._global.clear() |
| self._persist_to_disk() |
|
|
| @staticmethod |
| def _is_duplicate(buffer: deque[dict[str, Any]], entry: dict[str, Any]) -> bool: |
| if not buffer: |
| return False |
| latest = buffer[-1] |
| return latest.get("role") == entry["role"] and latest.get("content") == entry["content"] |
|
|
| def _load_from_disk(self) -> None: |
| if self._storage_path is None or not self._storage_path.exists(): |
| return |
|
|
| try: |
| payload = json.loads(self._storage_path.read_text(encoding="utf-8")) |
| except Exception as exc: |
| logger.warning("Neizdevās ielādēt sarunu atmiņu no %s: %s", self._storage_path, exc) |
| return |
|
|
| sessions = payload.get("sessions", {}) |
| if not isinstance(sessions, dict): |
| return |
|
|
| for session_id, entries in sessions.items(): |
| if not isinstance(session_id, str) or not isinstance(entries, list): |
| continue |
| for entry in entries: |
| if not isinstance(entry, dict): |
| continue |
| role = str(entry.get("role", "")).strip().lower() |
| content = str(entry.get("content", "")).strip() |
| timestamp = ( |
| str(entry.get("timestamp", "")).strip() or datetime.now(tz=UTC).isoformat() |
| ) |
| source = str(entry.get("source", "disk")).strip() or "disk" |
| if role not in {"user", "assistant"} or not content: |
| continue |
| normalized_entry = { |
| "role": role, |
| "content": content, |
| "timestamp": timestamp, |
| "source": source, |
| } |
| if self._is_duplicate(self._sessions[session_id], normalized_entry): |
| continue |
| self._sessions[session_id].append(normalized_entry) |
| self._global.append({**normalized_entry, "session_id": session_id}) |
|
|
| def _persist_to_disk(self) -> None: |
| if self._storage_path is None: |
| return |
|
|
| try: |
| self._storage_path.parent.mkdir(parents=True, exist_ok=True) |
| payload = { |
| "sessions": { |
| session_id: list(entries) for session_id, entries in self._sessions.items() |
| } |
| } |
| tmp_path = self._storage_path.with_name(f"{self._storage_path.name}.tmp") |
| tmp_path.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8") |
| os.replace(tmp_path, self._storage_path) |
| except Exception as exc: |
| logger.warning("Neizdevās saglabāt sarunu atmiņu uz %s: %s", self._storage_path, exc) |
|
|
|
|
| memory_store = ConversationMemoryStore( |
| storage_path=get_env_any( |
| "MARIS_MEMORY_STORE_PATH", |
| "MARIS_CONVERSATION_MEMORY_PATH", |
| default="~/.maris/conversation-memory.json", |
| ) |
| ) |
|
|
|
|
| def _jaccard_similarity(set_a: set[str], set_b: set[str]) -> float: |
| union = set_a | set_b |
| if not union: |
| return 0.0 |
| return len(set_a & set_b) / len(union) |
|
|
|
|
| def _extract_semantic_tokens(text: str) -> set[str]: |
| return set(_extract_semantic_token_sequence(text)) |
|
|
|
|
| def _extract_semantic_token_sequence(text: str) -> list[str]: |
| tokens: list[str] = [] |
| for raw_token in _TOKEN_RE.findall(text.lower()): |
| token = _normalize_token(raw_token) |
| if len(token) < 3 or token in _GENERIC_MEMORY_TOKENS: |
| continue |
| tokens.append(token) |
| return tokens |
|
|
|
|
| def _normalize_token(token: str) -> str: |
| normalized = token.lower().strip("-_ ") |
| for suffix in ( |
| "ajiem", |
| "ajām", |
| "ajai", |
| "ajos", |
| "ajās", |
| "ības", |
| "iem", |
| "ām", |
| "ais", |
| "ajā", |
| "ing", |
| "ers", |
| "ies", |
| "us", |
| "as", |
| "es", |
| "am", |
| "em", |
| "ai", |
| "ei", |
| "u", |
| "a", |
| "i", |
| "s", |
| ): |
| if normalized.endswith(suffix) and len(normalized) - len(suffix) >= 4: |
| return normalized[: -len(suffix)] |
| return normalized |
|
|
|
|
| def _extract_phrases(tokens: list[str]) -> set[str]: |
| if len(tokens) < 2: |
| return set() |
| return {f"{tokens[index]} {tokens[index + 1]}" for index in range(len(tokens) - 1)} |
|
|
|
|
| def _char_trigrams(text: str) -> set[str]: |
| compact = re.sub(r"\s+", " ", text.strip().lower()) |
| if len(compact) < _TRIGRAM_WINDOW: |
| return {compact} if compact else set() |
| return { |
| compact[index : index + _TRIGRAM_WINDOW] |
| for index in range(len(compact) - _TRIGRAM_WINDOW + 1) |
| } |
|
|
|
|
| def _compact_summary_text(text: str, *, max_words: int = 18) -> str: |
| cleaned = re.sub(r"\s+", " ", text.strip()) |
| if not cleaned: |
| return "" |
| words = cleaned.split(" ") |
| if len(words) <= max_words: |
| return cleaned |
| return " ".join(words[:max_words]).rstrip(" ,;:.") + "…" |
|
|
|
|
| def _extract_user_focus_candidates(text: str) -> list[str]: |
| parts = re.split(r"(?<=[.!?])\s+|\n+", text) |
| candidates: list[str] = [] |
| for part in parts: |
| cleaned = _build_user_focus_candidate(part) |
| if not cleaned: |
| continue |
| candidates.append(cleaned) |
| if candidates: |
| return candidates |
|
|
| compact = _build_user_focus_candidate(text) |
| return [compact] if compact else [] |
|
|
|
|
| def _build_user_focus_candidate(text: str) -> str: |
| compact = _compact_summary_text(text, max_words=20) |
| lowered = compact.lower() |
| if not compact or not _looks_like_user_focus(compact, lowered=lowered): |
| return "" |
| return compact |
|
|
|
|
| def _extract_active_thread_candidates(text: str) -> list[str]: |
| parts = re.split(r"(?<=[.!?])\s+|\n+", text) |
| candidates: list[str] = [] |
| for part in parts: |
| cleaned = _build_active_thread_candidate(part) |
| if not cleaned: |
| continue |
| candidates.append(cleaned) |
| if candidates: |
| return candidates |
|
|
| compact = _build_active_thread_candidate(text) |
| return [compact] if compact else [] |
|
|
|
|
| def _build_active_thread_candidate(text: str) -> str: |
| compact = _compact_summary_text(text, max_words=_ACTIVE_THREAD_MAX_WORDS) |
| lowered = compact.lower() |
| if not compact or not _looks_like_active_thread(compact, lowered=lowered): |
| return "" |
| return compact |
|
|
|
|
| def _looks_like_user_focus(text: str, *, lowered: str | None = None) -> bool: |
| lowered = lowered if lowered is not None else text.lower() |
| return any(marker in lowered for marker in _USER_FOCUS_MARKERS) |
|
|
|
|
| def _classify_user_focus(text: str, *, lowered: str | None = None) -> str: |
| lowered = lowered if lowered is not None else text.lower() |
| if any(marker in lowered for marker in _USER_GOAL_MARKERS): |
| return "Mērķis" |
| if any(marker in lowered for marker in _USER_PREFERENCE_MARKERS): |
| return "Priekšroka" |
| return "Konteksts" |
|
|
|
|
| def _looks_like_active_thread(text: str, *, lowered: str | None = None) -> bool: |
| lowered = lowered if lowered is not None else text.lower() |
| has_question_signal = _contains_question_signal(text, lowered=lowered) |
| return any( |
| has_question_signal if marker == "?" else marker in lowered |
| for marker in _ACTIVE_THREAD_MARKERS |
| ) |
|
|
|
|
| def _classify_active_thread(text: str, *, lowered: str | None = None) -> str: |
| lowered = lowered if lowered is not None else text.lower() |
| if _contains_question_signal(text, lowered=lowered): |
| return "Atvērtais jautājums" |
| return "Aktīvais virziens" |
|
|
|
|
| def _contains_question_signal(text: str, *, lowered: str | None = None) -> bool: |
| lowered = lowered if lowered is not None else text.lower() |
| return "?" in text or any( |
| marker in lowered for marker in ("kā", "kas", "kur", "kad", "why", "how") |
| ) |
|
|
|
|
| def _looks_like_continuation_query(text: str) -> bool: |
| lowered = text.lower() |
| return any(marker in lowered for marker in _CONTINUATION_QUERY_MARKERS) |
|
|