diff --git "a/code/memory_adapters.py" "b/code/memory_adapters.py" new file mode 100644--- /dev/null +++ "b/code/memory_adapters.py" @@ -0,0 +1,9190 @@ +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass, field +import json +import math +import os +from pathlib import Path +import re +import tempfile +import time +import uuid +from typing import Any, Dict, Iterable, List, Mapping, Sequence + +from core.session_memory import SessionMemoryExtractor +from experiments.replacement.memory_graph import ( + SessionMemoryGraphV2, + SessionMemoryEdgeV2, + SessionMemoryRecordV2, + SQLiteSessionMemoryStore, + _clean_text, + _estimate_tokens, + _normalize, + _public_query_subject, + _public_slot_root, + _public_subject_signature, + stable_slot_key, + _tokenize, + guess_slot_key, + infer_category_hints, +) +from experiments.replacement.node_memory import ( + _call_with_supported_kwargs, + DEFAULT_MATRIX_EVENT_TOP_K, + LoadedNodeMemoryScorer, + MEMORY_ROUTER_LAYERS, + build_default_path_templates, + extract_question_features, +) +from experiments.replacement.public_event_signature import compute_public_event_signature +from experiments.replacement.memory_profiles import TMCRAProfile +from experiments.replacement.typed_tunnel_augmentation import ( + annotate_memory_record, + merge_typed_metadata, + typed_edge_tags_between, + typed_tunnel_signature_text, +) +from experiments.replacement.profile_layer import infer_profile_query_intent, is_profile_layer_record, profile_query_score_delta +from experiments.replacement import injection_planner as injection_planner_runtime +from experiments.replacement.temporal_modeling_types import TemporalFrame, TemporalQueryPlan +from experiments.replacement.temporal_organizer import TemporalOrganizer +from experiments.replacement.temporal_query_planner import TemporalQueryPlanner +from experiments.replacement.temporal_router_runtime import LoadedTemporalRouter +from experiments.replacement.timeline_evidence_pack import TimelineEvidencePackBuilder +from experiments.replacement.timeline_state_layer import TimelineStateLayer + +from .base import MemoryAdapter, MemoryHit, MemoryRetrieval + + +def _dedupe(items: Iterable[Any], *, max_items: int | None = None) -> List[str]: + values: List[str] = [] + seen = set() + for item in items: + text = _clean_text(item) + if not text: + continue + key = _normalize(text) + if key in seen: + continue + seen.add(key) + values.append(text) + if max_items is not None and len(values) >= max_items: + break + return values + + +def _apply_typed_tunnel_annotations(records: List[SessionMemoryRecordV2], *, source_text: str = "") -> List[SessionMemoryRecordV2]: + for record in records: + annotate_memory_record(record, source_text=source_text) + return records + + +def _estimate_tokens_from_hits(hits: Sequence[MemoryHit]) -> int: + total = 0 + for hit in hits: + total += _estimate_tokens(hit.value) + total += sum(_estimate_tokens(anchor) for anchor in hit.anchors) + return total + + +def _float_env(name: str, default: float) -> float: + raw = _clean_text(os.getenv(name, "")) + if not raw: + return float(default) + try: + return float(raw) + except ValueError: + return float(default) + + +_GRAPH_PROMPT_MAX_CHARS = 12_000 +_GRAPH_PROMPT_MAX_HITS = 8 +_GRAPH_PROMPT_MAX_ACTIVE_SLOTS = 12 +_GRAPH_PROMPT_MAX_RELATIONS = 10 +_HYBRID_SELECTED_EVENT_FLOOR = 8 +_HYBRID_SELECTED_PATH_CAP = 6 +_HYBRID_TEMPORAL_PATH_CAP = 2 +_HYBRID_PROFILE_PATH_CAP = 3 +_MEMORY_ROUTER_GUIDED_MODES = {"guided", "route", "routing", "enforce"} +_MEMORY_ROUTER_FORCE_MODES = {"force", "forced"} +_MEMORY_ROUTER_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"} +_MEMORY_ROUTER_OBSERVE_MODES = {"observe", "observer", "telemetry", "shadow"} +_MEMORY_ROUTER_DEFAULT_THRESHOLD = 0.55 +_MEMORY_ROUTER_DEFAULT_MARGIN = 0.08 +_INJECTION_PLANNER_GUIDED_MODES = {"guided", "route", "routing", "enforce"} +_INJECTION_PLANNER_FORCE_MODES = {"force", "forced"} +_INJECTION_PLANNER_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"} +_INJECTION_PLANNER_OBSERVE_MODES = {"observe", "observer", "telemetry", "shadow"} +_TEMPORAL_LAYER_DISABLED_MODES = {"off", "disabled", "none", "false", "0"} +_TEMPORAL_ROUTER_DEFAULT_WRITER_MIN_CONFIDENCE = 0.72 +_TEMPORAL_ROUTER_DEFAULT_QUERY_MIN_CONFIDENCE = 0.85 +_TEMPORAL_ROUTER_DEFAULT_QUERY_INTENT_MIN_CONFIDENCE = 0.60 +_EMBEDDER_INDEX_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"} +_EMBEDDER_INDEX_BGE_M3_MODES = {"bge", "bge_m3", "bge-m3", "baai_bge_m3", "baai/bge-m3"} +_EMBEDDER_INDEX_VERSION = "write_hash_sparse_v1" +_EMBEDDER_INDEX_BGE_M3_VERSION = "write_bge_m3_dense_sparse_v1" +_EMBEDDER_MODEL_CACHE: Dict[str, Any] = {} +_HYBRID_SYMBOLIC_STOPWORDS = { + "a", + "an", + "and", + "are", + "as", + "at", + "be", + "did", + "do", + "does", + "event", + "for", + "from", + "had", + "has", + "have", + "how", + "in", + "is", + "of", + "on", + "or", + "should", + "the", + "to", + "turn", + "was", + "were", + "what", + "when", + "where", + "which", + "who", + "why", + "will", + "with", +} + + +def _coerce_memory_router_scores(payload: Mapping[str, Any]) -> Dict[str, float]: + raw_scores = dict(payload.get("memory_router_scores", {}) or {}) + scores: Dict[str, float] = {} + for layer in MEMORY_ROUTER_LAYERS: + try: + scores[layer] = float(raw_scores.get(layer, 0.0)) + except (TypeError, ValueError): + scores[layer] = 0.0 + return scores + + +def _memory_router_decision( + payload: Mapping[str, Any], + *, + mode: str, + threshold: float, + margin: float, +) -> Dict[str, Any]: + normalized_mode = _normalize(mode) + scores = _coerce_memory_router_scores(payload) + has_scores = bool(scores) and any(layer in dict(payload.get("memory_router_scores", {}) or {}) for layer in MEMORY_ROUTER_LAYERS) + if not has_scores: + return { + "memory_router_enabled": False, + "memory_router_guided": False, + "memory_router_reason": "no_router_scores", + "memory_router_scores": {}, + "memory_router_top_layers": [], + "memory_router_active_layers": [], + "memory_router_score_spread": 0.0, + "memory_router_confidence": 0.0, + } + ranked_layers = [ + layer + for layer, _ in sorted( + scores.items(), + key=lambda item: (-float(item[1]), item[0]), + ) + ] + score_values = list(scores.values()) + score_spread = max(score_values) - min(score_values) + confidence = max(abs(float(score) - 0.5) for score in score_values) + resolved_threshold = max(0.0, min(1.0, float(threshold or _MEMORY_ROUTER_DEFAULT_THRESHOLD))) + resolved_margin = max(0.0, min(0.5, float(margin or _MEMORY_ROUTER_DEFAULT_MARGIN))) + active_layers = [ + layer + for layer in ranked_layers + if float(scores.get(layer, 0.0)) >= resolved_threshold + ] + confident = bool(active_layers) and (score_spread >= resolved_margin or confidence >= resolved_margin) + guided_requested = normalized_mode in _MEMORY_ROUTER_GUIDED_MODES + forced = normalized_mode in _MEMORY_ROUTER_FORCE_MODES + guided = bool(forced or (guided_requested and confident)) + if forced and not active_layers and ranked_layers: + active_layers = [ranked_layers[0]] + if active_layers and "event" not in active_layers: + active_layers = ["event", *active_layers] + reason = "observe" + if normalized_mode in _MEMORY_ROUTER_DISABLED_MODES: + reason = "disabled" + elif normalized_mode in _MEMORY_ROUTER_OBSERVE_MODES: + reason = "observe" + elif forced: + reason = "forced" + elif guided_requested and not confident: + reason = "low_confidence" + elif guided: + reason = "guided" + return { + "memory_router_enabled": True, + "memory_router_guided": bool(guided and normalized_mode not in _MEMORY_ROUTER_DISABLED_MODES), + "memory_router_reason": reason, + "memory_router_scores": scores, + "memory_router_top_layers": ranked_layers, + "memory_router_active_layers": active_layers, + "memory_router_score_spread": round(float(score_spread), 6), + "memory_router_confidence": round(float(confidence), 6), + "memory_router_threshold": round(float(resolved_threshold), 6), + "memory_router_margin": round(float(resolved_margin), 6), + "memory_router_mode": normalized_mode or "observe", + } + + +def _memory_router_allows(decision: Mapping[str, Any], *layers: str) -> bool: + if not bool(decision.get("memory_router_guided")): + return True + requested = {_normalize(layer) for layer in layers if _normalize(layer)} + if not requested: + return True + if "event" in requested: + return True + active = {_normalize(layer) for layer in list(decision.get("memory_router_active_layers", []) or [])} + return bool(active & requested) + + +def _hybrid_symbolic_tokens(value: Any) -> List[str]: + return [ + token + for token in _tokenize(value) + if token and token not in _HYBRID_SYMBOLIC_STOPWORDS and not re.fullmatch(r"\d+", token) + ] + + +_PATH_UTILITY_STOPWORDS = set(_HYBRID_SYMBOLIC_STOPWORDS) | { + "actually", + "also", + "bit", + "feel", + "feels", + "get", + "guess", + "i", + "im", + "it", + "its", + "just", + "kind", + "like", + "maybe", + "me", + "more", + "much", + "my", + "really", + "something", + "still", + "think", + "thats", + "there", + "this", + "way", + "would", +} + + +def _path_utility_tokens(value: Any) -> List[str]: + return [ + token + for token in _hybrid_symbolic_tokens(value) + if token and token not in _PATH_UTILITY_STOPWORDS + ] + + +_PROFILE_QUERY_ALIAS_GROUPS: tuple[set[str], ...] = ( + { + "accessory", + "accessories", + "bag", + "camera", + "cameras", + "equipment", + "flash", + "gear", + "lens", + "lenses", + "photo", + "photography", + "sony", + "tripod", + }, + { + "app", + "apps", + "dashboard", + "interface", + "layout", + "panel", + "software", + "tool", + "tools", + "ui", + "workflow", + }, + { + "diet", + "drink", + "food", + "meal", + "restaurant", + "snack", + "taste", + }, + { + "background", + "career", + "job", + "occupation", + "position", + "previous", + "profession", + "role", + "worked", + "work", + }, +) + + +_PROFILE_QUERY_GENERIC_TOKENS = set(_PATH_UTILITY_STOPWORDS) | { + "able", + "about", + "any", + "anything", + "based", + "best", + "can", + "complement", + "could", + "current", + "give", + "help", + "information", + "looking", + "make", + "need", + "please", + "recommend", + "recommendation", + "recommendations", + "should", + "some", + "suggest", + "suggestion", + "suggestions", + "tell", + "that", + "using", + "you", +} + + +def _profile_query_expanded_tokens(value: Any) -> set[str]: + tokens = set(_path_utility_tokens(value)) + expanded = set(tokens) + for group in _PROFILE_QUERY_ALIAS_GROUPS: + if tokens & group: + expanded.update(group) + return expanded + + +def _profile_specific_tokens(tokens: Iterable[Any]) -> set[str]: + return { + _normalize(token) + for token in tokens + if _clean_text(token) + and _normalize(token) not in _PROFILE_QUERY_GENERIC_TOKENS + and not re.fullmatch(r"\d+", _normalize(token)) + } + + +def _profile_hit_match_score(query_tokens: set[str], expanded_query_tokens: set[str], hit: MemoryHit) -> tuple[float, List[str], List[str]]: + metadata = dict(hit.metadata or {}) + payload_parts: List[Any] = [ + hit.category, + hit.relation, + hit.slot_key, + hit.value, + *list(hit.anchors or []), + metadata.get("profile_summary", ""), + metadata.get("profile_value", ""), + metadata.get("profile_type", ""), + metadata.get("profile_domain", ""), + metadata.get("profile_domain_label", ""), + metadata.get("semantic_slot", ""), + metadata.get("subject", ""), + metadata.get("extracted_subject", ""), + metadata.get("profile_cluster_domains", []), + metadata.get("profile_cluster_types", []), + metadata.get("profile_cluster_route_terms", []), + metadata.get("profile_route_terms", []), + metadata.get("profile_support_values", []), + ] + payload_text = " ".join(str(item) for item in payload_parts) + record_raw_tokens = set(_path_utility_tokens(payload_text)) + record_tokens = _profile_query_expanded_tokens(payload_text) + specific_query_tokens = _profile_specific_tokens(query_tokens) + raw_overlap_tokens = sorted(specific_query_tokens & _profile_specific_tokens(record_raw_tokens)) + expanded_overlap_tokens = sorted(_profile_specific_tokens(expanded_query_tokens) & _profile_specific_tokens(record_tokens)) + overlap_tokens = raw_overlap_tokens or expanded_overlap_tokens + overlap_ratio = float(len(overlap_tokens)) / float(max(1, len(specific_query_tokens))) + profile_type = _normalize(metadata.get("profile_type", "")) + source_kind = _normalize(hit.source_kind) + source_bonus = 0.0 + if source_kind in {"public_dialog_preference", "public_dialog_goal", "public_dialog_profile"}: + source_bonus += 0.08 + if bool(metadata.get("profile_candidate_status") == "consolidated"): + source_bonus += 0.04 + if bool(metadata.get("profile_cluster_node")): + source_bonus -= 0.16 + type_bonus = 0.0 + if profile_type in {"preference", "goal", "setup", "usage_context"} and ( + {"recommend", "suggest", "suited", "accessory", "accessories", "gear", "equipment", "setup", "current"} & query_tokens + ): + type_bonus += 0.10 + if profile_type in {"setup", "usage_context"} and {"current", "setup", "profile"} & query_tokens: + type_bonus += 0.08 + raw_bonus = 0.18 * len(raw_overlap_tokens) + expanded_bonus = 0.08 * max(0, len(expanded_overlap_tokens) - len(raw_overlap_tokens)) + score = raw_bonus + expanded_bonus + (0.58 * overlap_ratio) + source_bonus + type_bonus + return round(score, 6), overlap_tokens, raw_overlap_tokens + + +def _bounded_event_id_union(*groups: Iterable[Any], max_items: int) -> List[str]: + return _dedupe((item for group in groups for item in group), max_items=max(1, int(max_items))) + + +def _symbolic_recall_event_ids( + query: str, + runtime_graph: Mapping[str, Any], + *, + grouped_hits: Mapping[str, Sequence[MemoryHit]], + limit: int, +) -> List[str]: + question_features = extract_question_features(query) + query_tokens = set(_hybrid_symbolic_tokens(question_features.get("question_anchor_tokens", []) or query)) + if not query_tokens: + return [] + nodes_by_id = { + _clean_text(node.get("id", "")): dict(node) + for node in list(runtime_graph.get("nodes", []) or []) + if _clean_text(node.get("id", "")) + } + event_payloads: Dict[str, List[str]] = {} + for node_id, node in nodes_by_id.items(): + if _clean_text(node.get("type", "")) == "event": + metadata = dict(node.get("metadata", {}) or {}) + teacher_fields = dict(node.get("teacher_fields", {}) or {}) + event_payloads.setdefault(node_id, []).extend( + [ + node.get("text", ""), + node.get("speaker", ""), + node.get("slot_key", ""), + node.get("target_status", ""), + node.get("profile_type", ""), + node.get("profile_value", ""), + *teacher_fields.values(), + *metadata.values(), + ] + ) + for path in list(runtime_graph.get("paths", []) or []): + event_id = _clean_text(path.get("event_id", "")) + support_node = nodes_by_id.get(_clean_text(path.get("target", "")), {}) + if event_id and support_node: + event_payloads.setdefault(event_id, []).append(support_node.get("text", "")) + for event_id, group_hits in grouped_hits.items(): + payloads = event_payloads.setdefault(_clean_text(event_id), []) + for hit in list(group_hits or []): + metadata = dict(hit.metadata or {}) + payloads.extend([hit.value, hit.slot_key, hit.category, hit.relation, *hit.anchors, *metadata.values()]) + + scored_events: List[tuple[str, float]] = [] + for event_id, payloads in event_payloads.items(): + event_tokens = set(_hybrid_symbolic_tokens(payloads)) + overlap = query_tokens & event_tokens + if not overlap: + continue + overlap_ratio = float(len(overlap)) / float(max(1, len(query_tokens))) + event_node = nodes_by_id.get(event_id, {}) + turn_index = int(event_node.get("turn_index", 0) or 0) + scored_events.append((event_id, (len(overlap) * 4.0) + overlap_ratio + min(turn_index, 1000) * 0.000001)) + return [ + event_id + for event_id, _ in sorted(scored_events, key=lambda item: (-float(item[1]), item[0])) + ][: max(1, int(limit))] + + +_EMBEDDER_INDEX_METADATA_TEXT_KEYS = ( + "raw_text", + "source_turn_text", + "source_span", + "event_phrase", + "event_summary", + "profile_value", + "profile_summary", + "profile_type", + "profile_domain", + "profile_domain_label", + "semantic_slot", + "target_status", + "subject", + "subject_signature", + "canonical_slot_key", + "resource_key", + "resolved_date", + "resolved_time_value", + "time_value", + "time_display_value", + "time_granularity", + "speaker", + "session_name", + "topic_label", + "topic_bucket_id", + "origin_query", + "writeback_class", + "depth_layer", + "memory_chain_depth_layer", +) +_EMBEDDER_INDEX_METADATA_LIST_KEYS = ( + "topic_keywords", + "profile_route_terms", + "profile_cluster_route_terms", + "profile_support_values", + "evidence_anchors", + "support_memory_ids", + "support_fact_refs", + "support_path_refs", +) + + +def _embedder_index_enabled(mode: Any) -> bool: + return _normalize(mode) not in _EMBEDDER_INDEX_DISABLED_MODES + + +def _embedder_index_uses_bge_m3(mode: Any) -> bool: + return _normalize(mode).replace("-", "_") in {item.replace("-", "_") for item in _EMBEDDER_INDEX_BGE_M3_MODES} + + +def _embedder_index_version_for_mode(mode: Any) -> str: + return _EMBEDDER_INDEX_BGE_M3_VERSION if _embedder_index_uses_bge_m3(mode) else _EMBEDDER_INDEX_VERSION + + +def _embedder_index_text_items(value: Any, *, max_items: int = 64) -> List[str]: + items: List[str] = [] + + def visit(item: Any) -> None: + if len(items) >= max_items: + return + if item is None: + return + if isinstance(item, Mapping): + for key, nested in list(item.items()): + if len(items) >= max_items: + break + key_text = _clean_text(key) + if isinstance(nested, (str, int, float, bool)): + value_text = _clean_text(nested)[:800] + if value_text: + items.append(f"{key_text} {value_text}".strip()) + else: + visit(nested) + return + if isinstance(item, (list, tuple, set)): + for nested in list(item): + if len(items) >= max_items: + break + visit(nested) + return + text = _clean_text(item)[:800] + if text: + items.append(text) + + visit(value) + return items[:max_items] + + +def _embedder_index_term_weights(value: Any, *, max_terms: int = 96) -> Dict[str, float]: + text = _clean_text(value) + if not text: + return {} + counts: Dict[str, float] = {} + for token in _path_utility_tokens(text): + token = _normalize(token) + if not token or len(token) > 64: + continue + counts[token] = counts.get(token, 0.0) + 1.0 + normalized_text = _normalize(text) + cjk_chars = [char for char in normalized_text if "\u4e00" <= char <= "\u9fff"] + for width, weight in ((2, 1.35), (3, 1.15)): + if len(cjk_chars) < width: + continue + for index in range(0, len(cjk_chars) - width + 1): + gram = "".join(cjk_chars[index : index + width]) + if gram: + counts[gram] = counts.get(gram, 0.0) + weight + if not counts: + return {} + ranked = sorted(counts.items(), key=lambda item: (-float(item[1]), item[0]))[: max(1, int(max_terms or 1))] + norm = math.sqrt(sum(float(weight) * float(weight) for _, weight in ranked)) or 1.0 + return {term: round(float(weight) / norm, 6) for term, weight in ranked} + + +def _embedder_dense_vectors_for_texts(texts: Sequence[str], *, mode: str) -> tuple[List[List[float]], Dict[str, Any]]: + normalized_mode = _normalize(mode) + if not _embedder_index_uses_bge_m3(normalized_mode): + return [[] for _ in texts], {"write_embedder_dense_enabled": False} + clean_texts = [_clean_text(text) for text in texts] + metadata: Dict[str, Any] = { + "write_embedder_dense_enabled": False, + "write_embedder_dense_backend": "bge_m3_transformers", + "write_embedder_dense_model": _clean_text(os.getenv("TMCRA_EMBEDDER_MODEL_PATH", "")) or "BAAI/bge-m3", + } + if not any(clean_texts): + metadata["write_embedder_dense_error"] = "empty_texts" + return [[] for _ in texts], metadata + try: + import torch # type: ignore + from transformers import AutoModel, AutoTokenizer # type: ignore + except Exception as exc: + metadata["write_embedder_dense_error"] = f"dependency_unavailable:{exc.__class__.__name__}" + return [[] for _ in texts], metadata + model_name = metadata["write_embedder_dense_model"] + device = _clean_text(os.getenv("TMCRA_EMBEDDER_DEVICE", "")) + if not device: + device = "cuda" if bool(getattr(torch, "cuda", None) and torch.cuda.is_available()) else "cpu" + try: + max_length = max(64, int(os.getenv("TMCRA_EMBEDDER_MODEL_MAX_LENGTH", "512") or 512)) + except (TypeError, ValueError): + max_length = 512 + cache_key = f"bge_m3::{model_name}::{device}::{max_length}" + try: + cached = _EMBEDDER_MODEL_CACHE.get(cache_key) + if cached is None: + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) + model.to(device) + model.eval() + cached = (tokenizer, model) + _EMBEDDER_MODEL_CACHE[cache_key] = cached + tokenizer, model = cached + encoded = tokenizer( + clean_texts, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + encoded = {key: value.to(device) for key, value in encoded.items()} + with torch.no_grad(): + output = model(**encoded) + hidden = output.last_hidden_state + mask = encoded.get("attention_mask") + if mask is not None: + mask = mask.unsqueeze(-1).expand(hidden.size()).float() + pooled = torch.sum(hidden * mask, dim=1) / torch.clamp(mask.sum(dim=1), min=1e-9) + else: + pooled = hidden[:, 0] + pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) + vectors = [ + [round(float(value), 6) for value in row.detach().cpu().tolist()] + for row in pooled + ] + metadata.update( + { + "write_embedder_dense_enabled": True, + "write_embedder_dense_device": device, + "write_embedder_dense_dim": int(len(vectors[0]) if vectors else 0), + "write_embedder_dense_max_length": int(max_length), + } + ) + return vectors, metadata + except Exception as exc: + metadata["write_embedder_dense_error"] = f"{exc.__class__.__name__}:{_clean_text(exc)[:240]}" + return [[] for _ in texts], metadata + + +def _prewarm_embedder_dense_if_requested(*, mode: str) -> Dict[str, Any]: + flag = _normalize(os.getenv("TMCRA_EMBEDDER_PREWARM", "")) + if flag in _EMBEDDER_INDEX_DISABLED_MODES or flag not in {"1", "true", "yes", "on", "auto"}: + return {"embedder_prewarm_enabled": False} + normalized_mode = _normalize(mode) + if not _embedder_index_uses_bge_m3(normalized_mode): + return { + "embedder_prewarm_enabled": False, + "embedder_prewarm_reason": "mode_not_dense", + "embedder_prewarm_mode": normalized_mode or "off", + } + warmup_text = _clean_text(os.getenv("TMCRA_EMBEDDER_PREWARM_TEXT", "")) or "tmcra memory retrieval warmup" + vectors, metadata = _embedder_dense_vectors_for_texts([warmup_text], mode=normalized_mode) + return { + "embedder_prewarm_enabled": bool(vectors and vectors[0]), + "embedder_prewarm_mode": normalized_mode, + "embedder_prewarm_dense_enabled": bool(metadata.get("write_embedder_dense_enabled")), + "embedder_prewarm_dense_device": metadata.get("write_embedder_dense_device", ""), + "embedder_prewarm_dense_error": metadata.get("write_embedder_dense_error", ""), + } + + +def _embedder_index_record_text(record: SessionMemoryRecordV2, *, turn_text: str = "") -> str: + metadata = dict(record.metadata or {}) + parts: List[str] = [ + record.category, + record.slot_key, + record.value, + record.value, + record.value, + record.relation, + *list(record.anchor_concepts or []), + *list(record.anchor_concepts or []), + *list(record.evidence_anchors or []), + ] + has_record_evidence_text = any( + _clean_text(metadata.get(key, "")) + for key in ("raw_text", "source_turn_text", "source_span", "event_phrase", "profile_value", "profile_summary") + ) + if turn_text and not has_record_evidence_text: + parts.append(_clean_text(turn_text)[:1000]) + for key in _EMBEDDER_INDEX_METADATA_TEXT_KEYS: + value = metadata.get(key) + if value: + parts.append(f"{key} {' '.join(_embedder_index_text_items(value, max_items=12))}".strip()) + for key in _EMBEDDER_INDEX_METADATA_LIST_KEYS: + value = metadata.get(key) + if value: + parts.append(f"{key} {' '.join(_embedder_index_text_items(value, max_items=24))}".strip()) + return _clean_text(" ".join(_clean_text(part) for part in parts if _clean_text(part))) + + +def _apply_write_embedder_index_to_graph( + graph: SessionMemoryGraphV2, + *, + stored_ids: Sequence[str], + turn_text: str, + turn_index: int, + mode: str, + max_terms: int, +) -> Dict[str, Any]: + normalized_mode = _normalize(mode) + index_version = _embedder_index_version_for_mode(normalized_mode) + metadata: Dict[str, Any] = { + "write_embedder_index_enabled": False, + "write_embedder_index_mode": normalized_mode or "off", + "write_embedder_index_version": index_version, + "write_embedder_index_record_count": 0, + } + if not _embedder_index_enabled(normalized_mode) or not stored_ids: + return metadata + indexed_ids: List[str] = [] + index_rows: List[tuple[SessionMemoryRecordV2, str, Dict[str, float]]] = [] + for memory_id in _dedupe(stored_ids): + record = getattr(graph, "records_by_id", {}).get(memory_id) + if record is None: + continue + index_text = _embedder_index_record_text(record, turn_text=turn_text) + terms = _embedder_index_term_weights( + index_text, + max_terms=max_terms, + ) + if not terms: + continue + index_rows.append((record, index_text, terms)) + dense_vectors, dense_metadata = _embedder_dense_vectors_for_texts( + [index_text for _, index_text, _ in index_rows], + mode=normalized_mode, + ) + metadata.update(dense_metadata) + for row_index, (record, _, terms) in enumerate(index_rows): + dense_vector = dense_vectors[row_index] if row_index < len(dense_vectors) else [] + record_metadata = dict(record.metadata or {}) + record_metadata.update( + { + "write_embedder_index_enabled": True, + "write_embedder_index_mode": normalized_mode, + "write_embedder_index_version": index_version, + "write_embedder_index_source_turn": int(turn_index), + "write_embedder_index_term_count": int(len(terms)), + "write_embedder_index_terms": dict(terms), + "write_embedder_index_top_terms": list(terms.keys())[:24], + } + ) + if dense_vector: + record_metadata.update( + { + "write_embedder_dense_enabled": True, + "write_embedder_dense_backend": dense_metadata.get("write_embedder_dense_backend", ""), + "write_embedder_dense_model": dense_metadata.get("write_embedder_dense_model", ""), + "write_embedder_dense_dim": int(len(dense_vector)), + "write_embedder_dense_vector": list(dense_vector), + } + ) + elif _embedder_index_uses_bge_m3(normalized_mode): + record_metadata.update( + { + "write_embedder_dense_enabled": False, + "write_embedder_dense_error": dense_metadata.get("write_embedder_dense_error", "dense_vector_unavailable"), + } + ) + record.metadata = record_metadata + indexed_ids.append(record.memory_id) + metadata.update( + { + "write_embedder_index_enabled": bool(indexed_ids), + "write_embedder_index_record_count": int(len(indexed_ids)), + "write_embedder_index_record_ids": list(indexed_ids[:24]), + } + ) + return metadata + + +def _embedder_index_recall_event_ids( + query: str, + *, + grouped_hits: Mapping[str, Sequence[MemoryHit]], + mode: str, + limit: int, + max_terms: int, +) -> Dict[str, Any]: + normalized_mode = _normalize(mode) + index_version = _embedder_index_version_for_mode(normalized_mode) + metadata: Dict[str, Any] = { + "embedder_index_recall_enabled": False, + "embedder_index_recall_mode": normalized_mode or "off", + "embedder_index_recall_version": index_version, + "embedder_index_event_ids": [], + "embedder_index_event_scores": {}, + "embedder_index_record_count": 0, + } + if not _embedder_index_enabled(normalized_mode): + return {"event_ids": [], "metadata": metadata} + query_terms = _embedder_index_term_weights(query, max_terms=max_terms) + query_vectors, query_dense_metadata = _embedder_dense_vectors_for_texts([query], mode=normalized_mode) + query_vector = query_vectors[0] if query_vectors else [] + metadata.update( + { + "embedder_dense_recall_enabled": bool(query_vector), + "embedder_dense_recall_backend": query_dense_metadata.get("write_embedder_dense_backend", ""), + "embedder_dense_recall_model": query_dense_metadata.get("write_embedder_dense_model", ""), + "embedder_dense_recall_error": query_dense_metadata.get("write_embedder_dense_error", ""), + } + ) + if not query_terms and not query_vector: + metadata["embedder_index_recall_reason"] = "empty_query_terms" + return {"event_ids": [], "metadata": metadata} + scored_events: List[tuple[str, float, int]] = [] + indexed_record_count = 0 + dense_record_count = 0 + for event_id, group_hits in grouped_hits.items(): + event_score = 0.0 + event_turn = 0 + for hit in list(group_hits or []): + hit_metadata = dict(hit.metadata or {}) + raw_terms = hit_metadata.get("write_embedder_index_terms") + raw_vector = hit_metadata.get("write_embedder_dense_vector") + if not isinstance(raw_terms, Mapping) and not isinstance(raw_vector, list): + continue + indexed_record_count += int(isinstance(raw_terms, Mapping)) + dense_record_count += int(isinstance(raw_vector, list) and bool(raw_vector)) + sparse_score = 0.0 + if isinstance(raw_terms, Mapping): + for term, query_weight in query_terms.items(): + try: + sparse_score += float(query_weight) * float(raw_terms.get(term, 0.0) or 0.0) + except (TypeError, ValueError): + continue + dense_score = 0.0 + if query_vector and isinstance(raw_vector, list) and raw_vector: + for query_value, record_value in zip(query_vector, raw_vector): + try: + dense_score += float(query_value) * float(record_value) + except (TypeError, ValueError): + continue + score = max(float(sparse_score), float(dense_score) + (0.15 * float(sparse_score) if dense_score > 0.0 else 0.0)) + if score <= 0.0: + continue + event_score = max(event_score, float(score)) + event_turn = max(event_turn, int(hit.turn_index or 0)) + if event_score > 0.0: + scored_events.append((_clean_text(event_id), round(event_score, 6), event_turn)) + scored_events.sort(key=lambda item: (-float(item[1]), -int(item[2]), item[0])) + selected = scored_events[: max(1, int(limit or 1))] + event_ids = [event_id for event_id, _, _ in selected if event_id] + event_scores = {event_id: score for event_id, score, _ in selected if event_id} + metadata.update( + { + "embedder_index_recall_enabled": True, + "embedder_index_event_ids": list(event_ids), + "embedder_index_event_scores": dict(event_scores), + "embedder_index_query_terms": list(query_terms.keys())[:32], + "embedder_index_record_count": int(indexed_record_count), + "embedder_dense_record_count": int(dense_record_count), + } + ) + return {"event_ids": event_ids, "metadata": metadata} + + +_IDENTIFIER_GENERIC_TOKENS = { + "api", + "agent", + "code", + "codename", + "context", + "debug", + "goal", + "memory", + "model", + "name", + "project", + "retrieval", + "runtime", + "session", + "target", + "test", + "turn", +} +_IDENTIFIER_REQUEST_RE = re.compile( + r"\b(code\s*name|codename|identifier|alias|project\s+name|model\s+name|api\s+name)\b|" + r"(代号|编号|标识符|名称|名字|别名|项目名|模型名|接口名)", + flags=re.IGNORECASE, +) +_IDENTIFIER_TOKEN_RE = re.compile(r"\b[A-Za-z][A-Za-z0-9_-]{2,63}\b") + + +def _query_identifier_tokens(query: Any) -> List[str]: + text = _clean_text(query) + tokens: List[str] = [] + for match in _IDENTIFIER_TOKEN_RE.finditer(text): + token = match.group(0) + lowered = token.lower() + if lowered in _IDENTIFIER_GENERIC_TOKENS: + continue + has_inner_upper = any(ch.isupper() for ch in token[1:]) + has_digit = any(ch.isdigit() for ch in token) + has_joiner = "-" in token or "_" in token + if has_inner_upper or has_digit or has_joiner: + tokens.append(token) + return _dedupe(tokens, max_items=8) + + +def _query_requests_identifier_fact(query: Any) -> bool: + return bool(_IDENTIFIER_REQUEST_RE.search(_clean_text(query))) + + +def _hit_text_for_identifier_match(hit: MemoryHit) -> str: + metadata = dict(hit.metadata or {}) + values = [ + hit.memory_id, + hit.category, + hit.value, + hit.relation, + hit.slot_key, + hit.source_kind, + *list(hit.anchors or []), + metadata.get("raw_text", ""), + metadata.get("source_turn_text", ""), + metadata.get("source_span", ""), + metadata.get("event_phrase", ""), + metadata.get("profile_value", ""), + metadata.get("target_status", ""), + ] + return " ".join(_clean_text(value) for value in values if _clean_text(value)) + + +def _copy_hit_with_identifier_boost(hit: MemoryHit, *, score: float, reasons: Sequence[str], matched_tokens: Sequence[str]) -> MemoryHit: + metadata = dict(hit.metadata or {}) + metadata.update( + { + "identifier_protected": True, + "identifier_protected_score": round(float(score), 6), + "identifier_protected_reasons": list(reasons)[:8], + "identifier_protected_matched_tokens": list(matched_tokens)[:8], + "identifier_protected_original_score": round(float(hit.score), 6), + } + ) + return MemoryHit( + memory_id=hit.memory_id, + category=hit.category, + value=hit.value, + relation=hit.relation, + anchors=list(hit.anchors), + score=max(float(hit.score), round(1.0 + float(score), 6)), + source_kind=hit.source_kind, + slot_key=hit.slot_key, + state=hit.state, + turn_index=int(hit.turn_index), + metadata=metadata, + ) + + +def _identifier_protected_hits( + *, + query: str, + final_hits: Sequence[MemoryHit], + candidate_hits: Sequence[MemoryHit], + top_k: int, +) -> Dict[str, Any]: + identifier_tokens = _query_identifier_tokens(query) + identifier_request = _query_requests_identifier_fact(query) + if not identifier_tokens and not identifier_request: + return {"enabled": False, "hits": list(final_hits), "promoted_hits": [], "metadata": {"identifier_protected_enabled": False}} + + query_text = _clean_text(query) + pool: Dict[str, MemoryHit] = {} + for hit in list(final_hits) + list(candidate_hits): + key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}" + if key and key not in pool: + pool[key] = hit + + scored: List[tuple[float, MemoryHit, List[str], List[str]]] = [] + for hit in pool.values(): + hit_text = _hit_text_for_identifier_match(hit) + hit_lower = hit_text.lower() + matched_tokens = [token for token in identifier_tokens if token.lower() in hit_lower] + reasons: List[str] = [] + score = 0.0 + if matched_tokens: + score += 12.0 + float(len(matched_tokens)) + reasons.append("exact_identifier_match") + if identifier_request: + if any(term in hit_lower for term in ("codename", "code name", "identifier", "alias", "project_codename")): + score += 8.0 + reasons.append("identifier_field_match") + if any(term in hit_text for term in ("代号", "编号", "标识符", "名称", "名字", "别名")): + score += 8.0 + reasons.append("identifier_cjk_field_match") + if "项目" in query_text and "项目" in hit_text: + score += 1.5 + reasons.append("project_term_match") + if score <= 0: + continue + scored.append((score, hit, matched_tokens, reasons)) + + scored.sort(key=lambda item: (item[0], float(item[1].score), int(item[1].turn_index or 0)), reverse=True) + promoted = [ + _copy_hit_with_identifier_boost(hit, score=score, reasons=reasons, matched_tokens=matched_tokens) + for score, hit, matched_tokens, reasons in scored[:2] + ] + if not promoted: + return { + "enabled": True, + "hits": list(final_hits), + "promoted_hits": [], + "metadata": { + "identifier_protected_enabled": True, + "identifier_query_tokens": identifier_tokens, + "identifier_request": bool(identifier_request), + "identifier_promoted_count": 0, + }, + } + + promoted_keys = {hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}" for hit in promoted} + merged = list(promoted) + for hit in final_hits: + key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}" + if key in promoted_keys: + continue + merged.append(hit) + if len(merged) >= max(1, int(top_k)): + break + return { + "enabled": True, + "hits": merged[: max(1, int(top_k))], + "promoted_hits": promoted, + "metadata": { + "identifier_protected_enabled": True, + "identifier_query_tokens": identifier_tokens, + "identifier_request": bool(identifier_request), + "identifier_promoted_count": len(promoted), + "identifier_promoted_ids": [hit.memory_id for hit in promoted], + }, + } + + +def _trim_prompt_text(value: Any, *, max_chars: int = 220) -> str: + text = _clean_text(value) + if len(text) <= max_chars: + return text + return f"{text[: max(0, max_chars - 3)].rstrip()}..." + + +def _prompt_hit_payload(hit: MemoryHit) -> Dict[str, Any]: + return { + "memory_id": hit.memory_id, + "slot_key": hit.slot_key, + "category": hit.category, + "value": _trim_prompt_text(hit.value), + "relation": hit.relation, + "anchors": [_trim_prompt_text(anchor, max_chars=80) for anchor in list(hit.anchors)[:4]], + "score": round(float(hit.score), 6), + "state": hit.state, + "turn_index": int(hit.turn_index), + "source_kind": hit.source_kind, + } + + +def _prompt_record_payload(record: SessionMemoryRecordV2) -> Dict[str, Any]: + metadata = dict(record.metadata or {}) + return { + "slot_key": record.slot_key, + "category": record.category, + "value": _trim_prompt_text(record.value), + "state": record.state, + "turn_index": int(record.turn_index), + "anchors": [_trim_prompt_text(anchor, max_chars=80) for anchor in list(record.anchor_concepts)[:4]], + "memory_role": _clean_text(metadata.get("memory_role", "")), + "authority": _clean_text(metadata.get("authority", "")), + } + + +def _graph_prompt_state_summary(graph: SessionMemoryGraphV2, retrieval: MemoryRetrieval) -> Dict[str, Any]: + active_slots: List[Dict[str, Any]] = [] + for slot_key, record_id in list(graph.slot_heads.items())[:_GRAPH_PROMPT_MAX_ACTIVE_SLOTS]: + record = graph.records_by_id.get(record_id) + if record is None: + continue + active_slots.append(_prompt_record_payload(record)) + top_hits = [_prompt_hit_payload(hit) for hit in list(retrieval.hits)[:_GRAPH_PROMPT_MAX_HITS]] + relation_preview = [ + { + "from": _trim_prompt_text(item.get("from", ""), max_chars=72), + "to": _trim_prompt_text(item.get("to", ""), max_chars=72), + "relation": _clean_text(item.get("relation", "")), + } + for item in list(retrieval.relations)[:_GRAPH_PROMPT_MAX_RELATIONS] + ] + summary = { + "records": len(graph.records_by_id), + "active_slots": len(graph.slot_heads), + "turn_index": int(graph.turn_index), + "noise_turn_count": int(graph.noise_turn_count), + "answer_support_events": len(graph.answer_support_log), + "top_hits": top_hits, + "active_slot_records": active_slots, + "relation_preview": relation_preview, + "context_truncated": False, + "truncation_reason": "", + } + truncated = False + truncation_reason = "" + while len(json.dumps(summary, ensure_ascii=False)) > _GRAPH_PROMPT_MAX_CHARS: + if len(summary["top_hits"]) > 1: + summary["top_hits"] = summary["top_hits"][:-1] + truncated = True + truncation_reason = "trimmed_top_hits" + continue + if len(summary["active_slot_records"]) > 1: + summary["active_slot_records"] = summary["active_slot_records"][:-1] + truncated = True + truncation_reason = "trimmed_active_slots" + continue + if len(summary["relation_preview"]) > 2: + summary["relation_preview"] = summary["relation_preview"][:-1] + truncated = True + truncation_reason = "trimmed_relations" + continue + break + summary["context_truncated"] = truncated + summary["truncation_reason"] = truncation_reason + return summary + + +def _state_stats(*, storage_bytes: int, retrieval_context_tokens: int, total_state_tokens: int, **extra: Any) -> Dict[str, Any]: + return { + **extra, + "storage_bytes": int(storage_bytes), + "context_token_estimate": int(retrieval_context_tokens), + "retrieval_context_token_estimate": int(retrieval_context_tokens), + "total_state_token_estimate": int(total_state_tokens), + } + + +def _relation_hit(hit: MemoryHit, *, weight_bias: float = 0.0) -> Dict[str, Any]: + if not hit.anchors: + return {} + anchor = hit.anchors[0] + if not anchor or anchor == hit.value: + return {} + return { + "from": anchor, + "to": hit.value, + "relation": hit.relation, + "weight": round(max(0.25, min(0.98, 0.42 + hit.score * 0.4 + weight_bias)), 6), + "source_kind": hit.source_kind, + "memory_id": hit.memory_id, + } + + +def _raw_hit_to_memory_hit(payload: Dict[str, Any]) -> MemoryHit: + metadata = dict(payload.get("metadata", {}) or {}) + if payload.get("supersedes"): + metadata["supersedes"] = list(payload.get("supersedes", []) or []) + slot_key = stable_slot_key( + category=str(payload.get("category", "")), + value=str(payload.get("value", "")), + anchors=[str(anchor) for anchor in payload.get("anchor_concepts", payload.get("anchors", [])) or [] if _clean_text(anchor)], + slot_key=str(payload.get("slot_key", metadata.get("slot", ""))), + relation=str(payload.get("relation", "related_to")), + metadata=metadata, + ) + return MemoryHit( + memory_id=str(payload.get("memory_id", "")), + category=str(payload.get("category", "")), + value=str(payload.get("value", "")), + relation=str(payload.get("relation", "related_to")), + anchors=[str(anchor) for anchor in payload.get("anchor_concepts", payload.get("anchors", [])) or [] if _clean_text(anchor)], + score=float(payload.get("score", payload.get("relevance", 0.0)) or 0.0), + source_kind=str(payload.get("source_kind", "memory")), + slot_key=slot_key, + state=str(payload.get("state", payload.get("metadata", {}).get("state", "active")) or "active"), + turn_index=int(payload.get("turn_index", 0) or 0), + metadata=metadata, + ) + + +def _restore_hit_scores(hits: List[MemoryHit], scored_lookup: Dict[str, MemoryHit]) -> List[MemoryHit]: + restored: List[MemoryHit] = [] + for hit in hits: + scored = scored_lookup.get(hit.memory_id) + if scored: + hit.score = max(float(hit.score), float(scored.score)) + if not hit.anchors and scored.anchors: + hit.anchors = list(scored.anchors) + restored.append(hit) + return restored + + +def _current_subject_query(query: str) -> bool: + lowered = _normalize(query) + return bool( + _public_query_subject(query) + and any(marker in lowered for marker in ("right now", "current", "currently", "active", "now", "当前", "现在")) + ) + + +def _record_subject_signatures(record: SessionMemoryRecordV2) -> set[str]: + metadata = dict(record.metadata or {}) + signatures = { + _normalize(metadata.get("subject_signature", "")).replace("-", "_"), + _public_subject_signature(metadata.get("subject", "")), + } + canonical_slot_key = _clean_text(metadata.get("canonical_slot_key", "") or record.slot_key) + if ".subject." in canonical_slot_key: + signatures.add(_public_subject_signature(canonical_slot_key.split(".subject.", 1)[-1])) + if ".subject." in record.slot_key: + signatures.add(_public_subject_signature(record.slot_key.split(".subject.", 1)[-1])) + signatures.discard("") + return signatures + + +def _current_subject_protected_hits( + *, + query: str, + graph: SessionMemoryGraphV2, + final_hits: Sequence[MemoryHit], + top_k: int, +) -> Dict[str, Any]: + if not _current_subject_query(query): + return { + "hits": list(final_hits), + "metadata": {"current_subject_resolver_enabled": False}, + } + subject = _public_query_subject(query) + subject_signature = _public_subject_signature(subject) + if not subject_signature: + return { + "hits": list(final_hits), + "metadata": { + "current_subject_resolver_enabled": True, + "current_subject_resolver_reason": "no_subject_signature", + }, + } + def _candidate_record(record: SessionMemoryRecordV2 | None) -> bool: + return bool( + record is not None + and _clean_text(record.source_kind).startswith("public_dialog") + and _normalize(record.category) != "question" + and subject_signature in _record_subject_signatures(record) + ) + + slot_head_candidates = [ + record + for slot_key, memory_id in graph.slot_heads.items() + for record in [graph.records_by_id.get(memory_id)] + if _candidate_record(record) + and subject_signature in _record_subject_signatures(record) + ] + candidates = slot_head_candidates or [ + record + for record in graph.records_by_id.values() + if record.state == "active" and _candidate_record(record) + ] + candidates.sort( + key=lambda record: ( + int( + _normalize((record.metadata or {}).get("target_status", "")) == "current" + or _normalize(record.relation) == "current_subject_value" + ), + int(record.turn_index), + float(record.confidence), + float(record.salience), + ), + reverse=True, + ) + promoted: List[MemoryHit] = [] + for index, record in enumerate(candidates[: max(1, min(2, int(top_k or 1)))], start=1): + metadata = dict(record.metadata or {}) + metadata.update( + { + "current_subject_resolver": True, + "current_subject_resolver_rank": index, + "current_subject_query_subject": subject, + "current_subject_query_signature": subject_signature, + "public_subject_match": True, + "public_subject_overlap": 1.0, + } + ) + promoted.append( + MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=max(float(record.confidence), float(record.salience), 1.75), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=metadata, + ) + ) + if not promoted: + return { + "hits": list(final_hits), + "metadata": { + "current_subject_resolver_enabled": True, + "current_subject_resolver_reason": "no_active_subject_head", + "current_subject_query_subject": subject, + "current_subject_query_signature": subject_signature, + }, + } + promoted_ids = {hit.memory_id for hit in promoted} + merged_tail: List[MemoryHit] = [] + for hit in final_hits: + if hit.memory_id in promoted_ids: + continue + metadata = dict(hit.metadata or {}) + hit_state = _normalize(hit.state) + same_subject = subject_signature in { + _normalize(metadata.get("subject_signature", "")).replace("-", "_"), + _public_subject_signature(metadata.get("subject", "")), + _public_subject_signature(hit.slot_key.split(".subject.", 1)[-1]) if ".subject." in hit.slot_key else "", + _public_subject_signature(_clean_text(metadata.get("canonical_slot_key", "")).split(".subject.", 1)[-1]) + if ".subject." in _clean_text(metadata.get("canonical_slot_key", "")) + else "", + } + if same_subject and hit_state in {"superseded", "evidence", "historical", "stale", "false"}: + continue + merged_tail.append(hit) + merged = [*promoted, *merged_tail] + limit = max(int(top_k or 1), len(promoted)) + return { + "hits": merged[:limit], + "metadata": { + "current_subject_resolver_enabled": True, + "current_subject_resolver_reason": "promoted_active_subject_head", + "current_subject_query_subject": subject, + "current_subject_query_signature": subject_signature, + "current_subject_promoted_memory_ids": [hit.memory_id for hit in promoted], + }, + } + + +def _depth_chain_protected_hits( + *, + query: str, + graph: SessionMemoryGraphV2, + final_hits: Sequence[MemoryHit], + top_k: int, +) -> Dict[str, Any]: + seed_memory_ids = [hit.memory_id for hit in final_hits if hit.memory_id] + chain = graph.depth_chain_for_query( + query, + seed_memory_ids=seed_memory_ids, + top_k=max(3, min(8, int(top_k or 1))), + ) + if not chain.get("enabled") or not chain.get("nodes"): + return { + "hits": list(final_hits), + "metadata": { + "memory_chain_enabled": bool(chain.get("enabled", False)), + "memory_chain_reason": _clean_text(chain.get("reason", "")), + "memory_chain_node_count": 0, + "memory_chain_edge_count": 0, + "memory_chain": chain, + }, + } + seen = {hit.memory_id for hit in final_hits if hit.memory_id} + chain_hits: List[MemoryHit] = [] + for rank, node in enumerate(list(chain.get("nodes", []) or []), start=1): + if not isinstance(node, Mapping): + continue + memory_id = _clean_text(node.get("memory_id", "")) + if not memory_id or memory_id in seen: + continue + payload = dict(node) + metadata = dict(payload.get("metadata", {}) or {}) + metadata.update( + { + "memory_chain_protected": True, + "memory_chain_rank": int(rank), + "memory_chain_subject_signature": _clean_text(chain.get("subject_signature", "")), + "memory_chain_depth_layer": _clean_text(metadata.get("depth_layer", "")) or "core_view", + } + ) + payload["metadata"] = metadata + payload["score"] = max(float(payload.get("score", 0.0) or 0.0), 0.62 - (rank * 0.01)) + chain_hits.append(_raw_hit_to_memory_hit(payload)) + seen.add(memory_id) + limit = max(int(top_k or 1), min(12, int(top_k or 1) + max(0, len(chain_hits)))) + merged = [*list(final_hits), *chain_hits] + return { + "hits": merged[:limit], + "metadata": { + "memory_chain_enabled": True, + "memory_chain_reason": _clean_text(chain.get("reason", "")), + "memory_chain_subject_signature": _clean_text(chain.get("subject_signature", "")), + "memory_chain_node_count": int(chain.get("node_count", 0) or 0), + "memory_chain_edge_count": int(chain.get("edge_count", 0) or 0), + "memory_chain_depth_layers": list(chain.get("depth_layers", []) or []), + "memory_chain": chain, + }, + } + + +def _is_public_dialog_hit(hit: MemoryHit) -> bool: + return _clean_text(hit.source_kind).startswith("public_dialog") + + +def _normalized_runtime_signature(prefix: str, value: str) -> str: + normalized = _normalize(value) + if not normalized: + return "" + return f"{prefix}{normalized.replace('|', '_').replace(':', '_')}" + + +def _runtime_event_key(hit: MemoryHit) -> str: + metadata = dict(hit.metadata or {}) + explicit = _clean_text(metadata.get("event_id", "")) + if explicit: + return explicit + dia_id = _clean_text(metadata.get("dia_id", "")) + if dia_id: + return f"event::{dia_id}" + if not _is_public_dialog_hit(hit): + state_signature = _clean_text(metadata.get("state_signature", "")) + if state_signature: + return _normalized_runtime_signature("event::state::", state_signature) + memory_signature = _clean_text(metadata.get("memory_signature", "")) + if memory_signature: + return _normalized_runtime_signature("event::memory::", memory_signature) + slot_root = _public_slot_root(_clean_text(hit.slot_key)) + if slot_root: + return slot_root + return _clean_text(hit.memory_id) + + +def _runtime_event_turn_index_from_id(event_id: str) -> int: + text = _clean_text(event_id) + if not text: + return 0 + match = re.search(r"(?::|_)(\d+)$", text) + if match: + return int(match.group(1)) + matches = re.findall(r"\d+", text) + return int(matches[-1]) if matches else 0 + + +def _representative_event_hit(group_hits: Sequence[MemoryHit], *, query: str = "") -> MemoryHit | None: + semantic_source_kinds = { + "public_dialog_fact", + "public_dialog_preference", + "public_dialog_goal", + "public_dialog_constraint", + "public_dialog_status", + "public_dialog_profile", + "replacement_memory", + "session_memory", + } + query_tokens = set(_path_utility_tokens(query)) + + def rank(hit: MemoryHit) -> tuple[bool, float, bool, bool, bool, bool, float]: + metadata = dict(hit.metadata or {}) + source_kind = _clean_text(hit.source_kind) + text_parts: List[Any] = [hit.value, hit.slot_key, hit.category] + if source_kind != "public_dialog_turn": + text_parts.extend([metadata.get("source_turn_text", ""), metadata.get("raw_text", "")]) + text = " ".join( + _clean_text(item) + for item in text_parts + if _clean_text(item) + ) + hit_tokens = set(_path_utility_tokens(text)) + overlap = len(query_tokens & hit_tokens) if query_tokens else 0 + has_number = bool(re.search(r"\b\d+\b", text)) + is_semantic = source_kind in semantic_source_kinds or ( + source_kind != "public_dialog_turn" and bool(_clean_text(metadata.get("memory_writer_role", ""))) + ) + direct_semantic_answer = bool(is_semantic and has_number and overlap >= 2) + query_score = float(overlap) + (0.75 if has_number and overlap else 0.0) + return ( + direct_semantic_answer, + query_score, + is_semantic, + source_kind == "public_dialog_event", + source_kind == "public_dialog_turn", + source_kind in {"replacement_memory", "session_memory"}, + float(hit.score), + ) + + ordered = sorted( + list(group_hits), + key=rank, + reverse=True, + ) + return ordered[0] if ordered else None + + +def _event_record_hits_from_graph(graph: SessionMemoryGraphV2, event_id: str) -> List[MemoryHit]: + normalized_event_id = _clean_text(event_id) + if not normalized_event_id: + return [] + hits: List[MemoryHit] = [] + for record in graph.records_by_id.values(): + metadata = dict(record.metadata or {}) + if _clean_text(metadata.get("event_id", "")) != normalized_event_id: + continue + state = _normalize(record.state) + if state not in {"active", "parallel_active", "evidence"}: + continue + hits.append( + MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=max(float(record.confidence), float(record.salience), 0.01), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=metadata, + ) + ) + return hits + +def _hit_matches_path_support(path_type: str, hit: MemoryHit) -> bool: + metadata = dict(hit.metadata or {}) + source_kind = _clean_text(hit.source_kind) + category = _clean_text(hit.category) + relation = _clean_text(hit.relation) + if path_type == "speaker_event_time": + return bool( + source_kind == "public_dialog_time" + or _clean_text(metadata.get("resolved_time_value", "")) + or _clean_text(metadata.get("resolved_date", "")) + or _clean_text(metadata.get("time_value", "")) + or _clean_text(metadata.get("time_display_value", "")) + or _clean_text(metadata.get("time_granularity", "")) + or relation == "event_date" + or category in {"time", "event_time"} + ) + if path_type == "speaker_event_profile": + semantic_slot = _clean_text(metadata.get("semantic_slot", "")) or _clean_text(metadata.get("profile_type", "")) + return bool( + source_kind == "public_dialog_profile" + or semantic_slot in {"identity", "research_topic", "education", "occupation", "profile"} + or _clean_text(metadata.get("profile_value", "")) + or category == "profile" + ) + if path_type == "speaker_event_status": + return bool( + _clean_text(metadata.get("target_status", "")) + or category in {"status", "stage_state"} + or relation == "status_of" + ) + if path_type == "speaker_event_source_turn": + return bool( + source_kind in {"public_dialog_turn", "public_dialog_text", "public_dialog_auxiliary_evidence"} + or _clean_text(metadata.get("raw_text", "")) + or _clean_text(metadata.get("origin_query", "")) + or _clean_text(metadata.get("source_turn_text", "")) + or not _is_public_dialog_hit(hit) + ) + return False + + +def _support_hit_for_path(path_type: str, group_hits: Sequence[MemoryHit]) -> MemoryHit | None: + matching_hits = [hit for hit in group_hits if _hit_matches_path_support(path_type, hit)] + if matching_hits: + matching_hits.sort(key=lambda item: (float(item.score), int(item.turn_index)), reverse=True) + return matching_hits[0] + representative = _representative_event_hit(group_hits) + return representative + + +def _path_support_node_id(path: Dict[str, Any]) -> str: + node_ids = list(path.get("node_ids", []) or []) + if len(node_ids) < 3: + return "" + return _clean_text(node_ids[2]) + + +def _event_ids_from_hits(hits: Sequence[MemoryHit]) -> List[str]: + return _dedupe( + _clean_text(dict(hit.metadata or {}).get("event_id", "")) + for hit in hits + if _clean_text(dict(hit.metadata or {}).get("event_id", "")) + ) + + +def _dia_ids_from_hits(hits: Sequence[MemoryHit]) -> List[str]: + return _dedupe( + _clean_text(dict(hit.metadata or {}).get("dia_id", "")) + for hit in hits + if _clean_text(dict(hit.metadata or {}).get("dia_id", "")) + ) + + +def _final_hit_role_priority(hit: MemoryHit) -> int: + metadata = dict(hit.metadata or {}) + source_kind = _clean_text(hit.source_kind) + semantic_source_kinds = { + "public_dialog_fact", + "public_dialog_preference", + "public_dialog_goal", + "public_dialog_constraint", + "public_dialog_status", + "public_dialog_profile", + "public_dialog_profile_cluster", + "replacement_memory", + "session_memory", + } + if source_kind in semantic_source_kinds or ( + source_kind != "public_dialog_turn" and bool(_clean_text(metadata.get("memory_writer_role", ""))) + ): + return 0 + if bool(metadata.get("profile_first_source_support")): + return 0 + role = _clean_text(metadata.get("evidence_snippet_role", "")) + if role == "selected_path_support": + return 1 + if role == "selected_event_representative": + return 2 + if role == "selected_path_event": + return 3 + return 4 + + +def _coverage_preserving_final_hits( + final_hits: Sequence[MemoryHit], + *, + selected_event_ids: Sequence[str], + top_k: int, +) -> List[MemoryHit]: + """Keep selected-event coverage before filling the remaining prompt budget. + + Learned selection can emit both path-support and event-representative snippets + for the same event. A pure score sort can then drop another selected event at + the top-k boundary, which hides recall/rerank successes from the answer head. + """ + + budget = max(1, int(top_k or 1)) + hits = list(final_hits) + selected_order = [ + _clean_text(event_id) + for event_id in selected_event_ids + if _clean_text(event_id) + ][:budget] + if not selected_order: + return sorted(hits, key=lambda item: float(item.score), reverse=True)[:budget] + hits_by_event: Dict[str, List[MemoryHit]] = {} + for hit in hits: + event_id = _clean_text(dict(hit.metadata or {}).get("event_id", "")) + if event_id: + hits_by_event.setdefault(event_id, []).append(hit) + selected: List[MemoryHit] = [] + used_memory_ids = set() + for event_id in selected_order: + candidates = [hit for hit in hits_by_event.get(event_id, []) if hit.memory_id not in used_memory_ids] + if not candidates: + continue + candidates.sort(key=lambda item: (_final_hit_role_priority(item), -float(item.score), item.memory_id)) + chosen = candidates[0] + selected.append(chosen) + used_memory_ids.add(chosen.memory_id) + if len(selected) >= budget: + return selected + remaining = [ + hit + for hit in hits + if hit.memory_id not in used_memory_ids + ] + remaining.sort(key=lambda item: (-float(item.score), _final_hit_role_priority(item), item.memory_id)) + for hit in remaining: + selected.append(hit) + if len(selected) >= budget: + break + return selected[:budget] + + +def _dominant_answer_type(question_analysis: Dict[str, Any], answer_type_scores: Dict[str, Any]) -> str: + normalized_scores = { + _clean_text(answer_type): float(value or 0.0) + for answer_type, value in dict(answer_type_scores or {}).items() + if _clean_text(answer_type) + } + if bool(question_analysis.get("is_temporal", False)) and normalized_scores.get("time", 0.0) >= (normalized_scores.get("abstain", 0.0) - 0.05): + return "time" + if not bool(question_analysis.get("is_temporal", False)) and normalized_scores: + non_time_scores = { + answer_type: score + for answer_type, score in normalized_scores.items() + if answer_type not in {"time", "abstain"} + } + if non_time_scores: + return max(non_time_scores.items(), key=lambda item: (float(item[1]), item[0]))[0] + if not normalized_scores: + return "time" if bool(question_analysis.get("is_temporal", False)) else "event_text" + return max(normalized_scores.items(), key=lambda item: (float(item[1]), item[0]))[0] + + +def _answer_type_preferred_path_types(question_analysis: Dict[str, Any], answer_type_scores: Dict[str, Any]) -> List[str]: + dominant_answer_type = _dominant_answer_type(question_analysis, answer_type_scores) + if dominant_answer_type == "time" or bool(question_analysis.get("is_temporal", False)): + return ["speaker_event_time", "speaker_event_source_turn", "speaker_event_status", "speaker_event_profile"] + if dominant_answer_type == "profile": + return ["speaker_event_profile", "speaker_event_source_turn", "speaker_event_status", "speaker_event_time"] + if dominant_answer_type == "multi_evidence": + return ["speaker_event_source_turn", "speaker_event_time", "speaker_event_profile", "speaker_event_status"] + return ["speaker_event_source_turn", "speaker_event_time", "speaker_event_profile", "speaker_event_status"] + + +def _reconciled_focused_answer_type( + question_analysis: Dict[str, Any], + answer_type_scores: Dict[str, Any], + model_answer_type: str, +) -> str: + model_type = _clean_text(model_answer_type) + dominant_type = _dominant_answer_type(question_analysis, answer_type_scores) + if bool(question_analysis.get("is_temporal", False)) and model_type not in {"", "time", "abstain"}: + return "time" + if model_type == "time" and not bool(question_analysis.get("is_temporal", False)): + return dominant_type if dominant_type != "time" else "event_text" + return model_type or dominant_type + + +def _path_type_is_focus_compatible(path_type: str, *, focused_answer_type: str, question_analysis: Dict[str, Any]) -> bool: + normalized_path_type = _clean_text(path_type) + normalized_answer_type = _clean_text(focused_answer_type) + if normalized_answer_type == "time" or bool(question_analysis.get("is_temporal", False)): + return normalized_path_type in {"speaker_event_time", "speaker_event_source_turn"} + if normalized_answer_type == "profile": + return normalized_path_type in {"speaker_event_profile", "speaker_event_source_turn"} + return True + + +def _rank_focus_compatible_path_ids( + *, + runtime_paths: Mapping[str, Dict[str, Any]], + selected_event_ids: Sequence[str], + path_scores: Mapping[str, Any], + event_scores: Mapping[str, Any], + temporal_scores: Mapping[str, Any], + question_analysis: Dict[str, Any], + answer_type_scores: Dict[str, Any], + focused_answer_type: str, +) -> List[str]: + selected_event_id_set = {_clean_text(event_id) for event_id in selected_event_ids if _clean_text(event_id)} + preferred_types = _answer_type_preferred_path_types(question_analysis, answer_type_scores) + if _clean_text(focused_answer_type) == "profile" and "speaker_event_profile" not in preferred_types: + preferred_types = ["speaker_event_profile", *preferred_types] + if (_clean_text(focused_answer_type) == "time" or bool(question_analysis.get("is_temporal", False))) and "speaker_event_time" not in preferred_types: + preferred_types = ["speaker_event_time", "speaker_event_source_turn", *preferred_types] + ranked: List[str] = [] + seen = set() + for preferred_type in preferred_types: + candidates: List[tuple[str, float]] = [] + for path_id, path in runtime_paths.items(): + if _clean_text(path.get("type", "")) != preferred_type: + continue + event_id = _clean_text(path.get("event_id", "")) + if selected_event_id_set and event_id not in selected_event_id_set: + continue + if not _path_type_is_focus_compatible(preferred_type, focused_answer_type=focused_answer_type, question_analysis=question_analysis): + continue + support_node_id = _path_support_node_id(path) + score = ( + float(event_scores.get(event_id, 0.0) or 0.0) + + (0.20 * float(path_scores.get(path_id, 0.0) or 0.0)) + + (0.15 * float(temporal_scores.get(support_node_id, 0.0) or 0.0)) + ) + candidates.append((path_id, score)) + for path_id, _ in sorted(candidates, key=lambda item: (-float(item[1]), item[0])): + if path_id in seen: + continue + seen.add(path_id) + ranked.append(path_id) + if ranked and (_clean_text(focused_answer_type) == "time" or bool(question_analysis.get("is_temporal", False))): + if preferred_type in {"speaker_event_time", "speaker_event_source_turn"}: + break + return ranked + + +def _repair_selected_paths_for_focus( + selected_path_ids: Sequence[str], + *, + runtime_paths: Mapping[str, Dict[str, Any]], + selected_event_ids: Sequence[str], + path_scores: Mapping[str, Any], + event_scores: Mapping[str, Any], + temporal_scores: Mapping[str, Any], + question_analysis: Dict[str, Any], + answer_type_scores: Dict[str, Any], + focused_answer_type: str, + limit: int, +) -> tuple[List[str], bool, str]: + normalized_selected = [_clean_text(path_id) for path_id in selected_path_ids if _clean_text(path_id)] + if not normalized_selected: + return [], False, "" + incompatible = [ + path_id + for path_id in normalized_selected + if not _path_type_is_focus_compatible( + _clean_text(runtime_paths.get(path_id, {}).get("type", "")), + focused_answer_type=focused_answer_type, + question_analysis=question_analysis, + ) + ] + if not incompatible: + return normalized_selected, False, "" + compatible_ranked = _rank_focus_compatible_path_ids( + runtime_paths=runtime_paths, + selected_event_ids=selected_event_ids, + path_scores=path_scores, + event_scores=event_scores, + temporal_scores=temporal_scores, + question_analysis=question_analysis, + answer_type_scores=answer_type_scores, + focused_answer_type=focused_answer_type, + ) + if not compatible_ranked: + return normalized_selected, False, "" + repaired = _dedupe( + [ + *[path_id for path_id in normalized_selected if path_id not in incompatible], + *compatible_ranked, + ] + )[: max(1, limit)] + if repaired == normalized_selected: + return normalized_selected, False, "" + return repaired, True, "replaced_focus_incompatible_model_paths" + + +def _runtime_node_by_id(runtime_graph: Mapping[str, Any]) -> Dict[str, Dict[str, Any]]: + return { + _clean_text(node.get("id", "")): dict(node) + for node in list(runtime_graph.get("nodes", []) or []) + if _clean_text(node.get("id", "")) + } + + +def _runtime_event_subject_signature(runtime_nodes: Mapping[str, Dict[str, Any]], event_id: str) -> str: + node = dict(runtime_nodes.get(_clean_text(event_id), {}) or {}) + metadata = dict(node.get("metadata", {}) or {}) + return _clean_text(node.get("subject_signature", "")) or _clean_text(metadata.get("subject_signature", "")) + + +def _path_utility_candidate_text( + path: Mapping[str, Any], + *, + runtime_nodes: Mapping[str, Dict[str, Any]], + grouped_hits: Mapping[str, Sequence[MemoryHit]], +) -> str: + event_id = _clean_text(path.get("event_id", "")) + support_node_id = _path_support_node_id(dict(path)) + node_texts = [ + _clean_text(dict(runtime_nodes.get(event_id, {}) or {}).get("text", "")), + _clean_text(dict(runtime_nodes.get(support_node_id, {}) or {}).get("text", "")), + ] + path_type = _clean_text(path.get("type", "")) + support_hit = _support_hit_for_path(path_type, grouped_hits.get(event_id, [])) + event_hit = _representative_event_hit(grouped_hits.get(event_id, [])) + hit_texts = [ + _clean_text(support_hit.value if support_hit is not None else ""), + _clean_text(event_hit.value if event_hit is not None else ""), + _runtime_source_turn_text(support_hit or event_hit, speaker=""), + ] + return " ".join(_dedupe([*node_texts, *hit_texts], max_items=6)) + + +def _path_utility_gate( + candidate_path_ids: Sequence[str], + *, + query: str, + runtime_graph: Mapping[str, Any], + runtime_paths: Mapping[str, Dict[str, Any]], + grouped_hits: Mapping[str, Sequence[MemoryHit]], + selected_path_ids: Sequence[str], + selected_event_ids_from_model: Sequence[str], + path_scores: Mapping[str, Any], + path_tunnel_support_scores: Mapping[str, Any], + question_analysis: Dict[str, Any], + focused_answer_type: str, + score_threshold: float, + limit: int, +) -> Dict[str, Any]: + runtime_nodes = _runtime_node_by_id(runtime_graph) + query_tokens = set(_path_utility_tokens(query)) + anchor_event_ids = _dedupe( + [ + *[ + _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + for path_id in selected_path_ids + if _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + ], + *[_clean_text(event_id) for event_id in selected_event_ids_from_model if _clean_text(event_id)], + ] + ) + anchor_subject_signatures = { + signature + for signature in ( + _runtime_event_subject_signature(runtime_nodes, event_id) + for event_id in anchor_event_ids + ) + if signature + } + direct_path_ids: List[str] = [] + contrast_path_ids: List[str] = [] + latent_path_ids: List[str] = [] + noise_path_ids: List[str] = [] + utility_scores: Dict[str, float] = {} + utility_roles: Dict[str, str] = {} + utility_reasons: Dict[str, str] = {} + utility_overlap_tokens: Dict[str, List[str]] = {} + for path_id in _dedupe(candidate_path_ids): + path = dict(runtime_paths.get(path_id, {}) or {}) + if not path: + continue + path_type = _clean_text(path.get("type", "")) + event_id = _clean_text(path.get("event_id", "")) + candidate_text = _path_utility_candidate_text( + path, + runtime_nodes=runtime_nodes, + grouped_hits=grouped_hits, + ) + candidate_tokens = set(_path_utility_tokens(candidate_text)) + overlap_tokens = sorted(query_tokens & candidate_tokens) + overlap_ratio = float(len(overlap_tokens)) / float(max(1, min(len(query_tokens), len(candidate_tokens)))) + support_score = float(path_tunnel_support_scores.get(path_id, 0.0) or 0.0) + decision_score = float(path_scores.get(path_id, support_score) or 0.0) + path_score = max(support_score, decision_score) + subject_signature = _runtime_event_subject_signature(runtime_nodes, event_id) + same_subject_chain = bool(subject_signature and subject_signature in anchor_subject_signatures) + focus_compatible = _path_type_is_focus_compatible( + path_type, + focused_answer_type=focused_answer_type, + question_analysis=question_analysis, + ) + utility_score = path_score + (0.20 * overlap_ratio) + (0.04 if same_subject_chain else 0.0) + utility_scores[path_id] = round(float(utility_score), 6) + utility_overlap_tokens[path_id] = overlap_tokens[:12] + if not focus_compatible: + role = "drift_noise" + reason = "focus_incompatible" + elif overlap_ratio >= 0.18 and path_score >= score_threshold: + role = "direct_support" + reason = "query_overlap_and_tunnel_score" + elif same_subject_chain and path_score >= score_threshold: + role = "contrast_support" + reason = "same_subject_deep_chain" + elif path_score >= score_threshold: + role = "latent_context" + reason = "tunnel_score_without_current_turn_utility" + else: + role = "drift_noise" + reason = "below_utility_threshold" + utility_roles[path_id] = role + utility_reasons[path_id] = reason + if role == "direct_support": + direct_path_ids.append(path_id) + elif role == "contrast_support": + contrast_path_ids.append(path_id) + elif role == "latent_context": + latent_path_ids.append(path_id) + else: + noise_path_ids.append(path_id) + injected_path_ids = _dedupe([*direct_path_ids, *contrast_path_ids], max_items=max(0, int(limit))) + overflow_latent_path_ids = [ + path_id + for path_id in [*direct_path_ids, *contrast_path_ids] + if path_id not in set(injected_path_ids) + ] + latent_path_ids = _dedupe([*latent_path_ids, *overflow_latent_path_ids]) + return { + "enabled": True, + "candidate_path_ids": list(_dedupe(candidate_path_ids)), + "direct_support_path_ids": list(direct_path_ids), + "contrast_support_path_ids": list(contrast_path_ids), + "latent_context_path_ids": list(latent_path_ids), + "drift_noise_path_ids": list(noise_path_ids), + "injected_path_ids": list(injected_path_ids), + "roles": dict(utility_roles), + "reasons": dict(utility_reasons), + "scores": dict(utility_scores), + "overlap_tokens": dict(utility_overlap_tokens), + "anchor_event_ids": list(anchor_event_ids), + "anchor_subject_signatures": sorted(anchor_subject_signatures), + } + + +def _calibrated_path_score( + *, + path: Dict[str, Any], + base_score: float, + temporal_scores: Dict[str, Any], + question_analysis: Dict[str, Any], + answer_type_scores: Dict[str, Any], +) -> float: + normalized_path_type = _clean_text(path.get("type", "")) + support_node_id = _path_support_node_id(path) + temporal_score = float(temporal_scores.get(support_node_id, 0.0) or 0.0) + normalized_answer_scores = { + _clean_text(answer_type): float(value or 0.0) + for answer_type, value in dict(answer_type_scores or {}).items() + if _clean_text(answer_type) + } + dominant_answer_type = _dominant_answer_type(question_analysis, normalized_answer_scores) + calibrated = float(base_score) + if dominant_answer_type == "time" or bool(question_analysis.get("is_temporal", False)): + if normalized_path_type == "speaker_event_time": + calibrated += (0.30 * temporal_score) + (0.12 * normalized_answer_scores.get("time", 0.0)) + elif normalized_path_type == "speaker_event_source_turn": + calibrated -= 0.08 + (0.04 * normalized_answer_scores.get("time", 0.0)) + elif normalized_path_type == "speaker_event_status": + calibrated -= 0.12 + elif normalized_path_type == "speaker_event_profile": + calibrated -= 0.22 + (0.06 * normalized_answer_scores.get("time", 0.0)) + elif dominant_answer_type == "profile": + if normalized_path_type == "speaker_event_profile": + calibrated += 0.12 * normalized_answer_scores.get("profile", 0.0) + elif normalized_path_type != "speaker_event_source_turn": + calibrated -= 0.08 + else: + if normalized_path_type == "speaker_event_source_turn": + calibrated += 0.08 * max( + normalized_answer_scores.get("event_text", 0.0), + normalized_answer_scores.get("multi_evidence", 0.0), + ) + elif normalized_path_type == "speaker_event_profile": + calibrated -= 0.04 + return calibrated + + +def _group_metadata_value( + group_hits: Sequence[MemoryHit], + key: str, + *, + source_kinds: Sequence[str] = (), +) -> str: + source_kind_set = {_clean_text(item) for item in list(source_kinds) if _clean_text(item)} + candidates = [ + hit + for hit in group_hits + if not source_kind_set or _clean_text(hit.source_kind) in source_kind_set + ] + candidates.sort(key=lambda item: (float(item.score), int(item.turn_index)), reverse=True) + for hit in candidates: + metadata = dict(hit.metadata or {}) + value = _clean_text(metadata.get(key, "")) + if value: + return value + return "" + + +def _runtime_event_sequence_key(session_name: str, turn_index: int, event_id: str) -> tuple[Any, ...]: + normalized_session = _clean_text(session_name) + session_number_match = re.search(r"(\d+)$", normalized_session) + if session_number_match: + return (0, int(session_number_match.group(1)), int(turn_index), _clean_text(event_id)) + return (1, normalized_session, int(turn_index), _clean_text(event_id)) + + +def _runtime_source_turn_text(hit: MemoryHit | None, *, speaker: str) -> str: + if hit is None: + return "" + metadata = dict(hit.metadata or {}) + source_turn_text = _clean_text(metadata.get("source_turn_text", "")) + if source_turn_text: + return source_turn_text + raw_text = _clean_text(metadata.get("raw_text", "")) + if raw_text: + auxiliary_text = _clean_text(metadata.get("auxiliary_evidence_text", "")) + if auxiliary_text and auxiliary_text.lower() not in raw_text.lower(): + return f"{raw_text}\nAuxiliary evidence: {auxiliary_text}" + return raw_text + origin_query = _clean_text(metadata.get("origin_query", "")) + if origin_query: + return origin_query + text = _clean_text(hit.value) + if _clean_text(hit.source_kind) == "public_dialog_turn": + text = re.sub(r"^\[[^\]]+\]\s*", "", text) + if _clean_text(speaker): + text = re.sub(rf"^{re.escape(_clean_text(speaker))}\s*:\s*", "", text, flags=re.IGNORECASE) + text = re.sub(r"^[A-Za-z][A-Za-z0-9_' -]{0,40}:\s*", "", text) + return text + + +def _runtime_event_signature( + *, + group_hits: Sequence[MemoryHit], + representative: MemoryHit, + speaker: str, + semantic_slot: str, + source_turn_hit: MemoryHit | None, +) -> str: + existing = _group_metadata_value(group_hits, "event_signature") or _clean_text(dict(representative.metadata or {}).get("event_signature", "")) + if existing: + return existing + event_phrase = _group_metadata_value(group_hits, "event_phrase") or _clean_text(dict(representative.metadata or {}).get("event_phrase", "")) + source_turn_text = _runtime_source_turn_text(source_turn_hit, speaker=speaker) + base_text = event_phrase or _clean_text(representative.value) or source_turn_text + if not base_text: + return "" + return compute_public_event_signature( + base_text, + speaker=_clean_text(speaker), + semantic_slot=_clean_text(semantic_slot), + ) or _clean_text(base_text) + + +def _build_runtime_graph_from_hits(query: str, hits: Sequence[MemoryHit]) -> Dict[str, Any]: + nodes: List[Dict[str, Any]] = [] + edges: List[Dict[str, Any]] = [] + paths: List[Dict[str, Any]] = [] + node_ids = set() + grouped_hits: Dict[str, List[MemoryHit]] = {} + ordered_events: List[tuple[Any, ...]] = [] + event_typed_metadata_by_id: Dict[str, Dict[str, Any]] = {} + typed_tunnel_edges: List[Dict[str, Any]] = [] + + def add_node(node: Dict[str, Any]) -> None: + node_id = _clean_text(node.get("id", "")) + if not node_id or node_id in node_ids: + return + node_ids.add(node_id) + nodes.append(node) + + for hit in hits: + event_id = _runtime_event_key(hit) + grouped_hits.setdefault(event_id, []).append(hit) + + for event_id, group_hits in grouped_hits.items(): + representative = _representative_event_hit(group_hits) + if representative is None: + continue + metadata = dict(representative.metadata or {}) + event_time_hit = _support_hit_for_path("speaker_event_time", group_hits) + event_profile_hit = _support_hit_for_path("speaker_event_profile", group_hits) + source_turn_hit = _support_hit_for_path("speaker_event_source_turn", group_hits) + event_time_display_value = ( + _group_metadata_value(group_hits, "time_display_value", source_kinds=("public_dialog_time",)) + or _group_metadata_value(group_hits, "resolved_date") + or _group_metadata_value(group_hits, "time_display_value") + ) + event_time_value = ( + _group_metadata_value(group_hits, "resolved_time_value", source_kinds=("public_dialog_time",)) + or _group_metadata_value(group_hits, "resolved_time_value") + or _group_metadata_value(group_hits, "time_value") + ) + event_time_granularity = ( + _group_metadata_value(group_hits, "time_granularity", source_kinds=("public_dialog_time",)) + or _group_metadata_value(group_hits, "time_granularity") + ) + event_profile_type = ( + _group_metadata_value(group_hits, "profile_type", source_kinds=("public_dialog_profile",)) + or _group_metadata_value(group_hits, "semantic_slot", source_kinds=("public_dialog_profile",)) + or _group_metadata_value(group_hits, "profile_type") + or _group_metadata_value(group_hits, "semantic_slot") + ) + profile_hit = _support_hit_for_path("speaker_event_profile", group_hits) + event_profile_value = ( + _group_metadata_value(group_hits, "profile_value") + or _clean_text(profile_hit.value if profile_hit is not None else "") + or (_clean_text(representative.value) if event_profile_type else "") + ) + event_target_status = _group_metadata_value(group_hits, "target_status") + depth_layer = ( + _clean_text(metadata.get("depth_layer", "")) + or _group_metadata_value(group_hits, "depth_layer") + or _clean_text(metadata.get("memory_chain_depth_layer", "")) + or _group_metadata_value(group_hits, "memory_chain_depth_layer") + ) + subject_signature = ( + _clean_text(metadata.get("subject_signature", "")) + or _group_metadata_value(group_hits, "subject_signature") + or _clean_text(metadata.get("memory_chain_subject_signature", "")) + or _group_metadata_value(group_hits, "memory_chain_subject_signature") + ) + session_name = ( + _group_metadata_value(group_hits, "session_name") + or _group_metadata_value(group_hits, "session_key") + or _group_metadata_value(group_hits, "scope_id") + or "runtime_session" + ) + speaker = ( + _clean_text(metadata.get("speaker", "")) + or _group_metadata_value(group_hits, "speaker") + or _clean_text(metadata.get("subject_signature", "")) + or (_clean_text(representative.anchors[0]) if representative.anchors else "") + or "speaker" + ) + event_turn_index = int(getattr(representative, "turn_index", 0) or 0) + semantic_slot = ( + _group_metadata_value(group_hits, "semantic_slot") + or _clean_text(metadata.get("semantic_slot", "")) + or _clean_text(metadata.get("profile_type", "")) + or ("profile" if event_profile_type else _clean_text(representative.category)) + or "event" + ) + teacher_fields = { + "event_phrase": _clean_text(metadata.get("event_phrase", "")) or _clean_text(representative.value), + "semantic_slot": semantic_slot, + "target_status": event_target_status, + "time_expression_span": event_time_display_value, + "time_granularity": event_time_granularity, + "profile_type": event_profile_type, + } + base_event_signature = _runtime_event_signature( + group_hits=group_hits, + representative=representative, + speaker=speaker, + semantic_slot=event_profile_type or teacher_fields["semantic_slot"], + source_turn_hit=source_turn_hit, + ) + typed_tunnel_metadata = merge_typed_metadata([metadata, *[dict(hit.metadata or {}) for hit in group_hits]]) + event_typed_metadata_by_id[event_id] = typed_tunnel_metadata + typed_signature = typed_tunnel_signature_text(typed_tunnel_metadata) + event_text = _clean_text(metadata.get("event_phrase", "")) or representative.value + depth_prefix = " ".join( + item + for item in ( + f"subject {subject_signature.replace('_', ' ')}" if subject_signature else "", + f"depth layer {depth_layer.replace('_', ' ')}" if depth_layer else "", + ) + if item + ) + runtime_event_text = f"{depth_prefix} {event_text}".strip() if depth_prefix else event_text + event_signature = ( + compute_public_event_signature( + runtime_event_text, + speaker=_clean_text(speaker), + semantic_slot=_clean_text(event_profile_type or teacher_fields["semantic_slot"]), + ) + or base_event_signature + ) + if typed_signature and typed_signature not in event_signature: + event_signature = f"{event_signature} {typed_signature}".strip() + speaker_node_id = f"{event_id}:speaker:{_normalize(speaker).replace(' ', '_') or 'speaker'}" + add_node({"id": speaker_node_id, "type": "speaker", "text": speaker, "metadata": {"speaker": speaker}}) + add_node( + { + "id": event_id, + "type": "event", + "text": runtime_event_text, + "speaker": speaker, + "turn_index": event_turn_index, + "session_name": session_name, + "dia_id": _clean_text(metadata.get("dia_id", "")), + "event_signature": event_signature, + "slot_key": _clean_text(representative.slot_key), + "state_signature": _clean_text(metadata.get("state_signature", "")), + "memory_signature": _clean_text(metadata.get("memory_signature", "")), + "target_status": event_target_status, + "time_granularity": event_time_granularity, + "time_value": event_time_value, + "time_display_value": event_time_display_value, + "profile_type": event_profile_type, + "profile_value": event_profile_value, + "depth_layer": depth_layer, + "subject_signature": subject_signature, + "tmcra_node_tags": list(typed_tunnel_metadata.get("tmcra_node_tags", []) or []), + "tmcra_path_tags": list(typed_tunnel_metadata.get("tmcra_path_tags", []) or []), + "tmcra_tunnel_roles": list(typed_tunnel_metadata.get("tmcra_tunnel_roles", []) or []), + "tmcra_tunnel_group_key": _clean_text(typed_tunnel_metadata.get("tmcra_tunnel_group_key", "")), + "teacher_fields": teacher_fields, + "metadata": { + "speaker": speaker, + "session_name": session_name, + "dia_id": _clean_text(metadata.get("dia_id", "")), + "slot_key": _clean_text(representative.slot_key), + "state_signature": _clean_text(metadata.get("state_signature", "")), + "memory_signature": _clean_text(metadata.get("memory_signature", "")), + "target_status": event_target_status, + "time_granularity": event_time_granularity, + "time_value": event_time_value, + "time_display_value": event_time_display_value, + "profile_type": event_profile_type, + "profile_value": event_profile_value, + "event_signature": event_signature, + "depth_layer": depth_layer, + "subject_signature": subject_signature, + **typed_tunnel_metadata, + }, + } + ) + edges.append({"id": f"{speaker_node_id}->{event_id}:speaker_of", "source": speaker_node_id, "target": event_id, "type": "speaker_of"}) + + time_node_ids: List[str] = [] + profile_node_ids: List[str] = [] + status_node_ids: List[str] = [] + source_turn_node_ids: List[str] = [] + + if event_time_display_value or event_time_value: + time_node_id = f"{event_id}:time" + time_hit_metadata = dict((event_time_hit or representative).metadata or {}) + add_node( + { + "id": time_node_id, + "type": "time", + "text": event_time_display_value or event_time_value, + "turn_index": int(getattr(event_time_hit, "turn_index", 0) or 0), + "time_display_value": event_time_display_value, + "time_value": event_time_value, + "time_granularity": event_time_granularity, + "metadata": { + "time_display_value": event_time_display_value, + "time_value": event_time_value, + "time_granularity": event_time_granularity, + "resolved_date": _clean_text(time_hit_metadata.get("resolved_date", "")), + }, + } + ) + edges.append({"id": f"{event_id}->{time_node_id}:time_of", "source": event_id, "target": time_node_id, "type": "time_of"}) + time_node_ids.append(time_node_id) + if event_profile_value: + profile_node_id = f"{event_id}:profile:{_normalize(event_profile_type).replace(' ', '_') or 'profile'}" + add_node( + { + "id": profile_node_id, + "type": "profile", + "text": event_profile_value, + "turn_index": int(getattr(event_profile_hit or representative, "turn_index", 0) or 0), + "profile_type": event_profile_type, + "profile_value": event_profile_value, + "metadata": { + "profile_type": event_profile_type, + "profile_value": event_profile_value, + }, + } + ) + edges.append({"id": f"{event_id}->{profile_node_id}:profile_of", "source": event_id, "target": profile_node_id, "type": "profile_of"}) + profile_node_ids.append(profile_node_id) + source_turn_text = _runtime_source_turn_text(source_turn_hit or representative, speaker=speaker) + if source_turn_text: + source_turn_node_id = f"{event_id}:source_turn" + add_node( + { + "id": source_turn_node_id, + "type": "source_turn", + "text": source_turn_text, + "turn_index": int(getattr(source_turn_hit or representative, "turn_index", 0) or 0), + "metadata": { + "speaker": speaker, + "dia_id": _clean_text(dict((source_turn_hit or representative).metadata or {}).get("dia_id", "")), + }, + } + ) + edges.append({"id": f"{event_id}->{source_turn_node_id}:supported_by_turn", "source": event_id, "target": source_turn_node_id, "type": "supported_by_turn"}) + source_turn_node_ids.append(source_turn_node_id) + if event_target_status: + status_node_id = f"{event_id}:status" + add_node({"id": status_node_id, "type": "status", "text": event_target_status, "metadata": {"target_status": event_target_status}}) + edges.append({"id": f"{event_id}->{status_node_id}:status_of", "source": event_id, "target": status_node_id, "type": "status_of"}) + status_node_ids.append(status_node_id) + event_paths = build_default_path_templates( + event_id=event_id, + speaker_node_id=speaker_node_id, + time_node_ids=time_node_ids, + profile_node_ids=profile_node_ids, + status_node_ids=status_node_ids, + source_turn_node_ids=source_turn_node_ids, + ) + for path in event_paths: + path_metadata = dict(path.get("metadata", {}) or {}) + path_metadata["tmcra_path_tags"] = list(typed_tunnel_metadata.get("tmcra_path_tags", []) or []) + path_metadata["tmcra_tunnel_group_key"] = _clean_text(typed_tunnel_metadata.get("tmcra_tunnel_group_key", "")) + path["metadata"] = path_metadata + path["tmcra_path_tags"] = path_metadata["tmcra_path_tags"] + paths.extend(event_paths) + ordered_events.append(_runtime_event_sequence_key(session_name, event_turn_index, event_id)) + ordered_event_ids = [event_id for _, _, _, event_id in sorted(ordered_events)] + for previous_event_id, next_event_id in zip(ordered_event_ids, ordered_event_ids[1:]): + typed_edge_tags = typed_edge_tags_between( + event_typed_metadata_by_id.get(previous_event_id, {}), + event_typed_metadata_by_id.get(next_event_id, {}), + ) + edges.append( + { + "id": f"{previous_event_id}->{next_event_id}:same_session_next", + "source": previous_event_id, + "target": next_event_id, + "type": "same_session_next", + "metadata": { + "tmcra_edge_tags": typed_edge_tags, + "typed_tunnel_edge": bool(typed_edge_tags), + }, + } + ) + for index, source_event_id in enumerate(ordered_event_ids): + for target_event_id in ordered_event_ids[index + 1 : index + 12]: + typed_edge_tags = typed_edge_tags_between( + event_typed_metadata_by_id.get(source_event_id, {}), + event_typed_metadata_by_id.get(target_event_id, {}), + ) + if not typed_edge_tags: + continue + typed_tunnel_edges.append( + { + "id": f"{source_event_id}->{target_event_id}:typed_tunnel", + "source": source_event_id, + "target": target_event_id, + "type": "typed_tunnel_candidate", + "metadata": { + "tmcra_edge_tags": typed_edge_tags, + "typed_tunnel_edge": True, + }, + } + ) + if len(typed_tunnel_edges) >= 64: + break + if len(typed_tunnel_edges) >= 64: + break + return { + "conversation_id": "runtime", + "query": query, + "nodes": nodes, + "edges": edges, + "typed_tunnel_edges": typed_tunnel_edges, + "paths": paths, + "grouped_hits": grouped_hits, + } + + +def _public_graph_hits(graph: SessionMemoryGraphV2) -> List[MemoryHit]: + public_hits: List[MemoryHit] = [] + for record in graph.records_by_id.values(): + if record.state != "active": + continue + if not _clean_text(record.source_kind).startswith("public_dialog"): + continue + metadata = dict(record.metadata or {}) + public_hits.append( + MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=max(float(record.confidence), float(record.salience), 0.01), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=metadata, + ) + ) + public_hits.sort(key=lambda item: (int(item.turn_index), float(item.score)), reverse=True) + return public_hits + + +_AUDIT_ANCHOR_QUERY_RE = re.compile( + r"(?i)\b(?:remember|recall|earlier|previous|previously|old|before|mentioned|said|quote|verbatim|original|" + r"that time|last time|bring back|return to|go back to|turn\s*\d+)\b|" + r"(?:\u7b2c\s*\d+\s*[\u8f6e\u6b21]|\d+\s*\u8f6e|\u4e4b\u524d|\u4ee5\u524d|\u521a\u521a|" + r"\u539f\u8bdd|\u90a3\u6b21|\u65e7\u8bdd\u9898|\u56de\u5230|\u63d0\u8d77|\u8bb0\u5f97)" +) + +_AUDIT_TURN_ANCHOR_RE = re.compile( + r"(?i)\bturn\s*#?\s*(\d{1,6})\b|" + r"(?:\u7b2c\s*(\d{1,6})\s*[\u8f6e\u6b21]|\b(\d{1,6})\s*\u8f6e\b)" +) + +_AUDIT_ANCHOR_STOPWORDS = set(_HYBRID_SYMBOLIC_STOPWORDS) | { + "about", + "again", + "back", + "bring", + "can", + "could", + "did", + "discuss", + "discussed", + "earlier", + "from", + "just", + "keep", + "mentioned", + "old", + "one", + "previous", + "previously", + "really", + "recall", + "remember", + "return", + "said", + "still", + "that", + "the", + "thing", + "think", + "this", + "turn", + "what", + "when", + "where", + "with", + "you", +} + +_AUDIT_ANCHOR_GENERIC_MATCH_TOKENS = { + "body", + "coherence", + "continuity", + "fiction", + "fray", + "memories", + "memory", + "narrative", + "ourselves", + "physical", + "really", + "stories", + "story", + "tell", + "trust", +} + +_AUDIT_ANCHOR_PHRASE_TOKEN_SETS = [ + {"spine", "book"}, + {"body", "spine"}, + {"stories", "pages"}, + {"story", "pages"}, + {"islands", "self"}, + {"archipelago", "self"}, + {"loom", "body"}, + {"braid", "body"}, +] + + +def _audit_anchor_query(query: str) -> bool: + return bool(_AUDIT_ANCHOR_QUERY_RE.search(_clean_text(query))) + + +def _audit_anchor_turn_numbers(query: str) -> List[int]: + numbers: List[int] = [] + for match in _AUDIT_TURN_ANCHOR_RE.finditer(_clean_text(query)): + for group in match.groups(): + if not group: + continue + try: + value = int(group) + except Exception: + continue + if value > 0 and value not in numbers: + numbers.append(value) + return numbers + + +def _hit_event_id(hit: MemoryHit) -> str: + metadata = dict(hit.metadata or {}) + event_id = _clean_text(metadata.get("event_id", "")) + if event_id: + return event_id + dia_id = _clean_text(metadata.get("dia_id", "")) + if dia_id: + return f"event::{dia_id}" + if int(hit.turn_index or 0) > 0: + return f"event::realchat:{int(hit.turn_index)}" + return "" + + +def _audit_anchor_hit_text(hit: MemoryHit) -> str: + metadata = dict(hit.metadata or {}) + values = [ + hit.value, + hit.category, + hit.source_kind, + hit.slot_key, + metadata.get("event_text", ""), + metadata.get("source_span", ""), + metadata.get("raw_text", ""), + ] + return " ".join(_clean_text(value) for value in values if _clean_text(value)) + + +def _audit_anchor_content_tokens(text: str) -> set[str]: + tokens = set() + for token in _tokenize(text): + norm = _normalize(token) + if not norm or norm in _AUDIT_ANCHOR_STOPWORDS: + continue + if len(norm) < 3 and not any("\u4e00" <= ch <= "\u9fff" for ch in norm): + continue + tokens.add(norm) + return tokens + + +def _audit_anchor_phrase_bonus(query_tokens: set[str], hit_tokens: set[str]) -> tuple[float, List[str]]: + matched_phrases: List[str] = [] + bonus = 0.0 + for phrase_tokens in _AUDIT_ANCHOR_PHRASE_TOKEN_SETS: + if phrase_tokens <= query_tokens and phrase_tokens <= hit_tokens: + matched_phrases.append("+".join(sorted(phrase_tokens))) + bonus += 5.0 if phrase_tokens == {"spine", "book"} else 3.0 + return bonus, matched_phrases + + +def _copy_hit_with_audit_anchor_boost(hit: MemoryHit, *, score: float, reason: str, matched_tokens: Sequence[str]) -> MemoryHit: + metadata = dict(hit.metadata or {}) + metadata.update( + { + "audit_anchor_protected": True, + "audit_anchor_reason": reason, + "audit_anchor_original_score": round(float(hit.score), 6), + "audit_anchor_score": round(float(score), 6), + "audit_anchor_matched_tokens": list(matched_tokens)[:20], + } + ) + return MemoryHit( + memory_id=hit.memory_id, + category=hit.category, + value=hit.value, + relation=hit.relation, + anchors=list(hit.anchors), + score=max(float(hit.score), round(1.0 + float(score), 6)), + source_kind=hit.source_kind, + slot_key=hit.slot_key, + state=hit.state, + turn_index=int(hit.turn_index), + metadata=metadata, + ) + + +def _audit_anchor_protected_hits( + *, + query: str, + final_hits: Sequence[MemoryHit], + candidate_hits: Sequence[MemoryHit], + metadata: Mapping[str, Any], + top_k: int, +) -> Dict[str, Any]: + if not _audit_anchor_query(query): + return {"enabled": False, "hits": list(final_hits), "promoted_hits": [], "metadata": {"audit_anchor_enabled": False}} + query_tokens = _audit_anchor_content_tokens(query) + if not query_tokens: + return {"enabled": True, "hits": list(final_hits), "promoted_hits": [], "metadata": {"audit_anchor_enabled": True, "audit_anchor_reason": "no_content_tokens"}} + + explicit_turns = _audit_anchor_turn_numbers(query) + explicit_turn_window = set() + for number in explicit_turns: + explicit_turn_window.update({number - 1, number, number + 1}) + + symbolic_ids = list(dict.fromkeys(str(item) for item in dict(metadata or {}).get("symbolic_recall_event_ids", []) or [])) + learned_ids = list(dict.fromkeys(str(item) for item in dict(metadata or {}).get("learned_recall_event_ids", []) or [])) + selected_ids = set(str(item) for item in dict(metadata or {}).get("selected_event_ids", []) or []) + symbolic_rank = {event_id: index for index, event_id in enumerate(symbolic_ids, start=1)} + learned_rank = {event_id: index for index, event_id in enumerate(learned_ids, start=1)} + + pool: Dict[str, MemoryHit] = {} + for hit in list(final_hits) + list(candidate_hits): + key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}" + if key and key not in pool: + pool[key] = hit + + scored: List[tuple[float, str, List[str], MemoryHit]] = [] + for hit in pool.values(): + event_id = _hit_event_id(hit) + hit_tokens = _audit_anchor_content_tokens(_audit_anchor_hit_text(hit)) + matched = sorted(query_tokens & hit_tokens) + distinctive_matched = [token for token in matched if token not in _AUDIT_ANCHOR_GENERIC_MATCH_TOKENS] + phrase_bonus, matched_phrases = _audit_anchor_phrase_bonus(query_tokens, hit_tokens) + if not matched and int(hit.turn_index or 0) not in explicit_turn_window: + continue + score = 0.0 + reason_parts: List[str] = [] + if int(hit.turn_index or 0) in explicit_turn_window: + score += 8.0 + reason_parts.append("explicit_turn_anchor") + if matched: + weighted_overlap = 0.0 + for token in matched: + weighted_overlap += 1.8 if token not in _AUDIT_ANCHOR_GENERIC_MATCH_TOKENS else 0.35 + score += min(8.0, weighted_overlap) + score += len(matched) / max(1.0, float(min(len(query_tokens), len(hit_tokens)))) + reason_parts.append("lexical_anchor_overlap") + if phrase_bonus > 0: + score += phrase_bonus + reason_parts.append("distinctive_phrase_anchor") + if event_id in symbolic_rank: + score += max(0.0, 2.0 - (symbolic_rank[event_id] - 1) * 0.08) + reason_parts.append("symbolic_recall_anchor") + if event_id in learned_rank: + score += max(0.0, 1.0 - (learned_rank[event_id] - 1) * 0.02) + reason_parts.append("learned_recall_anchor") + if event_id in selected_ids: + score -= 0.25 + # For non-numeric old-topic probes, require a real content overlap so generic + # "earlier" phrasing does not promote unrelated old memories. + if not explicit_turns and len(distinctive_matched) < 1 and phrase_bonus <= 0: + continue + if score > 0: + scored.append((score, event_id, sorted(set(matched + matched_phrases)), hit)) + + scored.sort(key=lambda item: (item[0], -abs(int(item[3].turn_index or 0))), reverse=True) + max_promoted = 2 if not explicit_turns else 3 + promoted: List[MemoryHit] = [] + promoted_event_ids: set[str] = set() + for score, event_id, matched, hit in scored: + if event_id in promoted_event_ids: + continue + if any((hit.memory_id and hit.memory_id == current.memory_id) for current in final_hits[: max(1, top_k)]): + continue + reason = "explicit_turn_anchor" if int(hit.turn_index or 0) in explicit_turn_window else "old_topic_anchor" + promoted.append(_copy_hit_with_audit_anchor_boost(hit, score=score, reason=reason, matched_tokens=matched)) + promoted_event_ids.add(event_id) + if len(promoted) >= max_promoted: + break + + if not promoted: + return { + "enabled": True, + "hits": list(final_hits), + "promoted_hits": [], + "metadata": { + "audit_anchor_enabled": True, + "audit_anchor_turn_numbers": explicit_turns, + "audit_anchor_query_tokens": sorted(query_tokens)[:50], + "audit_anchor_promoted_event_ids": [], + }, + } + + merged: List[MemoryHit] = [] + seen_keys: set[str] = set() + for hit in promoted + list(final_hits): + key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}" + if key in seen_keys: + continue + seen_keys.add(key) + merged.append(hit) + if len(merged) >= max(1, int(top_k)): + break + + return { + "enabled": True, + "hits": merged, + "promoted_hits": promoted, + "metadata": { + "audit_anchor_enabled": True, + "audit_anchor_turn_numbers": explicit_turns, + "audit_anchor_query_tokens": sorted(query_tokens)[:50], + "audit_anchor_promoted_event_ids": [_hit_event_id(hit) for hit in promoted], + "audit_anchor_promoted_turns": [int(hit.turn_index) for hit in promoted], + "audit_anchor_promoted_hit_count": len(promoted), + }, + } + + +def _learnable_graph_hits(graph: SessionMemoryGraphV2) -> List[MemoryHit]: + learnable_hits: List[MemoryHit] = [] + for record in graph.records_by_id.values(): + state = _normalize(record.state) + metadata = dict(record.metadata or {}) + is_source_grounded_evidence = ( + state == "evidence" + and ( + _clean_text(record.source_kind).startswith("public_dialog") + or _clean_text(metadata.get("content_variant", "")) in {"source_turn", "llm_semantic_write"} + or _clean_text(metadata.get("write_path", "")) == "llm_semantic_writer_gate" + ) + ) + if state not in {"active", "parallel_active"} and not is_source_grounded_evidence: + continue + if _clean_text(record.slot_key).startswith("noise."): + continue + learnable_hits.append( + MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=max(float(record.confidence), float(record.salience), 0.01), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=metadata, + ) + ) + learnable_hits.sort( + key=lambda item: ( + _is_public_dialog_hit(item), + int(item.turn_index), + float(item.score), + ), + reverse=True, + ) + return learnable_hits + + +def _parse_structured_records( + payload: Dict[str, Any] | None, + *, + turn_index: int, + profile: TMCRAProfile | None = None, +) -> List[SessionMemoryRecordV2]: + profile = profile or TMCRAProfile() + records: List[SessionMemoryRecordV2] = [] + structured_rows: List[tuple[str, Mapping[str, Any]]] = [] + for raw in (payload or {}).get("replacement_memory_records", []) or []: + if isinstance(raw, Mapping): + structured_rows.append(("formal", raw)) + for raw in (payload or {}).get("suspect_memory_records", []) or []: + if isinstance(raw, Mapping): + structured_rows.append(("suspect", raw)) + for index, (buffer_state, raw) in enumerate(structured_rows): + if not isinstance(raw, dict): + continue + category = _clean_text(raw.get("category", "memory")) or "memory" + value = _clean_text(raw.get("value", "")) + if not value: + continue + anchors = _dedupe(raw.get("anchors", []) or [], max_items=8) + metadata = dict(raw.get("metadata", {}) or {}) + slot_key = profile.stable_slot_key( + category=category, + value=value, + anchors=anchors, + slot_key=_clean_text(raw.get("slot_key", "")) or _clean_text(raw.get("slot", "")), + relation=_clean_text(raw.get("relation", "")), + metadata=metadata, + ) + metadata = { + **metadata, + "memory_role": _clean_text(metadata.get("memory_role", "")) or "user", + "authority": _clean_text(metadata.get("authority", "")) or "source", + "canonical_slot_key": _clean_text(metadata.get("canonical_slot_key", "")) or slot_key, + "writeback_class": _clean_text(metadata.get("writeback_class", "")), + "origin_query": _clean_text(metadata.get("origin_query", "")), + "origin_answer_id": _clean_text(metadata.get("origin_answer_id", "")), + "support_memory_ids": _dedupe(metadata.get("support_memory_ids", []) or []), + "support_fact_refs": _dedupe(metadata.get("support_fact_refs", []) or []), + "support_path_refs": _dedupe(metadata.get("support_path_refs", []) or []), + "promotion_state": _clean_text(metadata.get("promotion_state", "")) or "none", + "memory_buffer_state": buffer_state, + } + memory_id = f"{slot_key}:{turn_index}:{index}" + default_source_kind = "suspect_memory" if buffer_state == "suspect" else "replacement_memory" + default_state = "suspect" if buffer_state == "suspect" else ("active" if bool(raw.get("active", True)) else "historical") + records.append( + SessionMemoryRecordV2( + memory_id=memory_id, + category=category, + slot_key=slot_key, + value=value, + relation=_clean_text(raw.get("relation", "")) or f"{category}_memory", + anchor_concepts=anchors, + evidence_anchors=anchors, + salience=float(raw.get("salience", 0.88 if category in {"goal", "constraint"} else 0.74) or 0.74), + confidence=float(raw.get("confidence", 0.82) or 0.82), + source_kind=_clean_text(raw.get("source_kind", "")) or default_source_kind, + turn_index=int(raw.get("turn_index", turn_index) or turn_index), + state=_clean_text(raw.get("state", "")) or default_state, + metadata=metadata, + ) + ) + return records + + +_WRITE_MARKERS = ( + "goal update", + "goal seed", + "constraint update", + "constraint overwrite", + "constraint seed", + "preference update", + "preference overwrite", + "preference seed", + "terminology:", + "term seed", + "term overwrite", + "stage update", + "stage overwrite", + "path fact", + "fact:", + "memory update", +) + + +_OVERWRITE_MARKERS = ( + "overwrite", + "replace", + "supersede", + "覆盖", + "替换", + "改成", + "更新为", +) + + +_TOPIC_BUCKET_CUE_GROUPS: tuple[tuple[tuple[str, ...], tuple[str, ...]], ...] = ( + ( + ( + "食物", + "饮品", + "甜点", + "早餐", + "餐饮", + "口味", + "酸辣", + "热食", + "过敏", + "芒果", + "配料", + "推荐", + "吃", + "food", + "breakfast", + "allergy", + "mango", + ), + ("餐饮偏好安全", "食物偏好", "过敏约束", "口味偏好"), + ), + ( + ( + "界面", + "首页", + "首屏", + "配色", + "营销", + "横幅", + "电商", + "网站", + "移动端", + "视觉", + "工具型", + "ui", + "homepage", + "ecommerce", + "mobile", + "design", + ), + ("界面产品设计", "电商页面", "视觉布局", "移动端体验"), + ), + ( + ( + "api", + "writer", + "延迟", + "算法", + "耗时", + "评估", + "指标", + "调用", + "服务", + "key", + "runtime", + "latency", + "metric", + ), + ("api评估运行", "算法服务", "调用指标", "writer延迟"), + ), +) + + +_TOPIC_BUCKET_STOPWORDS = { + "不要", + "不用", + "复述", + "原话", + "只说", + "关键", + "最关键", + "以后", + "需要", + "必须", + "可以", + "时候", + "这个", + "那个", + "现在", + "做", + "说", + "的", + "了", + "和", + "与", + "to", + "the", + "and", + "for", +} + + +def _topic_bucket_keywords(text: str, *, max_items: int = 18) -> List[str]: + normalized = _normalize(text) + tokens: List[str] = [] + for cues, anchors in _TOPIC_BUCKET_CUE_GROUPS: + if any(cue and _normalize(cue) in normalized for cue in cues): + tokens.extend(anchors) + tokens.extend(cue for cue in cues if len(cue) >= 2) + tokens.extend(re.findall(r"[\u4e00-\u9fff]{2,8}", str(text or ""))) + tokens.extend(_tokenize(text)) + cleaned = [] + for token in tokens: + value = _clean_text(token) + if not value: + continue + normalized_value = _normalize(value) + if normalized_value in _TOPIC_BUCKET_STOPWORDS: + continue + if len(value) < 2: + continue + cleaned.append(value) + return _dedupe(cleaned, max_items=max_items) + + +def _topic_bucket_id_from_keywords(keywords: Sequence[str]) -> str: + basis = "|".join(_normalize(item) for item in keywords[:8] if _clean_text(item)) + if not basis: + basis = "general" + return "topic-" + str(uuid.uuid5(uuid.NAMESPACE_URL, f"tmcra-topic-bucket:{basis}"))[:12] + + +def _topic_bucket_label_from_keywords(keywords: Sequence[str]) -> str: + visible = [ + _clean_text(item) + for item in keywords + if _clean_text(item) and not _normalize(item).startswith("topic:") + ] + return " / ".join(visible[:3]) if visible else "动态话题" + + +def _topic_bucket_overlap_score(left_keywords: Sequence[str], right_keywords: Sequence[str]) -> float: + left = {_normalize(item) for item in left_keywords if _clean_text(item)} + right = {_normalize(item) for item in right_keywords if _clean_text(item)} + if not left or not right: + return 0.0 + overlap = left & right + if not overlap: + return 0.0 + return len(overlap) / max(1, min(len(left), len(right))) + + +def _explicit_topic_bucket_from_payload(answer_payload: Dict[str, Any] | None) -> Dict[str, Any]: + metadata = dict((answer_payload or {}).get("metadata", {}) or {}) + explicit = metadata.get("topic_bucket") + if not isinstance(explicit, Mapping): + return {} + keywords = _dedupe(explicit.get("topic_keywords", explicit.get("keywords", [])) or [], max_items=24) + label = _clean_text(explicit.get("topic_label", explicit.get("label", ""))) + if not keywords and label: + keywords = _topic_bucket_keywords(label, max_items=12) + bucket_id = _clean_text(explicit.get("topic_bucket_id", explicit.get("id", ""))) + if not bucket_id: + bucket_id = _topic_bucket_id_from_keywords(keywords or [label]) + return { + "topic_bucket_id": bucket_id, + "topic_label": label or _topic_bucket_label_from_keywords(keywords), + "topic_keywords": keywords, + "topic_confidence": float(explicit.get("confidence", explicit.get("topic_confidence", 0.92)) or 0.92), + "topic_assignment_source": "explicit_payload", + } + + +def _coerce_topic_bucket(metadata: Mapping[str, Any]) -> Dict[str, Any]: + bucket_id = _clean_text(metadata.get("topic_bucket_id", "")) + if not bucket_id: + nested = metadata.get("topic_bucket") + if isinstance(nested, Mapping): + bucket_id = _clean_text(nested.get("topic_bucket_id", nested.get("id", ""))) + if not bucket_id: + return {} + keywords = _dedupe(metadata.get("topic_keywords", []) or [], max_items=32) + nested = metadata.get("topic_bucket") + if isinstance(nested, Mapping): + keywords = _dedupe([*keywords, *(nested.get("topic_keywords", nested.get("keywords", [])) or [])], max_items=32) + label = _clean_text(metadata.get("topic_label", "")) + if not label and isinstance(nested, Mapping): + label = _clean_text(nested.get("topic_label", nested.get("label", ""))) + return { + "topic_bucket_id": bucket_id, + "topic_label": label or _topic_bucket_label_from_keywords(keywords), + "topic_keywords": keywords, + } + + +def _collect_topic_buckets(graph: SessionMemoryGraphV2) -> Dict[str, Dict[str, Any]]: + buckets: Dict[str, Dict[str, Any]] = {} + + def merge(bucket: Mapping[str, Any], *, record_id: str = "", turn_index: int = 0) -> None: + bucket_id = _clean_text(bucket.get("topic_bucket_id", bucket.get("id", ""))) + if not bucket_id: + return + current = buckets.setdefault( + bucket_id, + { + "topic_bucket_id": bucket_id, + "topic_label": _clean_text(bucket.get("topic_label", bucket.get("label", ""))), + "topic_keywords": [], + "record_ids": [], + "last_turn_index": 0, + }, + ) + current["topic_label"] = current.get("topic_label") or _clean_text(bucket.get("topic_label", bucket.get("label", ""))) + current["topic_keywords"] = _dedupe( + [*list(current.get("topic_keywords", []) or []), *(bucket.get("topic_keywords", bucket.get("keywords", [])) or [])], + max_items=32, + ) + if record_id: + current["record_ids"] = _dedupe([*list(current.get("record_ids", []) or []), record_id], max_items=200) + current["last_turn_index"] = max(int(current.get("last_turn_index", 0) or 0), int(turn_index or 0)) + + for record in getattr(graph, "records_by_id", {}).values(): + metadata = dict(record.metadata or {}) + bucket = _coerce_topic_bucket(metadata) + if bucket: + merge(bucket, record_id=record.memory_id, turn_index=int(record.turn_index)) + for turn in list(getattr(graph, "turn_log", []) or []): + metadata = dict(turn.get("metadata", {}) if isinstance(turn, Mapping) else getattr(turn, "metadata", {}) or {}) + bucket = _coerce_topic_bucket(metadata) + if bucket: + merge(bucket, turn_index=int(turn.get("turn_index", 0) if isinstance(turn, Mapping) else getattr(turn, "turn_index", 0) or 0)) + for bucket in buckets.values(): + if not bucket.get("topic_label"): + bucket["topic_label"] = _topic_bucket_label_from_keywords(bucket.get("topic_keywords", []) or []) + return buckets + + +def _assign_topic_bucket_for_text( + graph: SessionMemoryGraphV2, + text: str, + *, + answer_payload: Dict[str, Any] | None = None, + turn_index: int = 0, + create: bool = True, +) -> Dict[str, Any]: + explicit = _explicit_topic_bucket_from_payload(answer_payload) + if explicit: + explicit["turn_index"] = int(turn_index or 0) + return explicit + keywords = _topic_bucket_keywords(text, max_items=24) + if not keywords: + keywords = _dedupe([_clean_text(text)[:40]], max_items=1) + buckets = _collect_topic_buckets(graph) + best_bucket: Dict[str, Any] | None = None + best_score = 0.0 + for bucket in buckets.values(): + score = _topic_bucket_overlap_score(keywords, bucket.get("topic_keywords", []) or []) + if score > best_score: + best_bucket = bucket + best_score = score + if best_bucket and best_score >= 0.22: + merged_keywords = _dedupe([*list(best_bucket.get("topic_keywords", []) or []), *keywords], max_items=32) + return { + "topic_bucket_id": best_bucket["topic_bucket_id"], + "topic_label": best_bucket.get("topic_label") or _topic_bucket_label_from_keywords(merged_keywords), + "topic_keywords": merged_keywords, + "topic_confidence": round(min(0.97, 0.68 + best_score * 0.24), 6), + "topic_assignment_source": "reused_by_overlap", + "topic_match_score": round(best_score, 6), + "turn_index": int(turn_index or 0), + } + bucket_id = _topic_bucket_id_from_keywords(keywords) + return { + "topic_bucket_id": bucket_id, + "topic_label": _topic_bucket_label_from_keywords(keywords), + "topic_keywords": keywords, + "topic_confidence": 0.72 if create else 0.58, + "topic_assignment_source": "created_by_dialog" if create else "query_probe", + "topic_match_score": round(best_score, 6), + "turn_index": int(turn_index or 0), + } + + +def _apply_topic_bucket_to_records(records: List[SessionMemoryRecordV2], topic_bucket: Mapping[str, Any]) -> None: + if not records or not topic_bucket: + return + bucket_id = _clean_text(topic_bucket.get("topic_bucket_id", "")) + if not bucket_id: + return + label = _clean_text(topic_bucket.get("topic_label", "")) or "动态话题" + keywords = _dedupe(topic_bucket.get("topic_keywords", []) or [], max_items=32) + for record in records: + metadata = dict(record.metadata or {}) + metadata.update( + { + "topic_bucket_id": bucket_id, + "topic_label": label, + "topic_keywords": keywords, + "topic_confidence": float(topic_bucket.get("topic_confidence", 0.72) or 0.72), + "topic_assignment_source": _clean_text(topic_bucket.get("topic_assignment_source", "")) or "created_by_dialog", + } + ) + record.metadata = metadata + record.anchor_concepts = _dedupe([*list(record.anchor_concepts or []), f"topic:{label}", *keywords[:8]], max_items=24) + evidence_anchors = list(metadata.get("evidence_anchors", []) or []) + metadata["evidence_anchors"] = _dedupe([*evidence_anchors, f"topic:{label}", *keywords[:8]], max_items=24) + + +def _last_topic_turn(graph: SessionMemoryGraphV2) -> Dict[str, Any]: + for turn in reversed(list(getattr(graph, "turn_log", []) or [])): + metadata = dict(turn.get("metadata", {}) if isinstance(turn, Mapping) else getattr(turn, "metadata", {}) or {}) + bucket = _coerce_topic_bucket(metadata) + if bucket: + bucket["turn_index"] = int(turn.get("turn_index", 0) if isinstance(turn, Mapping) else getattr(turn, "turn_index", 0) or 0) + return bucket + return {} + + +def _add_topic_bridge_edges( + graph: SessionMemoryGraphV2, + *, + previous_topic: Mapping[str, Any], + current_topic: Mapping[str, Any], + current_record_ids: Sequence[str], + turn_index: int, + evidence: str, +) -> Dict[str, Any]: + previous_id = _clean_text(previous_topic.get("topic_bucket_id", "")) + current_id = _clean_text(current_topic.get("topic_bucket_id", "")) + if not previous_id or not current_id or previous_id == current_id or not current_record_ids: + return {"topic_bridge_edge_count": 0} + buckets = _collect_topic_buckets(graph) + previous_record_ids = list((buckets.get(previous_id, {}) or {}).get("record_ids", []) or [])[-4:] + if not previous_record_ids: + return {"topic_bridge_edge_count": 0} + edge_count = 0 + evidence_text = _clean_text(evidence)[:240] + for source_id in previous_record_ids: + for target_id in list(current_record_ids)[:4]: + if source_id == target_id: + continue + edge = SessionMemoryEdgeV2( + edge_id=f"{source_id}->{target_id}:topic_bridge:{previous_id}->{current_id}", + source_memory_id=source_id, + target_memory_id=target_id, + edge_type="topic_bridge", + score=0.54, + model_score=0.0, + evidence_turn=int(turn_index or 0), + evidence=evidence_text, + metadata={ + "from_topic_bucket_id": previous_id, + "to_topic_bucket_id": current_id, + "from_topic_label": _clean_text(previous_topic.get("topic_label", "")), + "to_topic_label": _clean_text(current_topic.get("topic_label", "")), + "bridge_reason": "adjacent_dialog_topic_transition", + }, + ) + graph._upsert_memory_edge(edge) + edge_count += 1 + return { + "topic_bridge_edge_count": edge_count, + "topic_bridge_from": previous_id, + "topic_bridge_to": current_id, + } + + +def _add_dialogue_tunnel_edges( + graph: SessionMemoryGraphV2, + *, + current_topic: Mapping[str, Any], + current_record_ids: Sequence[str], + turn_index: int, + evidence: str, +) -> Dict[str, Any]: + current_id = _clean_text(current_topic.get("topic_bucket_id", "")) + if not current_id or not current_record_ids: + return {"dialogue_tunnel_edge_count": 0} + buckets = _collect_topic_buckets(graph) + edge_count = 0 + evidence_text = _clean_text(evidence)[:240] + source_ids: List[tuple[str, str, str]] = [] + for bucket_id, bucket in sorted( + buckets.items(), + key=lambda item: int(item[1].get("last_turn_index", 0) or 0), + reverse=True, + ): + if bucket_id == current_id: + continue + for source_id in list(bucket.get("record_ids", []) or [])[-2:]: + if source_id not in current_record_ids: + source_ids.append((bucket_id, _clean_text(bucket.get("topic_label", "")), source_id)) + if len(source_ids) >= 6: + break + for source_bucket_id, source_label, source_id in source_ids[:6]: + for target_id in list(current_record_ids)[:2]: + if source_id == target_id: + continue + edge = SessionMemoryEdgeV2( + edge_id=f"{source_id}->{target_id}:dialogue_tunnel:{source_bucket_id}->{current_id}", + source_memory_id=source_id, + target_memory_id=target_id, + edge_type="dialogue_tunnel", + score=0.24, + model_score=0.0, + evidence_turn=int(turn_index or 0), + evidence=evidence_text, + metadata={ + "from_topic_bucket_id": source_bucket_id, + "to_topic_bucket_id": current_id, + "from_topic_label": source_label, + "to_topic_label": _clean_text(current_topic.get("topic_label", "")), + "bridge_reason": "high_resistance_dialogue_level_tunnel", + }, + ) + graph._upsert_memory_edge(edge) + edge_count += 1 + return {"dialogue_tunnel_edge_count": edge_count} + + +def _topic_adjacent_bucket_ids(graph: SessionMemoryGraphV2, bucket_id: str) -> set[str]: + adjacent: set[str] = set() + if not bucket_id: + return adjacent + for edge in getattr(graph, "memory_edges", {}).values(): + if _normalize(edge.edge_type) != "topic_bridge": + continue + metadata = dict(edge.metadata or {}) + left = _clean_text(metadata.get("from_topic_bucket_id", "")) + right = _clean_text(metadata.get("to_topic_bucket_id", "")) + if left == bucket_id and right: + adjacent.add(right) + if right == bucket_id and left: + adjacent.add(left) + return adjacent + + +def _dialogue_tunnel_bucket_ids(graph: SessionMemoryGraphV2, bucket_id: str) -> set[str]: + adjacent: set[str] = set() + if not bucket_id: + return adjacent + for edge in getattr(graph, "memory_edges", {}).values(): + if _normalize(edge.edge_type) != "dialogue_tunnel": + continue + metadata = dict(edge.metadata or {}) + left = _clean_text(metadata.get("from_topic_bucket_id", "")) + right = _clean_text(metadata.get("to_topic_bucket_id", "")) + if left == bucket_id and right: + adjacent.add(right) + if right == bucket_id and left: + adjacent.add(left) + return adjacent + + +def _topic_bridge_requested(query: str) -> bool: + text = _normalize(query) + if not text: + return False + bridge_markers = ( + "关联", + "联系", + "链条", + "脉络", + "延展", + "深入", + "对比", + "整合", + "整体", + "上下文", + "刚才", + "之前", + "上面", + "跨话题", + "隧穿", + "related", + "connect", + "compare", + "context", + "chain", + ) + return any(marker in text for marker in bridge_markers) + + +def _dialogue_tunnel_requested(query: str) -> bool: + text = _normalize(query) + if not text: + return False + markers = ( + "跨话题", + "跨对话", + "不同话题", + "不同对话", + "历史对话", + "长期脉络", + "整体脉络", + "所有相关", + "全局", + "全局记忆", + "远一点", + "更深", + "深层关联", + "对话级隧穿", + "记忆隧穿", + "cross topic", + "cross-topic", + "cross dialogue", + "cross-session", + "global context", + "long range", + ) + return any(marker in text for marker in markers) + + +def _topic_bucket_record_to_hit( + record: SessionMemoryRecordV2, + *, + query_topic: Mapping[str, Any], + rank: int, + rescue_kind: str = "topic_bucket", +) -> MemoryHit: + metadata = dict(record.metadata or {}) + dialogue_tunnel = _normalize(rescue_kind) == "dialogue_tunnel" + metadata.update( + { + "topic_bucket_rescue": not dialogue_tunnel, + "dialogue_tunnel_rescue": dialogue_tunnel, + "topic_bucket_rescue_rank": int(rank), + "topic_bucket_query_id": _clean_text(query_topic.get("topic_bucket_id", "")), + "topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")), + "topic_bucket_same": not dialogue_tunnel, + "topic_bucket_bridge": False, + "topic_bucket_bridge_allowed": False, + "topic_bucket_dialogue_tunnel_allowed": dialogue_tunnel, + "topic_bucket_overlap": 0.0 if dialogue_tunnel else 1.0, + } + ) + category = _normalize(record.category) + value_text = _normalize(record.value) + hardish = ( + category == "constraint" + or _normalize(metadata.get("memory_type", "")) == "hard_constraint" + or _normalize(metadata.get("durability", "")) == "hard" + or _normalize(metadata.get("conflict_policy", "")) == "must_preserve" + or any(marker in value_text for marker in ("过敏", "必须", "避开", "禁止", "不能", "must", "avoid", "allergy")) + ) + base_score = max(float(record.confidence), float(record.salience), 0.62) + if dialogue_tunnel: + if hardish: + base_score += 0.42 + metadata.setdefault("memory_type", "hard_constraint") + metadata.setdefault("durability", "hard") + metadata.setdefault("conflict_policy", "must_preserve") + elif category == "preference": + base_score += 0.22 + metadata.setdefault("memory_type", "durable_preference") + metadata.setdefault("durability", "long_term") + else: + base_score += 0.12 + metadata["dialogue_tunnel_resistance"] = "high" + elif hardish: + base_score += 1.35 + metadata.setdefault("memory_type", "hard_constraint") + metadata.setdefault("durability", "hard") + metadata.setdefault("conflict_policy", "must_preserve") + elif category == "preference": + base_score += 0.72 + metadata.setdefault("memory_type", "durable_preference") + metadata.setdefault("durability", "long_term") + else: + base_score += 0.38 + return MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=base_score, + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=metadata, + ) + + +def _profile_query_rescue_hits( + graph: SessionMemoryGraphV2, + query: str, + *, + top_k: int, +) -> List[MemoryHit]: + query_raw_tokens = set(_path_utility_tokens(query)) + query_tokens = _profile_query_expanded_tokens(query) + intent = infer_profile_query_intent(query) + if not bool(intent.get("enabled")): + return [] + rescued: List[tuple[float, MemoryHit]] = [] + for record in getattr(graph, "records_by_id", {}).values(): + metadata = dict(record.metadata or {}) + if record.state != "active": + continue + if not is_profile_layer_record( + category=record.category, + source_kind=record.source_kind, + semantic_slot=metadata.get("semantic_slot", ""), + metadata=metadata, + ): + continue + delta, reason = profile_query_score_delta( + query=query, + query_tokens=query_tokens, + category=record.category, + source_kind=record.source_kind, + semantic_slot=metadata.get("semantic_slot", ""), + value=record.value, + anchors=record.anchor_concepts, + metadata=metadata, + ) + if delta <= 0: + continue + match_score, overlap_tokens, raw_overlap_tokens = _profile_hit_match_score( + query_raw_tokens, + query_tokens, + MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=max(float(record.confidence), float(record.salience), 0.01), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=metadata, + ), + ) + if match_score <= 0.0: + continue + if match_score < 0.34 and not raw_overlap_tokens: + continue + metadata.update( + { + "profile_query_rescue": True, + "profile_query_rescue_reason": reason or "profile_route", + "profile_query_match_score": round(match_score, 6), + "profile_query_overlap_tokens": list(overlap_tokens), + "profile_query_raw_overlap_tokens": list(raw_overlap_tokens), + "topic_bucket_profile_route_preserved": True, + "match_reason": ",".join(_dedupe([metadata.get("match_reason", ""), reason or "profile_route"], max_items=4)), + } + ) + hit = MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=max(float(record.confidence), float(record.salience), 0.62) + float(delta) + float(match_score), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=metadata, + ) + rescued.append((match_score, hit)) + rescued.sort(key=lambda item: (float(item[0]), float(item[1].score), int(item[1].turn_index)), reverse=True) + return [ + hit + for _, hit in rescued[: max(1, min(24, int(top_k or 1) * 3))] + ] + + +def _memory_hit_from_record(record: SessionMemoryRecordV2, *, score: float | None = None, metadata: Mapping[str, Any] | None = None) -> MemoryHit: + record_metadata = {**dict(record.metadata or {}), **dict(metadata or {})} + return MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=max(float(record.confidence), float(record.salience), 0.01) if score is None else float(score), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=int(record.turn_index), + metadata=record_metadata, + ) + + +_FACET_NUMERIC_QUERY_TOKENS = { + "amount", + "count", + "counts", + "duration", + "durations", + "many", + "much", + "number", + "quantity", + "sum", + "total", + "totals", + "weeks", + "week", + "hours", + "hour", + "dollars", + "dollar", + "tenants", + "tickets", +} +_FACET_TEMPORAL_QUERY_TOKENS = { + "after", + "before", + "date", + "deadline", + "end", + "finish", + "finished", + "start", + "started", + "time", + "when", +} + + +def _facet_query_pack_hits( + graph: SessionMemoryGraphV2, + query: str, + final_hits: Sequence[MemoryHit], + *, + top_k: int, +) -> Dict[str, Any]: + query_tokens = set(_path_utility_tokens(query)) + if not query_tokens: + return {"hits": list(final_hits), "metadata": {"facet_query_pack_enabled": False, "facet_query_pack_reason": "empty_query_tokens"}} + numeric_query = bool(query_tokens & _FACET_NUMERIC_QUERY_TOKENS) + temporal_query = bool(query_tokens & _FACET_TEMPORAL_QUERY_TOKENS) + if not numeric_query and not temporal_query and not any("facet" in _normalize(token) for token in query_tokens): + return {"hits": list(final_hits), "metadata": {"facet_query_pack_enabled": False, "facet_query_pack_reason": "no_facet_intent"}} + + candidate_rows: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = [] + for record in getattr(graph, "records_by_id", {}).values(): + metadata = dict(record.metadata or {}) + if _normalize(metadata.get("content_variant", "")) != "event_facet_write": + continue + if record.state not in {"active", "parallel_active", "evidence"}: + continue + facet_type = _normalize(metadata.get("facet_type", "")) + parent_slot_key = _clean_text(metadata.get("facet_parent_slot_key", "")) + parent = next( + ( + candidate + for candidate in getattr(graph, "records_by_id", {}).values() + if _clean_text(candidate.slot_key).lower() == parent_slot_key.lower() + ), + None, + ) + parent_text = " ".join( + [ + _clean_text(parent.value if parent else ""), + " ".join(parent.anchor_concepts if parent else []), + _clean_text(dict(parent.metadata or {}).get("source_span", "") if parent else ""), + ] + ) + facet_text = " ".join( + [ + record.value, + " ".join(record.anchor_concepts or []), + _clean_text(metadata.get("facet_type", "")), + _clean_text(metadata.get("facet_role", "")), + _clean_text(metadata.get("facet_source_span", "")), + parent_text, + ] + ) + facet_tokens = set(_path_utility_tokens(facet_text)) + parent_tokens = set(_path_utility_tokens(parent_text)) + overlap_tokens = query_tokens & facet_tokens + parent_overlap_tokens = query_tokens & parent_tokens + unit_overlap = bool(query_tokens & set(_path_utility_tokens(record.value))) + score = 0.0 + if overlap_tokens: + score += min(0.72, len(overlap_tokens) / max(1.0, len(query_tokens)) * 1.2) + if parent_overlap_tokens: + score += min(0.72, len(parent_overlap_tokens) / max(1.0, len(query_tokens)) * 1.35) + if numeric_query and facet_type == "numeric": + score += 0.42 + if unit_overlap: + score += 0.42 + if temporal_query and facet_type == "temporal": + score += 0.32 + if facet_type == "entity" and parent_overlap_tokens: + score += 0.26 + if parent is not None and _normalize(dict(parent.metadata or {}).get("content_variant", "")) == "llm_semantic_write": + score += 0.08 + if score < 0.58: + continue + candidate_rows.append( + ( + round(min(2.75, 1.18 + score), 6), + record, + parent, + { + "facet_query_pack_overlap_tokens": sorted(overlap_tokens)[:12], + "facet_query_pack_parent_overlap_tokens": sorted(parent_overlap_tokens)[:12], + "facet_query_pack_unit_overlap": bool(unit_overlap), + "facet_query_pack_score": round(score, 6), + }, + ) + ) + + if not candidate_rows: + return { + "hits": list(final_hits), + "metadata": { + "facet_query_pack_enabled": True, + "facet_query_pack_inserted_hit_count": 0, + "facet_query_pack_candidate_count": 0, + }, + } + + candidate_rows.sort(key=lambda item: (float(item[0]), int(item[1].turn_index)), reverse=True) + selected = candidate_rows[: max(4, min(18, int(top_k or 1) * 2))] + packed_hits: List[MemoryHit] = [] + for score, facet_record, parent, extra_metadata in selected: + facet_metadata = { + **extra_metadata, + "facet_query_pack": True, + "evidence_snippet_role": "facet_query_attribute", + } + packed_hits.append(_memory_hit_from_record(facet_record, score=score, metadata=facet_metadata)) + if parent is not None: + packed_hits.append( + _memory_hit_from_record( + parent, + score=max(1.12, score - 0.04), + metadata={ + **extra_metadata, + "facet_query_pack": True, + "evidence_snippet_role": "facet_parent_event", + "facet_query_pack_child_id": facet_record.memory_id, + }, + ) + ) + + merged: List[MemoryHit] = [] + seen_ids: set[str] = set() + for hit in [*packed_hits, *list(final_hits)]: + if hit.memory_id and hit.memory_id in seen_ids: + continue + if hit.memory_id: + seen_ids.add(hit.memory_id) + merged.append(hit) + return { + "hits": merged, + "metadata": { + "facet_query_pack_enabled": True, + "facet_query_pack_candidate_count": len(candidate_rows), + "facet_query_pack_inserted_hit_count": len(packed_hits), + "facet_query_pack_numeric_query": bool(numeric_query), + "facet_query_pack_temporal_query": bool(temporal_query), + }, + } + + +_UNIT_COVERAGE_COUNT_TOKENS = { + "amount", + "amounts", + "count", + "counts", + "cost", + "costs", + "dollar", + "dollars", + "how", + "many", + "much", + "number", + "minimum", + "maximum", + "paid", + "percent", + "percentage", + "price", + "prices", + "sale", + "sales", + "sell", + "sold", + "total", + "totals", + "sum", + "value", + "valued", + "values", + "worth", + "items", + "projects", + "events", + "things", +} +_UNIT_COVERAGE_TEMPORAL_TOKENS = { + "ago", + "after", + "before", + "between", + "consecutive", + "date", + "days", + "day", + "first", + "last", + "months", + "month", + "order", + "passed", + "since", + "weeks", + "week", +} +_MULTI_UNIT_CHAIN_TEMPORAL_COMPARISON_TOKENS = { + "after", + "before", + "between", + "consecutive", + "earlier", + "first", + "later", + "last", + "order", + "since", +} +_UNIT_COVERAGE_QUERY_DROP_TOKENS = { + "date", + "fri", + "friday", + "mon", + "monday", + "question", + "sat", + "saturday", + "sun", + "sunday", + "thu", + "thursday", + "tue", + "tuesday", + "wed", + "wednesday", +} +_MULTI_UNIT_CHAIN_DISABLED_MODES = {"", "off", "disabled", "none", "false", "0"} +_MULTI_UNIT_CHAIN_COUNT_TOKENS = { + *_UNIT_COVERAGE_COUNT_TOKENS, + "which", + "each", + "all", + "both", +} +_MULTI_UNIT_CHAIN_FOCUS_DROP_TOKENS = { + *_MULTI_UNIT_CHAIN_COUNT_TOKENS, + *_UNIT_COVERAGE_TEMPORAL_TOKENS, + "i", + "me", + "my", + "mine", + "am", + "is", + "are", + "was", + "were", + "need", + "needs", + "needed", + "currently", + "did", + "does", + "fri", + "friday", + "mon", + "monday", + "or", + "question", + "sat", + "saturday", + "sun", + "sunday", + "thu", + "thursday", + "tue", + "tuesday", + "wed", + "wednesday", + "what", + "which", + "who", +} +_MULTI_UNIT_CHAIN_FACET_TYPES = {"action", "entity", "numeric", "role", "evidence_role", "state", "temporal"} +_MULTI_UNIT_CHAIN_UNIT_KINDS = { + "action_unit", + "target_entity", + "numeric_quantity", + "leadership_role", + "participation_role", + "state_status", + "temporal_anchor", + "evidence_role", + "profile_shadow_unit", +} +_MULTI_UNIT_CHAIN_NUMERIC_VALUE_TOKENS = { + "amount", + "appraisal", + "appraised", + "cost", + "costs", + "dollar", + "dollars", + "minimum", + "paid", + "price", + "prices", + "sale", + "sell", + "sold", + "total", + "value", + "valued", + "values", + "worth", +} + + +def _multi_unit_chain_numeric_signal( + unit_kind: str, + facet_type: str, + text: str, + tokens: set[str], +) -> float: + compact_text = _clean_text(text) + date_like_only = bool( + re.fullmatch( + r"(?:\d{4}[/-]\d{1,2}(?:[/-]\d{1,2})?|\d{1,2}[/-]\d{1,2}(?:[/-]\d{2,4})?|\d{4}/\d{2})", + compact_text, + ) + ) + if date_like_only: + return 0.0 + signal = 0.0 + if unit_kind == "numeric_quantity" or facet_type == "numeric": + signal += 0.55 + if tokens & _MULTI_UNIT_CHAIN_NUMERIC_VALUE_TOKENS: + signal += 0.28 + if re.search(r"(?:[$€£¥]\s*\d|\b\d+(?:,\d{3})*(?:\.\d+)?\s*(?:dollars?|usd|bucks?|yuan|rmb)\b)", text, flags=re.IGNORECASE): + signal += 0.72 + elif re.search(r"\b\d+(?:,\d{3})*(?:\.\d+)?\s*(?:comments?|items?|pieces?|kits?|projects?|doctors?|weddings?|hours?|weeks?|days?|months?|years?|miles?)\b", text, flags=re.IGNORECASE): + signal += 0.42 + return min(1.1, signal) + + +def _multi_unit_chain_date_like_numeric(text: str) -> bool: + compact_text = _clean_text(text) + return bool( + re.fullmatch( + r"(?:\d{4}[/-]\d{1,2}(?:[/-]\d{1,2})?|\d{1,2}[/-]\d{1,2}(?:[/-]\d{2,4})?|\d{4}/\d{2})", + compact_text, + ) + ) + + +def _multi_unit_chain_temporal_anchor_signal(text: str, tokens: set[str]) -> float: + signal = 0.0 + if tokens & _UNIT_COVERAGE_TEMPORAL_TOKENS: + signal += 0.22 + if re.search( + r"\b(?:about|around|roughly|few|several|couple|last|previous|earlier)\s+" + r"(?:a\s+)?(?:day|days|week|weeks|month|months|year|years)\s+ago\b|" + r"\b(?:yesterday|today|tomorrow|last\s+week|last\s+month|a\s+few\s+months\s+ago)\b", + text, + flags=re.IGNORECASE, + ): + signal += 0.72 + return min(1.0, signal) + + +_PROFILE_SHADOW_EVENTLIKE_SLOT_HINTS = { + "action", + "appointment", + "constraint", + "deadline", + "exchange", + "obligation", + "pickup", + "plan", + "preference", + "return", + "status", + "task", +} +_PROFILE_SHADOW_EVENTLIKE_TEXT_HINTS = { + "bought", + "buy", + "completed", + "did", + "exchange", + "exchanged", + "finish", + "finished", + "got", + "have to", + "need", + "needed", + "needs", + "paid", + "pick", + "picked", + "return", + "returned", + "should", + "still", + "took", + "went", +} + + +def _profile_shadow_eventlike_record(record: SessionMemoryRecordV2, metadata: Mapping[str, Any]) -> bool: + if _normalize(metadata.get("content_variant", "")) != "profile_shadow_from_writer": + return False + if record.state not in {"active", "parallel_active", "evidence"}: + return False + slot_text = " ".join( + [ + _clean_text(record.category), + _clean_text(record.relation), + _clean_text(record.slot_key), + _clean_text(metadata.get("semantic_slot", "")), + _clean_text(metadata.get("profile_type", "")), + ] + ).lower() + value_text = " ".join( + [ + _clean_text(record.value), + _clean_text(metadata.get("source_span", "")), + _clean_text(metadata.get("raw_text", "")), + ] + ).lower() + return bool( + any(hint in slot_text for hint in _PROFILE_SHADOW_EVENTLIKE_SLOT_HINTS) + or any(hint in value_text for hint in _PROFILE_SHADOW_EVENTLIKE_TEXT_HINTS) + ) + + +def _profile_shadow_unit_text(record: SessionMemoryRecordV2, metadata: Mapping[str, Any]) -> str: + return " ".join( + [ + _clean_text(record.value), + _clean_text(record.category), + _clean_text(record.relation), + _clean_text(record.slot_key), + _clean_text(metadata.get("semantic_slot", "")), + _clean_text(metadata.get("profile_type", "")), + _clean_text(metadata.get("profile_domain", "")), + _clean_text(metadata.get("source_span", "")), + _clean_text(metadata.get("raw_text", "")), + ] + ) + + +def _unit_coverage_pack_hits( + graph: SessionMemoryGraphV2, + query: str, + final_hits: Sequence[MemoryHit], + *, + top_k: int, +) -> Dict[str, Any]: + raw_query_tokens = set(_path_utility_tokens(query)) + query_tokens = {token for token in raw_query_tokens if token not in _UNIT_COVERAGE_QUERY_DROP_TOKENS} + if not query_tokens: + return {"hits": list(final_hits), "metadata": {"unit_coverage_pack_enabled": False, "unit_coverage_reason": "empty_query"}} + count_intent = bool(query_tokens & _UNIT_COVERAGE_COUNT_TOKENS) + percentage_intent = bool(query_tokens & {"percent", "percentage"}) + temporal_intent = bool(raw_query_tokens & _MULTI_UNIT_CHAIN_TEMPORAL_COMPARISON_TOKENS) + direct_unit_intent = not count_intent and not temporal_intent and len(query_tokens) >= 2 + if not count_intent and not temporal_intent and not direct_unit_intent: + return { + "hits": list(final_hits), + "metadata": {"unit_coverage_pack_enabled": False, "unit_coverage_reason": "no_unit_intent"}, + } + + records = list(getattr(graph, "records_by_id", {}).values()) + parent_by_slot = {_clean_text(record.slot_key).lower(): record for record in records} + candidates: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = [] + for record in records: + metadata = dict(record.metadata or {}) + if record.state not in {"active", "parallel_active", "evidence"}: + continue + profile_shadow_unit = _profile_shadow_eventlike_record(record, metadata) + if profile_shadow_unit: + unit_kind = "profile_shadow_unit" + facet_type = "action" + parent_slot = _clean_text(record.slot_key).lower() + parent = None + parent_text = "" + unit_text = _profile_shadow_unit_text(record, metadata) + else: + if _normalize(metadata.get("content_variant", "")) != "event_facet_write": + continue + if _normalize(metadata.get("facet_layer_version", "")) not in {"event_unit_v1", "event_facet_v1"}: + continue + unit_kind = _normalize(metadata.get("unit_kind", "")) + facet_type = _normalize(metadata.get("facet_type", "")) + parent_slot = _clean_text(metadata.get("facet_parent_slot_key", "")).lower() + parent = parent_by_slot.get(parent_slot) + parent_text = " ".join( + [ + _clean_text(parent.value if parent else ""), + _clean_text(dict(parent.metadata or {}).get("source_span", "") if parent else ""), + " ".join(parent.anchor_concepts if parent else []), + ] + ) + unit_text = " ".join( + [ + _clean_text(record.value), + _clean_text(metadata.get("facet_role", "")), + _clean_text(metadata.get("unit_kind", "")), + _clean_text(metadata.get("action", "")), + _clean_text(metadata.get("target", "")), + _clean_text(metadata.get("quantity", "")), + _clean_text(metadata.get("unit", "")), + _clean_text(metadata.get("normalized_time", "")), + _clean_text(metadata.get("status", "")), + _clean_text(metadata.get("facet_source_span", "")), + parent_text, + ] + ) + unit_tokens = set(_path_utility_tokens(unit_text)) + semantic_event_unit = _normalize(record.memory_id).startswith("tmcra.event.") or _normalize(parent_slot).startswith("tmcra.event.") + overlap = query_tokens & unit_tokens + score = 0.0 + if overlap: + score += min(0.92, len(overlap) / max(1.0, len(query_tokens)) * 1.55) + if count_intent and facet_type in {"action", "entity", "numeric", "role", "evidence_role"}: + score += 0.42 + if temporal_intent and facet_type in {"temporal", "action", "role", "entity"}: + score += 0.38 + if direct_unit_intent and facet_type in {"action", "entity", "state", "numeric"}: + score += 0.30 + if direct_unit_intent and len(overlap) >= 2: + score += 0.36 + if unit_kind in {"action_unit", "target_entity", "leadership_role", "participation_role"}: + score += 0.16 + if unit_kind in {"numeric_quantity", "temporal_anchor", "state_status"}: + score += 0.12 + if profile_shadow_unit: + score += 0.28 + if count_intent and facet_type in {"action", "entity"}: + score += 0.24 + semantic_event_priority = bool(semantic_event_unit and (count_intent or percentage_intent)) + if semantic_event_priority: + score += 0.46 + if percentage_intent and ( + unit_kind == "numeric_quantity" + or facet_type == "numeric" + or re.search(r"\b\d+(?:\.\d+)?\s*%|\b\d+(?:\.\d+)?\s*percent\b", unit_text, flags=re.IGNORECASE) + ): + score += 0.72 + if parent is not None: + score += 0.08 + if score < 0.55: + continue + candidates.append( + ( + round(min(3.2, 1.28 + score), 6), + record, + parent, + { + "unit_coverage_overlap_tokens": sorted(overlap)[:14], + "unit_coverage_score": round(score, 6), + "unit_kind": unit_kind, + "facet_type": facet_type, + "unit_coverage_count_intent": bool(count_intent), + "unit_coverage_percentage_intent": bool(percentage_intent), + "unit_coverage_temporal_intent": bool(temporal_intent), + "unit_coverage_direct_unit_intent": bool(direct_unit_intent), + "unit_coverage_profile_shadow_unit": bool(profile_shadow_unit), + "unit_coverage_semantic_event_unit": bool(semantic_event_unit), + "unit_coverage_semantic_event_priority": bool(semantic_event_priority), + }, + ) + ) + + if not candidates: + return { + "hits": list(final_hits), + "metadata": { + "unit_coverage_pack_enabled": True, + "unit_coverage_candidate_count": 0, + "unit_coverage_inserted_hit_count": 0, + "unit_coverage_count_intent": bool(count_intent), + "unit_coverage_percentage_intent": bool(percentage_intent), + "unit_coverage_temporal_intent": bool(temporal_intent), + "unit_coverage_direct_unit_intent": bool(direct_unit_intent), + }, + } + candidates.sort( + key=lambda item: ( + 1 if bool(item[3].get("unit_coverage_semantic_event_priority", False)) else 0, + 1 + if bool(item[3].get("unit_coverage_percentage_intent")) + and ( + _normalize(item[3].get("unit_kind", "")) == "numeric_quantity" + or _normalize(item[3].get("facet_type", "")) == "numeric" + ) + else 0, + float(item[0]), + int(item[1].turn_index), + ), + reverse=True, + ) + selected: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = [] + seen_unit_values: set[str] = set() + try: + max_selected_units = max(1, int(os.getenv("TMCRA_UNIT_COVERAGE_PACK_MAX_UNITS", "6") or 6)) + except (TypeError, ValueError): + max_selected_units = 6 + for item in candidates: + _, record, parent, metadata = item + key = "|".join( + [ + _normalize(metadata.get("unit_kind", "")), + _normalize(record.value), + _normalize(dict(record.metadata or {}).get("facet_parent_slot_key", "")), + ] + ) + if key in seen_unit_values: + continue + seen_unit_values.add(key) + selected.append(item) + if len(selected) >= max_selected_units: + break + if count_intent: + try: + max_profile_shadow_units = max(0, int(os.getenv("TMCRA_UNIT_COVERAGE_PROFILE_SHADOW_MAX_UNITS", "2") or 2)) + except (TypeError, ValueError): + max_profile_shadow_units = 2 + selected_ids = {item[1].memory_id for item in selected} + added_profile_shadow_units = 0 + for item in candidates: + _, record, _parent, metadata = item + if added_profile_shadow_units >= max_profile_shadow_units: + break + if record.memory_id in selected_ids: + continue + if not bool(metadata.get("unit_coverage_profile_shadow_unit", False)): + continue + selected.append(item) + selected_ids.add(record.memory_id) + added_profile_shadow_units += 1 + packed_hits: List[MemoryHit] = [] + for score, unit_record, parent, extra in selected: + packed_hits.append( + _memory_hit_from_record( + unit_record, + score=score, + metadata={ + **extra, + "unit_coverage_pack": True, + "evidence_snippet_role": "unit_coverage_evidence_unit", + }, + ) + ) + if parent is not None: + unit_metadata = dict(unit_record.metadata or {}) + parent_hit = _memory_hit_from_record( + parent, + score=max(1.12, score - 0.06), + metadata={ + **extra, + "unit_coverage_pack": True, + "evidence_snippet_role": "unit_coverage_parent_event", + "unit_coverage_parent_memory_id": parent.memory_id, + "unit_coverage_child_id": unit_record.memory_id, + "unit_coverage_child_value": unit_record.value, + "unit_coverage_child_source_span": _clean_text( + unit_metadata.get("facet_source_span", "") or unit_metadata.get("source_span", "") + ), + }, + ) + parent_hit.memory_id = f"{parent.memory_id}#unit_parent:{unit_record.memory_id}" + packed_hits.append(parent_hit) + merged: List[MemoryHit] = [] + seen_ids: set[str] = set() + try: + insertion_index = max(0, min(len(final_hits), int(os.getenv("TMCRA_UNIT_COVERAGE_PACK_INSERT_AFTER", "2") or 2))) + except (TypeError, ValueError): + insertion_index = min(len(final_hits), 2) + ordered_hits = [*list(final_hits[:insertion_index]), *packed_hits, *list(final_hits[insertion_index:])] + for hit in ordered_hits: + if hit.memory_id and hit.memory_id in seen_ids: + continue + if hit.memory_id: + seen_ids.add(hit.memory_id) + merged.append(hit) + return { + "hits": merged, + "metadata": { + "unit_coverage_pack_enabled": True, + "unit_coverage_candidate_count": len(candidates), + "unit_coverage_selected_unit_count": len(selected), + "unit_coverage_inserted_hit_count": len(packed_hits), + "unit_coverage_count_intent": bool(count_intent), + "unit_coverage_percentage_intent": bool(percentage_intent), + "unit_coverage_temporal_intent": bool(temporal_intent), + "unit_coverage_direct_unit_intent": bool(direct_unit_intent), + }, + } + + +def _multi_unit_chain_focus_tokens(query: str) -> set[str]: + tokens = set(_path_utility_tokens(query)) + return { + normalized + for token in tokens + for normalized in [_multi_unit_chain_normalize_token(token)] + if normalized and normalized not in _MULTI_UNIT_CHAIN_FOCUS_DROP_TOKENS and len(normalized) > 2 + } + + +def _multi_unit_chain_normalize_token(token: str) -> str: + raw = _normalize(token) + if not raw: + return "" + aliases = { + "bought": "buy", + "buying": "buy", + "purchased": "buy", + "purchasing": "buy", + "worked": "work", + "working": "work", + "started": "start", + "starting": "start", + "finished": "finish", + "finishing": "finish", + "completed": "complete", + "completing": "complete", + "returned": "return", + "returning": "return", + "met": "meet", + "meeting": "meet", + "picked": "pick", + "picking": "pick", + "led": "lead", + "leading": "lead", + "kits": "kit", + "items": "item", + "projects": "project", + "events": "event", + "clothes": "clothing", + } + if raw in aliases: + return aliases[raw] + if len(raw) > 4 and raw.endswith("ies"): + return raw[:-3] + "y" + if len(raw) > 4 and raw.endswith("ing"): + stem = raw[:-3] + if len(stem) > 3 and stem[-1] == stem[-2]: + stem = stem[:-1] + return stem + if len(raw) > 3 and raw.endswith("ed"): + stem = raw[:-2] + if len(stem) > 3 and stem[-1] == stem[-2]: + stem = stem[:-1] + return stem + if len(raw) > 3 and raw.endswith("s") and not raw.endswith("ss"): + return raw[:-1] + return raw + + +def _multi_unit_chain_normalized_tokens(text: str) -> set[str]: + return { + normalized + for token in _path_utility_tokens(text) + for normalized in [_multi_unit_chain_normalize_token(token)] + if normalized + } + + +def _multi_unit_chain_hit_text(record: SessionMemoryRecordV2, parent: SessionMemoryRecordV2 | None) -> str: + metadata = dict(record.metadata or {}) + parent_metadata = dict(parent.metadata or {}) if parent is not None else {} + return " ".join( + [ + _clean_text(record.value), + _clean_text(record.category), + _clean_text(record.relation), + _clean_text(metadata.get("facet_type", "")), + _clean_text(metadata.get("unit_kind", "")), + _clean_text(metadata.get("facet_role", "")), + _clean_text(metadata.get("facet_value", "")), + _clean_text(metadata.get("facet_source_span", "")), + _clean_text(metadata.get("action", "")), + _clean_text(metadata.get("target", "")), + _clean_text(metadata.get("quantity", "")), + _clean_text(metadata.get("status", "")), + _clean_text(parent.value if parent else ""), + _clean_text(parent_metadata.get("source_span", "")), + _clean_text(parent_metadata.get("raw_text", "")), + ] + ) + + +def _multi_unit_chain_local_hit_text(record: SessionMemoryRecordV2, parent: SessionMemoryRecordV2 | None) -> str: + metadata = dict(record.metadata or {}) + parent_metadata = dict(parent.metadata or {}) if parent is not None else {} + return " ".join( + [ + _clean_text(record.value), + _clean_text(record.category), + _clean_text(record.relation), + _clean_text(metadata.get("facet_type", "")), + _clean_text(metadata.get("unit_kind", "")), + _clean_text(metadata.get("facet_role", "")), + _clean_text(metadata.get("facet_value", "")), + _clean_text(metadata.get("facet_source_span", "")), + _clean_text(metadata.get("action", "")), + _clean_text(metadata.get("target", "")), + _clean_text(metadata.get("quantity", "")), + _clean_text(metadata.get("status", "")), + _clean_text(parent.value if parent else ""), + _clean_text(parent_metadata.get("source_span", "")), + ] + ) + + +def _multi_unit_chain_slot_hits( + graph: SessionMemoryGraphV2, + query: str, + final_hits: Sequence[MemoryHit], + *, + top_k: int, +) -> Dict[str, Any]: + mode = _normalize(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MODE", "on")) + if mode in _MULTI_UNIT_CHAIN_DISABLED_MODES: + return { + "hits": list(final_hits), + "metadata": {"multi_unit_chain_slot_enabled": False, "multi_unit_chain_slot_reason": "disabled"}, + } + query_tokens = set(_path_utility_tokens(query)) + focus_tokens = _multi_unit_chain_focus_tokens(query) + temporal_comparison_intent = bool(query_tokens & _MULTI_UNIT_CHAIN_TEMPORAL_COMPARISON_TOKENS) and len(focus_tokens) >= 2 + aggregation_or_joiner_intent = bool( + re.search(r"\b(?:and|or|both|each|total|sum|minimum|maximum|amount|count|many|number)\b", str(query), flags=re.IGNORECASE) + ) + numeric_aggregation_intent = bool( + query_tokens + & ( + _MULTI_UNIT_CHAIN_NUMERIC_VALUE_TOKENS + | {"comment", "comments", "number", "percent", "percentage", "sum", "total"} + ) + ) + count_or_aggregation_intent = bool(query_tokens & _MULTI_UNIT_CHAIN_COUNT_TOKENS) and (aggregation_or_joiner_intent or len(focus_tokens) >= 3) + multi_intent = count_or_aggregation_intent or temporal_comparison_intent + if not query_tokens or not multi_intent: + return { + "hits": list(final_hits), + "metadata": { + "multi_unit_chain_slot_enabled": False, + "multi_unit_chain_slot_reason": "no_multi_intent", + }, + } + if not focus_tokens: + return { + "hits": list(final_hits), + "metadata": { + "multi_unit_chain_slot_enabled": True, + "multi_unit_chain_slot_formed": False, + "multi_unit_chain_slot_reason": "no_focus_tokens", + }, + } + + records = list(getattr(graph, "records_by_id", {}).values()) + parent_by_slot = {_clean_text(record.slot_key).lower(): record for record in records} + candidates: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = [] + for record in records: + metadata = dict(record.metadata or {}) + if record.state not in {"active", "parallel_active", "evidence"}: + continue + profile_shadow_unit = _profile_shadow_eventlike_record(record, metadata) + if profile_shadow_unit: + facet_type = "action" + unit_kind = "profile_shadow_unit" + parent_slot = _clean_text(record.slot_key).lower() + parent = None + text = _profile_shadow_unit_text(record, metadata) + local_text = text + else: + if _normalize(metadata.get("content_variant", "")) != "event_facet_write": + continue + if _normalize(metadata.get("facet_layer_version", "")) != "event_unit_v1": + continue + facet_type = _normalize(metadata.get("facet_type", "")) + unit_kind = _normalize(metadata.get("unit_kind", "")) + if facet_type not in _MULTI_UNIT_CHAIN_FACET_TYPES and unit_kind not in _MULTI_UNIT_CHAIN_UNIT_KINDS: + continue + parent_slot = _clean_text(metadata.get("facet_parent_slot_key", "")).lower() + parent = parent_by_slot.get(parent_slot) + text = _multi_unit_chain_hit_text(record, parent) + local_text = _multi_unit_chain_local_hit_text(record, parent) + semantic_event_unit = _normalize(record.memory_id).startswith("tmcra.event.") or _normalize(parent_slot).startswith("tmcra.event.") + unit_tokens = _multi_unit_chain_normalized_tokens(text) + local_tokens = _multi_unit_chain_normalized_tokens(local_text) + focus_overlap = focus_tokens & unit_tokens + if not focus_overlap: + continue + score = min(1.2, len(focus_overlap) / max(1.0, len(focus_tokens)) * 1.6) + local_numeric_signal = _multi_unit_chain_numeric_signal(unit_kind, facet_type, local_text, local_tokens) + if _multi_unit_chain_date_like_numeric(local_text): + numeric_signal = 0.0 + else: + numeric_signal = local_numeric_signal + if numeric_signal <= 0.0: + numeric_signal = _multi_unit_chain_numeric_signal(unit_kind, facet_type, text, unit_tokens) + temporal_anchor_signal = _multi_unit_chain_temporal_anchor_signal(text, unit_tokens) + local_temporal_anchor_signal = _multi_unit_chain_temporal_anchor_signal(local_text, local_tokens) + if facet_type in {"action", "entity", "role", "numeric"}: + score += 0.34 + if unit_kind in {"action_unit", "target_entity", "leadership_role", "participation_role", "numeric_quantity"}: + score += 0.24 + if profile_shadow_unit: + score += 0.30 + semantic_event_priority = bool(semantic_event_unit and (count_or_aggregation_intent or numeric_aggregation_intent)) + if semantic_event_priority: + score += 0.48 + if numeric_signal and numeric_aggregation_intent: + score += numeric_signal + if count_or_aggregation_intent and not numeric_aggregation_intent and facet_type in {"action", "entity", "role"}: + score += 0.36 + if count_or_aggregation_intent and not numeric_aggregation_intent and unit_kind in {"action_unit", "target_entity", "leadership_role", "participation_role"}: + score += 0.22 + if parent is not None: + score += 0.12 + if score < 0.72: + continue + candidates.append( + ( + round(min(4.0, 2.15 + score), 6), + record, + parent, + { + "multi_unit_chain_focus_overlap_tokens": sorted(focus_overlap)[:16], + "multi_unit_chain_score": round(score, 6), + "multi_unit_chain_numeric_signal": round(numeric_signal, 6), + "multi_unit_chain_temporal_anchor_signal": round(temporal_anchor_signal, 6), + "multi_unit_chain_local_temporal_anchor_signal": round(local_temporal_anchor_signal, 6), + "multi_unit_chain_temporal_comparison": bool(temporal_comparison_intent), + "multi_unit_chain_numeric_aggregation": bool(numeric_aggregation_intent), + "unit_kind": unit_kind, + "facet_type": facet_type, + "multi_unit_chain_parent_slot_key": parent_slot, + "multi_unit_chain_profile_shadow_unit": bool(profile_shadow_unit), + "multi_unit_chain_semantic_event_unit": bool(semantic_event_unit), + "multi_unit_chain_semantic_event_priority": bool(semantic_event_priority), + "multi_unit_chain_count_or_aggregation": bool(count_or_aggregation_intent), + }, + ) + ) + + if not candidates: + return { + "hits": list(final_hits), + "metadata": { + "multi_unit_chain_slot_enabled": True, + "multi_unit_chain_slot_formed": False, + "multi_unit_chain_slot_reason": "no_matching_units", + "multi_unit_chain_candidate_count": 0, + "multi_unit_chain_focus_tokens": sorted(focus_tokens)[:24], + }, + } + + if temporal_comparison_intent: + candidates.sort( + key=lambda item: ( + float(item[3].get("multi_unit_chain_local_temporal_anchor_signal", 0.0) or 0.0), + float(item[3].get("multi_unit_chain_temporal_anchor_signal", 0.0) or 0.0), + 1 if bool(item[3].get("multi_unit_chain_semantic_event_priority", False)) else 0, + 0 + if _normalize(item[3].get("unit_kind", "")) == "numeric_quantity" + or _normalize(item[3].get("facet_type", "")) == "numeric" + else 1, + float(item[0]), + int(item[1].turn_index), + ), + reverse=True, + ) + else: + candidates.sort( + key=lambda item: ( + float(item[3].get("multi_unit_chain_numeric_signal", 0.0) or 0.0) + if bool(item[3].get("multi_unit_chain_numeric_aggregation", False)) + else ( + 1.2 + if bool(item[3].get("multi_unit_chain_semantic_event_priority", False)) + else + 0.0 + if _normalize(item[3].get("unit_kind", "")) == "numeric_quantity" + or _normalize(item[3].get("facet_type", "")) == "numeric" + else 1.0 + ), + float(item[0]), + int(item[1].turn_index), + ), + reverse=True, + ) + selected: List[tuple[float, SessionMemoryRecordV2, SessionMemoryRecordV2 | None, Dict[str, Any]]] = [] + seen_parent_slots: set[str] = set() + seen_values: set[str] = set() + max_units = max(2, min(8, int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MAX_UNITS", "6") or 6))) + for item in candidates: + _, record, parent, metadata = item + parent_slot = _clean_text(metadata.get("multi_unit_chain_parent_slot_key", "")) + value_key = "|".join( + [ + _normalize(record.value), + _normalize(metadata.get("unit_kind", "")), + parent_slot, + ] + ) + if value_key in seen_values: + continue + if parent_slot and parent_slot in seen_parent_slots: + # Keep the chain broad: one strongest unit per parent event first. + continue + seen_values.add(value_key) + if parent_slot: + seen_parent_slots.add(parent_slot) + selected.append(item) + if len(selected) >= max_units: + break + if count_or_aggregation_intent: + try: + max_profile_shadow_units = max(0, int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_PROFILE_SHADOW_MAX_UNITS", "2") or 2)) + except (TypeError, ValueError): + max_profile_shadow_units = 2 + selected_ids = {item[1].memory_id for item in selected} + added_profile_shadow_units = 0 + for item in candidates: + _, record, _parent, metadata = item + if added_profile_shadow_units >= max_profile_shadow_units: + break + if record.memory_id in selected_ids: + continue + if not bool(metadata.get("multi_unit_chain_profile_shadow_unit", False)): + continue + selected.append(item) + selected_ids.add(record.memory_id) + parent_slot = _clean_text(metadata.get("multi_unit_chain_parent_slot_key", "")) + if parent_slot: + seen_parent_slots.add(parent_slot) + added_profile_shadow_units += 1 + + min_parents = max(2, min(4, int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MIN_PARENTS", "2") or 2))) + if len(seen_parent_slots) < min_parents: + return { + "hits": list(final_hits), + "metadata": { + "multi_unit_chain_slot_enabled": True, + "multi_unit_chain_slot_formed": False, + "multi_unit_chain_slot_reason": "insufficient_parent_coverage", + "multi_unit_chain_candidate_count": len(candidates), + "multi_unit_chain_selected_unit_count": len(selected), + "multi_unit_chain_parent_count": len(seen_parent_slots), + "multi_unit_chain_focus_tokens": sorted(focus_tokens)[:24], + }, + } + + chain_memory_ids: List[str] = [] + packed_hits: List[MemoryHit] = [] + for score, unit_record, parent, extra in selected: + chain_memory_ids.append(unit_record.memory_id) + packed_hits.append( + _memory_hit_from_record( + unit_record, + score=score, + metadata={ + **extra, + "multi_unit_chain_slot": True, + "multi_unit_chain_bundle": True, + "evidence_snippet_role": "multi_unit_chain_evidence_unit", + }, + ) + ) + if parent is not None: + unit_metadata = dict(unit_record.metadata or {}) + chain_memory_ids.append(parent.memory_id) + parent_hit = _memory_hit_from_record( + parent, + score=max(1.9, score - 0.08), + metadata={ + **extra, + "multi_unit_chain_slot": True, + "multi_unit_chain_bundle": True, + "evidence_snippet_role": "multi_unit_chain_parent_event", + "multi_unit_chain_child_id": unit_record.memory_id, + "multi_unit_chain_child_value": unit_record.value, + "multi_unit_chain_child_source_span": _clean_text( + unit_metadata.get("facet_source_span", "") or unit_metadata.get("source_span", "") + ), + }, + ) + parent_hit.memory_id = f"{parent.memory_id}#multi_parent:{unit_record.memory_id}" + packed_hits.append(parent_hit) + + seen_ids: set[str] = set() + merged: List[MemoryHit] = [] + insertion_index = max(0, min(len(final_hits), int(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_INSERT_AFTER", "1") or 1))) + ordered = [*list(final_hits[:insertion_index]), *packed_hits, *list(final_hits[insertion_index:])] + for hit in ordered: + if hit.memory_id and hit.memory_id in seen_ids: + continue + if hit.memory_id: + seen_ids.add(hit.memory_id) + merged.append(hit) + return { + "hits": merged, + "metadata": { + "multi_unit_chain_slot_enabled": True, + "multi_unit_chain_slot_formed": True, + "multi_unit_chain_slot_reason": "formed", + "multi_unit_chain_candidate_count": len(candidates), + "multi_unit_chain_selected_unit_count": len(selected), + "multi_unit_chain_parent_count": len(seen_parent_slots), + "multi_unit_chain_inserted_hit_count": len(packed_hits), + "multi_unit_chain_memory_ids": _dedupe(chain_memory_ids, max_items=32), + "multi_unit_chain_focus_tokens": sorted(focus_tokens)[:24], + }, + } + + +def _profile_support_source_hits( + graph: SessionMemoryGraphV2, + profile_hit: MemoryHit, + *, + grouped_hits: Mapping[str, Sequence[MemoryHit]], + query: str, + limit: int, +) -> List[MemoryHit]: + metadata = dict(profile_hit.metadata or {}) + support_ids = _dedupe( + [ + *list(metadata.get("profile_support_ids", []) or []), + *list(metadata.get("support_memory_ids", []) or []), + ], + max_items=24, + ) + if not support_ids: + return [] + query_raw_tokens = set(_path_utility_tokens(query)) + query_tokens = _profile_query_expanded_tokens(query) + profile_event_id = _clean_text(metadata.get("profile_first_hybrid_event_id", "")) or _runtime_event_key(profile_hit) + candidates: List[tuple[float, MemoryHit]] = [] + for support_id in support_ids: + record = getattr(graph, "records_by_id", {}).get(support_id) + if record is None or record.state not in {"active", "parallel_active", "evidence"}: + continue + support_hit = _memory_hit_from_record(record) + event_id = _runtime_event_key(support_hit) + if not event_id or event_id == profile_event_id or event_id not in grouped_hits: + continue + match_score, overlap_tokens, raw_overlap_tokens = _profile_hit_match_score(query_raw_tokens, query_tokens, support_hit) + source_score = max(0.80, float(profile_hit.score) - 0.02) + min(0.28, match_score * 0.12) + support_metadata = dict(support_hit.metadata or {}) + profile_metadata = dict(profile_hit.metadata or {}) + support_metadata.update( + { + "profile_first_hybrid_rescue": True, + "profile_first_source_support": True, + "profile_first_parent_memory_id": profile_hit.memory_id, + "profile_first_hybrid_event_id": event_id, + "profile_first_parent_summary": _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value), + "profile_type": _clean_text(support_metadata.get("profile_type", "")) or _clean_text(profile_metadata.get("profile_type", "")), + "profile_domain": _clean_text(support_metadata.get("profile_domain", "")) or _clean_text(profile_metadata.get("profile_domain", "")), + "profile_domain_label": _clean_text(support_metadata.get("profile_domain_label", "")) or _clean_text(profile_metadata.get("profile_domain_label", "")), + "profile_value": _clean_text(support_metadata.get("profile_value", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value), + "profile_summary": _clean_text(support_metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_hit.value), + "profile_query_match_score": round(match_score, 6), + "profile_query_overlap_tokens": list(overlap_tokens), + "profile_query_raw_overlap_tokens": list(raw_overlap_tokens), + "profile_source_pack_role": "source_event_support", + } + ) + candidates.append( + ( + source_score, + MemoryHit( + memory_id=support_hit.memory_id, + category=support_hit.category, + value=support_hit.value, + relation=support_hit.relation, + anchors=list(support_hit.anchors), + score=round(source_score, 6), + source_kind=support_hit.source_kind, + slot_key=support_hit.slot_key, + state=support_hit.state, + turn_index=int(support_hit.turn_index), + metadata=support_metadata, + ), + ) + ) + candidates.sort( + key=lambda item: ( + bool((item[1].metadata or {}).get("profile_query_raw_overlap_tokens")), + float(item[0]), + int(item[1].turn_index), + ), + reverse=True, + ) + return [hit for _, hit in candidates[: max(0, int(limit))]] + + +def _profile_same_event_source_hit(profile_hit: MemoryHit, group_hits: Sequence[MemoryHit], *, event_id: str) -> MemoryHit | None: + source_hit = _support_hit_for_path("speaker_event_source_turn", group_hits) or _representative_event_hit(group_hits) + if source_hit is None or source_hit.memory_id == profile_hit.memory_id: + return None + metadata = dict(source_hit.metadata or {}) + profile_metadata = dict(profile_hit.metadata or {}) + metadata.update( + { + "profile_first_hybrid_rescue": True, + "profile_first_source_support": True, + "profile_first_same_event_support": True, + "profile_first_parent_memory_id": profile_hit.memory_id, + "profile_first_hybrid_event_id": event_id, + "profile_first_parent_summary": _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value), + "profile_type": _clean_text(metadata.get("profile_type", "")) or _clean_text(profile_metadata.get("profile_type", "")), + "profile_domain": _clean_text(metadata.get("profile_domain", "")) or _clean_text(profile_metadata.get("profile_domain", "")), + "profile_domain_label": _clean_text(metadata.get("profile_domain_label", "")) or _clean_text(profile_metadata.get("profile_domain_label", "")), + "profile_value": _clean_text(metadata.get("profile_value", "")) or _clean_text(profile_metadata.get("profile_value", "")) or _clean_text(profile_hit.value), + "profile_summary": _clean_text(metadata.get("profile_summary", "")) or _clean_text(profile_metadata.get("profile_summary", "")) or _clean_text(profile_hit.value), + "profile_source_pack_role": "same_event_source_turn", + "evidence_snippet_role": "profile_source_support", + } + ) + return MemoryHit( + memory_id=source_hit.memory_id, + category=source_hit.category, + value=source_hit.value, + relation=source_hit.relation, + anchors=list(source_hit.anchors), + score=max(float(source_hit.score), float(profile_hit.score) - 0.01), + source_kind=source_hit.source_kind, + slot_key=source_hit.slot_key, + state=source_hit.state, + turn_index=int(source_hit.turn_index), + metadata=metadata, + ) + + +def _profile_first_hybrid_rescue( + graph: SessionMemoryGraphV2, + query: str, + *, + grouped_hits: Mapping[str, Sequence[MemoryHit]], + top_k: int, +) -> Dict[str, Any]: + intent = infer_profile_query_intent(query) + if not bool(intent.get("enabled")): + return {"hits": [], "event_ids": []} + raw_hits = _profile_query_rescue_hits(graph, query, top_k=max(4, int(top_k or 1) * 3)) + limit = max(2, min(4, int(top_k or 1))) + selected_hits: List[MemoryHit] = [] + selected_event_ids: List[str] = [] + seen_event_ids: set[str] = set() + for rank, hit in enumerate(raw_hits, start=1): + event_id = _runtime_event_key(hit) + if not event_id or event_id not in grouped_hits or event_id in seen_event_ids: + continue + metadata = dict(hit.metadata or {}) + metadata.update( + { + "profile_first_hybrid_rescue": True, + "profile_first_hybrid_rank": rank, + "profile_first_hybrid_event_id": event_id, + } + ) + hit.metadata = metadata + hit.score = round(max(float(hit.score), 0.88) + max(0.0, 0.08 - (0.01 * len(selected_hits))), 6) + selected_hits.append(hit) + selected_event_ids.append(event_id) + seen_event_ids.add(event_id) + same_event_source = _profile_same_event_source_hit(hit, grouped_hits.get(event_id, []), event_id=event_id) + if same_event_source is not None: + selected_hits.append(same_event_source) + if len(selected_hits) >= limit: + break + support_hits = _profile_support_source_hits( + graph, + hit, + grouped_hits=grouped_hits, + query=query, + limit=max(1, min(3, limit - len(selected_hits) + 1)), + ) + for support_hit in support_hits: + support_event_id = _clean_text((support_hit.metadata or {}).get("profile_first_hybrid_event_id", "")) or _runtime_event_key(support_hit) + if not support_event_id or support_event_id in seen_event_ids: + continue + selected_hits.append(support_hit) + selected_event_ids.append(support_event_id) + seen_event_ids.add(support_event_id) + if len(selected_hits) >= limit: + break + if len(selected_hits) >= limit: + break + return { + "hits": selected_hits, + "event_ids": selected_event_ids, + "memory_ids": [hit.memory_id for hit in selected_hits], + } + + +def _inject_profile_first_hits( + final_hits: Sequence[MemoryHit], + profile_first_hits: Sequence[MemoryHit], + *, + selected_event_ids: Sequence[str], + selected_path_ids: Sequence[str], +) -> List[MemoryHit]: + merged: List[MemoryHit] = [] + seen_memory_ids: set[str] = set() + for hit in list(profile_first_hits) + list(final_hits): + if not hit or not hit.memory_id or hit.memory_id in seen_memory_ids: + continue + metadata = dict(hit.metadata or {}) + if bool(metadata.get("profile_first_hybrid_rescue")): + event_id = _clean_text(metadata.get("profile_first_hybrid_event_id", "")) or _runtime_event_key(hit) + metadata.update( + { + "event_id": event_id, + "path_id": "", + "evidence_snippet_role": "selected_event_representative", + "hybrid_score_source": "profile_first_hybrid_rescue", + "selected_event_ids": list(selected_event_ids), + "selected_path_ids": list(selected_path_ids), + } + ) + hit = MemoryHit( + memory_id=hit.memory_id, + category=hit.category, + value=hit.value, + relation=hit.relation, + anchors=list(hit.anchors), + score=float(hit.score), + source_kind=hit.source_kind, + slot_key=hit.slot_key, + state=hit.state, + turn_index=int(hit.turn_index), + metadata=metadata, + ) + merged.append(hit) + seen_memory_ids.add(hit.memory_id) + return merged + + +def _profile_focused_pack_hits( + graph: SessionMemoryGraphV2, + query: str, + final_hits: Sequence[MemoryHit], + *, + top_k: int, +) -> Dict[str, Any]: + intent = infer_profile_query_intent(query) + if not bool(intent.get("enabled")): + return { + "hits": list(final_hits), + "metadata": { + "profile_focused_pack_enabled": False, + "profile_focused_pack_reason": "profile_intent_not_requested", + }, + } + source_hits = _learnable_graph_hits(graph) + if not source_hits: + return { + "hits": list(final_hits), + "metadata": { + "profile_focused_pack_enabled": False, + "profile_focused_pack_reason": "no_learnable_hits", + }, + } + runtime_graph = _build_runtime_graph_from_hits(query, source_hits) + grouped_hits = dict(runtime_graph.get("grouped_hits", {}) or {}) + profile_first_payload = _profile_first_hybrid_rescue( + graph, + query, + grouped_hits=grouped_hits, + top_k=max(2, min(max(1, int(top_k or 1)), 4)), + ) + profile_first_hits = list(profile_first_payload.get("hits", []) or []) + profile_first_event_ids = list(profile_first_payload.get("event_ids", []) or []) + profile_first_memory_ids = list(profile_first_payload.get("memory_ids", []) or []) + if not profile_first_hits: + return { + "hits": list(final_hits), + "metadata": { + "profile_focused_pack_enabled": True, + "profile_focused_pack_reason": "no_profile_hits", + "profile_focused_pack_event_ids": [], + "profile_focused_pack_memory_ids": [], + }, + } + selected_event_ids = _dedupe([*profile_first_event_ids, *_event_ids_from_hits(final_hits)], max_items=max(1, int(top_k or 1) * 2)) + selected_path_ids = _dedupe( + [ + _clean_text((hit.metadata or {}).get("path_id", "")) + for hit in final_hits + if _clean_text((hit.metadata or {}).get("path_id", "")) + ], + max_items=max(1, int(top_k or 1)), + ) + merged_hits = _inject_profile_first_hits( + final_hits, + profile_first_hits, + selected_event_ids=selected_event_ids, + selected_path_ids=selected_path_ids, + ) + merged_hits = _coverage_preserving_final_hits(merged_hits, selected_event_ids=selected_event_ids, top_k=top_k) + return { + "hits": merged_hits, + "metadata": { + "profile_focused_pack_enabled": True, + "profile_focused_pack_reason": "profile_first_pack_injected", + "profile_focused_pack_event_ids": list(profile_first_event_ids), + "profile_focused_pack_memory_ids": list(profile_first_memory_ids), + "profile_focused_pack_hit_count": len(profile_first_hits), + "profile_first_hybrid_enabled": True, + "profile_first_event_ids": list(profile_first_event_ids), + "profile_first_memory_ids": list(profile_first_memory_ids), + }, + } + + +def _topic_bucket_rerank_hits( + graph: SessionMemoryGraphV2, + query: str, + hits: Sequence[MemoryHit], + *, + top_k: int, +) -> Dict[str, Any]: + if not hits: + profile_rescue_hits = _profile_query_rescue_hits(graph, query, top_k=top_k) + if profile_rescue_hits: + return { + "hits": profile_rescue_hits[: max(1, int(top_k or 1))], + "metadata": { + "topic_bucket_rerank_enabled": True, + "topic_bucket_rerank_reason": "profile_query_rescue_from_empty_hits", + "topic_bucket_profile_query_rescue_count": len(profile_rescue_hits), + }, + } + return { + "hits": [], + "metadata": {"topic_bucket_rerank_enabled": True, "topic_bucket_rerank_reason": "no_hits"}, + } + query_topic = _assign_topic_bucket_for_text(graph, query, turn_index=0, create=False) + query_bucket_id = _clean_text(query_topic.get("topic_bucket_id", "")) + adjacent_ids = _topic_adjacent_bucket_ids(graph, query_bucket_id) + dialogue_tunnel_ids = _dialogue_tunnel_bucket_ids(graph, query_bucket_id) + query_keywords = list(query_topic.get("topic_keywords", []) or []) + bridge_requested = _topic_bridge_requested(query) + dialogue_requested = _dialogue_tunnel_requested(query) + profile_query_requested = bool(infer_profile_query_intent(query).get("enabled")) + reranked: List[MemoryHit] = [] + stats = { + "same_bucket": 0, + "bridge_bucket": 0, + "blocked_bridge_bucket": 0, + "dialogue_tunnel_bucket": 0, + "blocked_dialogue_tunnel_bucket": 0, + "profile_route_preserved": 0, + "profile_query_rescue": 0, + "overlap_bucket": 0, + "off_topic": 0, + } + for hit in list(hits): + metadata = dict(hit.metadata or {}) + hit_bucket_id = _clean_text(metadata.get("topic_bucket_id", "")) + hit_keywords = _dedupe(metadata.get("topic_keywords", []) or [], max_items=32) + overlap = _topic_bucket_overlap_score(query_keywords, hit_keywords) + same_bucket = bool(query_bucket_id and hit_bucket_id and query_bucket_id == hit_bucket_id) + bridge_bucket = bool(hit_bucket_id and hit_bucket_id in adjacent_ids) + dialogue_tunnel_bucket = bool(hit_bucket_id and hit_bucket_id in dialogue_tunnel_ids) + overlap_bucket = overlap >= 0.22 + bridge_allowed = bridge_bucket and bridge_requested + dialogue_allowed = dialogue_tunnel_bucket and dialogue_requested + profile_route_preserved = bool( + profile_query_requested + and ( + bool(metadata.get("profile_layer")) + or "profile_route" in _normalize(metadata.get("match_reason", "")) + ) + ) + current_subject_preserved = bool(metadata.get("current_subject_resolver") or metadata.get("public_subject_match")) + delta = 0.0 + if same_bucket: + delta += 1.35 + stats["same_bucket"] += 1 + elif current_subject_preserved: + delta += 0.18 + stats["profile_route_preserved"] += 1 + elif bridge_allowed: + delta += 0.42 + stats["bridge_bucket"] += 1 + elif dialogue_allowed: + delta += 0.16 + stats["dialogue_tunnel_bucket"] += 1 + elif overlap_bucket: + delta += 0.28 + overlap + stats["overlap_bucket"] += 1 + elif bridge_bucket: + delta -= 1.05 + stats["blocked_bridge_bucket"] += 1 + elif dialogue_tunnel_bucket: + delta -= 1.35 + stats["blocked_dialogue_tunnel_bucket"] += 1 + elif profile_route_preserved: + delta += 0.08 + stats["profile_route_preserved"] += 1 + else: + delta -= 0.72 + stats["off_topic"] += 1 + memory_type = _normalize(metadata.get("memory_type", "")) + durability = _normalize(metadata.get("durability", "")) + conflict_policy = _normalize(metadata.get("conflict_policy", "")) + if (memory_type == "hard_constraint" or durability == "hard" or conflict_policy == "must_preserve") and ( + same_bucket or bridge_allowed or dialogue_allowed or overlap_bucket + ): + delta += 1.10 + match_reason = _clean_text(metadata.get("match_reason", "")) + if profile_route_preserved and "profile_route" not in _normalize(match_reason): + match_reason = ",".join(_dedupe([match_reason, "profile_route"], max_items=4)) + metadata.update( + { + "topic_bucket_rerank": True, + "topic_bucket_query_id": query_bucket_id, + "topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")), + "topic_bucket_overlap": round(overlap, 6), + "topic_bucket_delta": round(delta, 6), + "topic_bucket_same": same_bucket, + "topic_bucket_bridge": bridge_bucket, + "topic_bucket_bridge_allowed": bridge_allowed, + "topic_bucket_dialogue_tunnel": dialogue_tunnel_bucket, + "topic_bucket_dialogue_tunnel_allowed": dialogue_allowed, + "topic_bucket_profile_route_preserved": profile_route_preserved, + "topic_bucket_current_subject_preserved": current_subject_preserved, + "match_reason": match_reason, + } + ) + hit.metadata = metadata + hit.score = float(hit.score) + delta + reranked.append(hit) + seen_ids = {hit.memory_id for hit in reranked if hit.memory_id} + rescue_records = [ + record + for record in getattr(graph, "records_by_id", {}).values() + if record.memory_id not in seen_ids + and record.state == "active" + and _clean_text((record.metadata or {}).get("topic_bucket_id", "")) == query_bucket_id + and _normalize(record.category) != "question" + ] + rescue_records.sort( + key=lambda record: ( + int(any(marker in _normalize(record.value) for marker in ("过敏", "必须", "避开", "禁止", "不能", "must", "avoid", "allergy"))), + int(_normalize(record.category) in {"constraint", "preference"}), + float(record.confidence), + float(record.salience), + -int(record.turn_index), + ), + reverse=True, + ) + rescue_hits = [ + _topic_bucket_record_to_hit(record, query_topic=query_topic, rank=index) + for index, record in enumerate(rescue_records[: max(2, min(12, int(top_k or 1) * 2))], start=1) + ] + reranked.extend(rescue_hits) + dialogue_rescue_hits: List[MemoryHit] = [] + if dialogue_requested and dialogue_tunnel_ids: + seen_ids.update(hit.memory_id for hit in rescue_hits if hit.memory_id) + dialogue_records = [ + record + for record in getattr(graph, "records_by_id", {}).values() + if record.memory_id not in seen_ids + and record.state == "active" + and _clean_text((record.metadata or {}).get("topic_bucket_id", "")) in dialogue_tunnel_ids + and _normalize(record.category) != "question" + ] + def dialogue_record_rank(record: SessionMemoryRecordV2) -> tuple[int, int, float, float, int]: + return ( + int(any(marker in _normalize(record.value) for marker in ("过敏", "必须", "避开", "禁止", "不能", "must", "avoid", "allergy"))), + int(_normalize(record.category) in {"constraint", "preference"}), + float(record.confidence), + float(record.salience), + -int(record.turn_index), + ) + + dialogue_records.sort(key=dialogue_record_rank, reverse=True) + selected_dialogue_records: List[SessionMemoryRecordV2] = [] + selected_dialogue_ids: set[str] = set() + for bucket_id in sorted(dialogue_tunnel_ids): + bucket_records = [ + record + for record in dialogue_records + if _clean_text((record.metadata or {}).get("topic_bucket_id", "")) == bucket_id + ] + if not bucket_records: + continue + best = bucket_records[0] + selected_dialogue_records.append(best) + selected_dialogue_ids.add(best.memory_id) + target_dialogue_rescue = max(len(selected_dialogue_records), max(1, min(4, int(top_k or 1) // 2 or 1))) + for record in dialogue_records: + if len(selected_dialogue_records) >= target_dialogue_rescue: + break + if record.memory_id in selected_dialogue_ids: + continue + selected_dialogue_records.append(record) + selected_dialogue_ids.add(record.memory_id) + dialogue_rescue_hits = [ + _topic_bucket_record_to_hit(record, query_topic=query_topic, rank=index, rescue_kind="dialogue_tunnel") + for index, record in enumerate(selected_dialogue_records, start=1) + ] + reranked.extend(dialogue_rescue_hits) + profile_rescue_hits: List[MemoryHit] = [] + if profile_query_requested: + seen_ids.update(hit.memory_id for hit in reranked if hit.memory_id) + for index, hit in enumerate(_profile_query_rescue_hits(graph, query, top_k=top_k), start=1): + if hit.memory_id in seen_ids: + continue + metadata = dict(hit.metadata or {}) + metadata.update( + { + "profile_query_rescue_rank": index, + "topic_bucket_query_id": query_bucket_id, + "topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")), + "topic_bucket_profile_route_preserved": True, + } + ) + hit.metadata = metadata + hit.score = float(hit.score) + 0.08 + profile_rescue_hits.append(hit) + seen_ids.add(hit.memory_id) + stats["profile_query_rescue"] = len(profile_rescue_hits) + reranked.extend(profile_rescue_hits) + reranked.sort(key=lambda item: (float(item.score), int(item.turn_index)), reverse=True) + limit = max(1, int(top_k or 1)) + focused_hits = [ + hit + for hit in reranked + if bool((hit.metadata or {}).get("topic_bucket_same")) + or float((hit.metadata or {}).get("topic_bucket_overlap", 0.0) or 0.0) >= 0.22 + or bool((hit.metadata or {}).get("topic_bucket_bridge_allowed")) + or bool((hit.metadata or {}).get("topic_bucket_dialogue_tunnel_allowed")) + or bool((hit.metadata or {}).get("topic_bucket_profile_route_preserved")) + or bool((hit.metadata or {}).get("topic_bucket_current_subject_preserved")) + or bool((hit.metadata or {}).get("current_subject_resolver")) + ] + model_path_fallback = False + no_bucket_model_fallback = False + if not focused_hits: + model_supported_hits = [ + hit + for hit in reranked + if _clean_text((hit.metadata or {}).get("path_id", "")) + or _clean_text((hit.metadata or {}).get("hybrid_score_source", "")) + ] + if model_supported_hits: + focused_hits = model_supported_hits + model_path_fallback = True + generic_memory_fallback = False + if not focused_hits and any(not _is_public_dialog_hit(hit) for hit in reranked): + focused_hits = list(reranked) + generic_memory_fallback = True + if not focused_hits and not query_bucket_id: + focused_hits = list(reranked) + no_bucket_model_fallback = True + time_focused_model_path = any( + "speaker_event_time" in _clean_text((hit.metadata or {}).get("path_id", "")) + or _clean_text((hit.metadata or {}).get("model_focused_answer_type", "")) == "time" + for hit in focused_hits + ) + if time_focused_model_path: + focused_hits = [ + hit + for hit in focused_hits + if not ( + _clean_text(hit.source_kind) == "public_dialog_profile" + and not _clean_text((hit.metadata or {}).get("path_id", "")) + ) + ] + filtered_count = max(0, len(reranked) - len(focused_hits)) + query_subject_signature = _public_subject_signature(_public_query_subject(query)) + + def _hit_subject_signatures(hit: MemoryHit) -> set[str]: + metadata = dict(hit.metadata or {}) + signatures = { + _normalize(metadata.get("subject_signature", "")).replace("-", "_"), + _public_subject_signature(metadata.get("subject", "")), + } + canonical_slot_key = _clean_text(metadata.get("canonical_slot_key", "")) + if ".subject." in canonical_slot_key: + signatures.add(_public_subject_signature(canonical_slot_key.split(".subject.", 1)[-1])) + if ".subject." in hit.slot_key: + signatures.add(_public_subject_signature(hit.slot_key.split(".subject.", 1)[-1])) + signatures.discard("") + return signatures + + protected_current_hits = [ + hit + for hit in focused_hits + if bool((hit.metadata or {}).get("current_subject_resolver")) + ] + protected_model_path_hits = [ + hit + for hit in focused_hits + if _clean_text((hit.metadata or {}).get("path_id", "")) + and _clean_text((hit.metadata or {}).get("hybrid_score_source", "")) + ] + protected_generic_hits = [ + hit + for hit in focused_hits + if not _is_public_dialog_hit(hit) + and not bool((hit.metadata or {}).get("profile_layer")) + ] + if protected_current_hits and query_subject_signature: + protected_ids = {hit.memory_id for hit in protected_current_hits if hit.memory_id} + compacted_focused_hits: List[MemoryHit] = [] + for hit in focused_hits: + if hit.memory_id in protected_ids: + compacted_focused_hits.append(hit) + continue + same_current_subject = query_subject_signature in _hit_subject_signatures(hit) + inactive_state = _normalize(hit.state) in {"superseded", "evidence", "historical", "stale", "false"} + if same_current_subject and inactive_state: + continue + compacted_focused_hits.append(hit) + focused_hits = compacted_focused_hits + filtered_count = max(0, len(reranked) - len(focused_hits)) + + selected_hits: List[MemoryHit] = [] + selected_keys: set[str] = set() + for hit in [*protected_current_hits, *protected_model_path_hits, *protected_generic_hits, *focused_hits]: + key = hit.memory_id or f"{_hit_event_id(hit)}::{hit.slot_key}::{hit.value[:80]}" + if key in selected_keys: + continue + selected_hits.append(hit) + selected_keys.add(key) + if len(selected_hits) >= limit: + break + dialogue_reserved_count = 0 + if dialogue_requested and dialogue_tunnel_ids: + selected_ids = {hit.memory_id for hit in selected_hits if hit.memory_id} + selected_bucket_ids = { + _clean_text((hit.metadata or {}).get("topic_bucket_id", "")) + for hit in selected_hits + if _clean_text((hit.metadata or {}).get("topic_bucket_id", "")) + } + for bucket_id in sorted(dialogue_tunnel_ids): + if bucket_id in selected_bucket_ids: + continue + candidate = next( + ( + hit + for hit in focused_hits + if hit.memory_id not in selected_ids + and _clean_text((hit.metadata or {}).get("topic_bucket_id", "")) == bucket_id + and bool((hit.metadata or {}).get("topic_bucket_dialogue_tunnel_allowed")) + ), + None, + ) + if candidate is None: + continue + if len(selected_hits) < limit: + selected_hits.append(candidate) + else: + replace_index = len(selected_hits) - 1 + for index in range(len(selected_hits) - 1, -1, -1): + metadata = dict(selected_hits[index].metadata or {}) + hardish = ( + _normalize(metadata.get("memory_type", "")) == "hard_constraint" + or _normalize(metadata.get("durability", "")) == "hard" + or _normalize(metadata.get("conflict_policy", "")) == "must_preserve" + ) + if not hardish and not bool(metadata.get("topic_bucket_dialogue_tunnel_allowed")): + replace_index = index + break + selected_hits[replace_index] = candidate + selected_ids.add(candidate.memory_id) + selected_bucket_ids.add(bucket_id) + dialogue_reserved_count += 1 + selected_hits.sort( + key=lambda item: ( + int(bool((item.metadata or {}).get("current_subject_resolver"))), + float(item.score), + int(item.turn_index), + ), + reverse=True, + ) + return { + "hits": selected_hits[:limit], + "metadata": { + "topic_bucket_rerank_enabled": True, + "topic_bucket_no_fill_policy": True, + "topic_bucket_query_id": query_bucket_id, + "topic_bucket_query_label": _clean_text(query_topic.get("topic_label", "")), + "topic_bucket_query_keywords": query_keywords, + "topic_bucket_adjacent_ids": sorted(adjacent_ids), + "dialogue_tunnel_adjacent_ids": sorted(dialogue_tunnel_ids), + "topic_bucket_bridge_requested": bridge_requested, + "dialogue_tunnel_requested": dialogue_requested, + "topic_bucket_model_path_fallback": model_path_fallback, + "topic_bucket_generic_memory_fallback": generic_memory_fallback, + "topic_bucket_no_bucket_model_fallback": no_bucket_model_fallback, + "topic_bucket_candidate_count": len(hits), + "topic_bucket_rescue_count": len(rescue_hits), + "dialogue_tunnel_rescue_count": len(dialogue_rescue_hits), + "profile_query_rescue_count": len(profile_rescue_hits), + "dialogue_tunnel_reserved_count": dialogue_reserved_count, + "topic_bucket_focused_count": len(focused_hits), + "topic_bucket_final_count": len(selected_hits[:limit]), + "topic_bucket_filtered_count": filtered_count, + "topic_bucket_stats": stats, + }, + } + + +def _looks_like_write_turn(user_text: str, *, answer_payload: Dict[str, Any] | None = None) -> bool: + text = _normalize(user_text) + if not text: + return False + if "?" in text: + return False + if any(marker in text for marker in _WRITE_MARKERS): + return True + metadata = dict((answer_payload or {}).get("metadata", {}) or {}) + return bool(metadata.get("memory_write")) + + +def _explicit_overwrite_requested(user_text: str) -> bool: + text = _normalize(user_text) + return bool(text) and any(marker in text for marker in _OVERWRITE_MARKERS) + + +def _apply_turn_write_intent(records: List[SessionMemoryRecordV2], *, user_text: str) -> List[SessionMemoryRecordV2]: + if not records or not _explicit_overwrite_requested(user_text): + return records + for record in records: + metadata = dict(record.metadata or {}) + metadata.setdefault("write_intent", "overwrite") + metadata.setdefault("memory_gate_decision", "explicit_overwrite") + metadata["allow_parallel_state"] = False + record.metadata = metadata + return records + + +def _records_from_extractor( + extractor: SessionMemoryExtractor, + *, + query: str, + answer_payload: Dict[str, Any] | None, + extraction_result: Dict[str, Any] | None, + turn_index: int, + profile: TMCRAProfile | None = None, +) -> List[SessionMemoryRecordV2]: + profile = profile or TMCRAProfile() + raw_records = extractor.extract( + query=query, + extraction_result=extraction_result, + answer_bundle=answer_payload, + answer_mode=str((answer_payload or {}).get("answer_mode", "transparent")), + turn_index=turn_index, + ) + results: List[SessionMemoryRecordV2] = [] + for index, record in enumerate(raw_records): + slot_key = profile.stable_slot_key( + category=record.category, + value=record.value, + anchors=record.anchor_concepts, + slot_key=record.metadata.get("slot_key", "") if isinstance(record.metadata, dict) else "", + relation=record.relation, + metadata=dict(record.metadata or {}), + ) + metadata = { + **dict(record.metadata or {}), + "memory_role": _clean_text(dict(record.metadata or {}).get("memory_role", "")) or "user", + "authority": _clean_text(dict(record.metadata or {}).get("authority", "")) or "source", + "canonical_slot_key": _clean_text(dict(record.metadata or {}).get("canonical_slot_key", "")) or slot_key, + "writeback_class": _clean_text(dict(record.metadata or {}).get("writeback_class", "")), + "origin_query": _clean_text(dict(record.metadata or {}).get("origin_query", "")) or _clean_text(query), + "origin_answer_id": _clean_text(dict(record.metadata or {}).get("origin_answer_id", "")), + "support_memory_ids": _dedupe(dict(record.metadata or {}).get("support_memory_ids", []) or []), + "support_fact_refs": _dedupe(dict(record.metadata or {}).get("support_fact_refs", []) or []), + "support_path_refs": _dedupe(dict(record.metadata or {}).get("support_path_refs", []) or []), + "promotion_state": _clean_text(dict(record.metadata or {}).get("promotion_state", "")) or "none", + } + results.append( + SessionMemoryRecordV2( + memory_id=f"{slot_key}:{turn_index}:auto:{index}", + category=record.category, + slot_key=slot_key, + value=_clean_text(record.value), + relation=_clean_text(record.relation) or f"{record.category}_memory", + anchor_concepts=_dedupe(record.anchor_concepts, max_items=8), + evidence_anchors=_dedupe(record.anchor_concepts, max_items=8), + salience=float(record.salience), + confidence=float(record.confidence), + source_kind=_clean_text(record.source_kind) or "session_memory", + turn_index=int(record.turn_index), + state="active", + metadata=metadata, + ) + ) + return results + + +def _build_turn_records( + extractor: SessionMemoryExtractor, + *, + user_text: str, + answer_payload: Dict[str, Any] | None, + extraction_result: Dict[str, Any] | None, + turn_index: int, + allow_auto_extract: bool, + profile: TMCRAProfile | None = None, +) -> List[SessionMemoryRecordV2]: + profile = profile or TMCRAProfile() + structured_records = _parse_structured_records(answer_payload, turn_index=turn_index, profile=profile) + if structured_records: + return _apply_typed_tunnel_annotations( + _apply_turn_write_intent(structured_records, user_text=user_text), + source_text=user_text, + ) + structured_records = _parse_structured_records(extraction_result, turn_index=turn_index, profile=profile) + if structured_records: + return _apply_typed_tunnel_annotations( + _apply_turn_write_intent(structured_records, user_text=user_text), + source_text=user_text, + ) + if not allow_auto_extract and not _looks_like_write_turn(user_text, answer_payload=answer_payload): + return [] + return _apply_typed_tunnel_annotations( + _records_from_extractor( + extractor, + query=user_text, + answer_payload=None, + extraction_result=extraction_result, + turn_index=turn_index, + profile=profile, + ), + source_text=user_text, + ) + + +class NullMemoryAdapter(MemoryAdapter): + name = "null_memory" + + def reset(self) -> None: + return None + + def ingest_turn( + self, + user_text: str, + assistant_text: str = "", + *, + answer_payload: Dict[str, Any] | None = None, + extraction_result: Dict[str, Any] | None = None, + ) -> None: + _ = user_text, assistant_text, answer_payload, extraction_result + + def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval: + _ = query, top_k + return MemoryRetrieval() + + def stats(self) -> Dict[str, Any]: + return _state_stats(storage_bytes=0, retrieval_context_tokens=0, total_state_tokens=0, records=0) + + def storage_bytes(self) -> int: + return 0 + + def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]: + _ = top_k + return { + "mode": "null_memory", + "query": query, + "retrieval": MemoryRetrieval().to_dict(), + "stats": self.stats(), + "state": {}, + } + + +class GraphSessionMemoryAdapter(MemoryAdapter): + name = "graph_session_memory_v2" + + def __init__( + self, + *, + auto_extract: bool = False, + storage_backend: str = "sqlite", + storage_path: str = "", + scope_id: str = "", + audit_retention: int = 256, + lightweight_stats: bool = True, + retrieval_mode: str = "heuristic", + node_model_path: str = "", + path_model_path: str = "", + node_model_device: str = "", + candidate_event_k: int = 24, + support_path_k: int = 3, + path_tunnel_rescue_k: int = 0, + path_tunnel_rescue_score_floor: float = 0.0, + path_tunnel_rescue_min_age: int = 0, + path_tunnel_rescue_min_score_margin: float = 0.0, + event_rerank_mode: str = "matrix", + matrix_event_top_k: int = DEFAULT_MATRIX_EVENT_TOP_K, + memory_router_mode: str = "", + memory_router_threshold: float = _MEMORY_ROUTER_DEFAULT_THRESHOLD, + memory_router_margin: float = _MEMORY_ROUTER_DEFAULT_MARGIN, + injection_planner_mode: str = "", + injection_planner_model_path: str = "", + injection_planner_latest_path: str = "", + injection_planner_device: str = "", + injection_planner_selection_threshold: float = -1.0, + injection_planner_row_threshold: float = -1.0, + injection_planner_logic_threshold: float = -1.0, + temporal_layer_mode: str = "", + temporal_router_mode: str = "", + temporal_router_dir: str = "", + temporal_router_latest_path: str = "", + temporal_router_device: str = "", + ) -> None: + prewarm_embedder_mode = ( + _normalize(os.getenv("TMCRA_EMBEDDER_INDEX_RECALL_MODE", "")) + or _normalize(os.getenv("TMCRA_WRITE_EMBEDDER_INDEX_MODE", "")) + ) + self._embedder_prewarm_metadata = _prewarm_embedder_dense_if_requested(mode=prewarm_embedder_mode) + self.extractor = SessionMemoryExtractor() + self.profile = TMCRAProfile() + self.temporal_organizer = TemporalOrganizer() + self.temporal_query_planner = TemporalQueryPlanner() + self.timeline_evidence_builder = TimelineEvidencePackBuilder() + self.auto_extract = bool(auto_extract) + self.storage_backend = _normalize(storage_backend) or "sqlite" + self.audit_retention = max(1, int(audit_retention)) + self.lightweight_stats = bool(lightweight_stats) + self.scope_id = _clean_text(scope_id) or f"graph-session-{uuid.uuid4().hex}" + self._store: SQLiteSessionMemoryStore | None = None + self.storage_path = "" + if self.storage_backend == "sqlite": + resolved_storage_path = _clean_text(storage_path) or str((Path(tempfile.gettempdir()) / "tmcra_graph_session_memory.sqlite3").resolve()) + self._store = SQLiteSessionMemoryStore(resolved_storage_path, audit_retention=self.audit_retention) + self.storage_path = str(self._store.storage_path) + self.graph = self._store.load_graph(self.scope_id) + elif self.storage_backend == "memory": + self.storage_path = _clean_text(storage_path) + self.graph = SessionMemoryGraphV2( + audit_retention=self.audit_retention, + persistence_backend="memory", + persistence_path=self.storage_path, + ) + else: + raise ValueError(f"Unsupported storage backend: {self.storage_backend}") + self._last_retrieval_context_tokens = 0 + self._last_writeback_summary: Dict[str, Any] = {} + self.retrieval_mode = _normalize(retrieval_mode) or "heuristic" + self.node_model_path = _clean_text(node_model_path) + self.path_model_path = _clean_text(path_model_path) + self.node_model_device = _clean_text(node_model_device) + self.candidate_event_k = max(1, int(candidate_event_k)) + self.support_path_k = max(1, int(support_path_k)) + self.path_tunnel_rescue_k = max(0, int(path_tunnel_rescue_k or 0)) + self.path_tunnel_rescue_score_floor = max(0.0, float(path_tunnel_rescue_score_floor or 0.0)) + self.path_tunnel_rescue_min_age = max(0, int(path_tunnel_rescue_min_age or 0)) + self.path_tunnel_rescue_min_score_margin = max(0.0, float(path_tunnel_rescue_min_score_margin or 0.0)) + self.event_rerank_mode = _normalize(event_rerank_mode) or "matrix" + self.matrix_event_top_k = max(1, int(matrix_event_top_k or DEFAULT_MATRIX_EVENT_TOP_K)) + self.write_embedder_index_mode = _normalize(os.getenv("TMCRA_WRITE_EMBEDDER_INDEX_MODE", "")) + if not self.write_embedder_index_mode: + self.write_embedder_index_mode = "off" + try: + self.write_embedder_index_max_terms = max( + 8, + int(os.getenv("TMCRA_WRITE_EMBEDDER_INDEX_MAX_TERMS", "96") or 96), + ) + except (TypeError, ValueError): + self.write_embedder_index_max_terms = 96 + env_embedder_recall_mode = _normalize(os.getenv("TMCRA_EMBEDDER_INDEX_RECALL_MODE", "")) + self.embedder_index_recall_mode = env_embedder_recall_mode or self.write_embedder_index_mode + try: + self.embedder_index_recall_k = max( + 0, + int(os.getenv("TMCRA_EMBEDDER_INDEX_RECALL_K", "0") or 0), + ) + except (TypeError, ValueError): + self.embedder_index_recall_k = 0 + self.embedder_pre_recall_mode = _normalize(os.getenv("TMCRA_EMBEDDER_PRE_RECALL_MODE", "")) + if not self.embedder_pre_recall_mode: + self.embedder_pre_recall_mode = "off" + try: + self.embedder_pre_recall_k = max( + 0, + int(os.getenv("TMCRA_EMBEDDER_PRE_RECALL_K", "0") or 0), + ) + except (TypeError, ValueError): + self.embedder_pre_recall_k = 0 + self.embedder_fusion_mode = _normalize(os.getenv("TMCRA_EMBEDDER_FUSION_MODE", "")) + try: + self.embedder_fusion_weight = max( + 0.0, + float(os.getenv("TMCRA_EMBEDDER_FUSION_WEIGHT", "0.35") or 0.35), + ) + except (TypeError, ValueError): + self.embedder_fusion_weight = 0.35 + try: + self.embedder_fusion_score_floor = max( + 0.0, + float(os.getenv("TMCRA_EMBEDDER_FUSION_SCORE_FLOOR", "0.62") or 0.62), + ) + except (TypeError, ValueError): + self.embedder_fusion_score_floor = 0.62 + try: + self.embedder_fusion_top_k = max( + 0, + int(os.getenv("TMCRA_EMBEDDER_FUSION_TOP_K", "16") or 16), + ) + except (TypeError, ValueError): + self.embedder_fusion_top_k = 16 + try: + self.embedder_fusion_select_k = max( + 0, + int(os.getenv("TMCRA_EMBEDDER_FUSION_SELECT_K", "4") or 4), + ) + except (TypeError, ValueError): + self.embedder_fusion_select_k = 4 + try: + self.embedder_fusion_max_boost = max( + 0.0, + float(os.getenv("TMCRA_EMBEDDER_FUSION_MAX_BOOST", "0.42") or 0.42), + ) + except (TypeError, ValueError): + self.embedder_fusion_max_boost = 0.42 + env_router_mode = _clean_text(os.getenv("TMCRA_MEMORY_ROUTER_MODE", "")) + self.memory_router_mode = _normalize(memory_router_mode) or _normalize(env_router_mode) or "observe" + try: + self.memory_router_threshold = float( + os.getenv("TMCRA_MEMORY_ROUTER_THRESHOLD", "") + or memory_router_threshold + or _MEMORY_ROUTER_DEFAULT_THRESHOLD + ) + except (TypeError, ValueError): + self.memory_router_threshold = _MEMORY_ROUTER_DEFAULT_THRESHOLD + try: + self.memory_router_margin = float( + os.getenv("TMCRA_MEMORY_ROUTER_MARGIN", "") + or memory_router_margin + or _MEMORY_ROUTER_DEFAULT_MARGIN + ) + except (TypeError, ValueError): + self.memory_router_margin = _MEMORY_ROUTER_DEFAULT_MARGIN + self._loaded_node_scorer: LoadedNodeMemoryScorer | None = None + self._node_scorer_error = "" + env_planner_mode = _clean_text(os.getenv("TMCRA_INJECTION_PLANNER_MODE", "")) + self.injection_planner_mode = _normalize(injection_planner_mode) or _normalize(env_planner_mode) or "observe" + self.injection_planner_model_path = _clean_text( + injection_planner_model_path or os.getenv("TMCRA_INJECTION_PLANNER_MODEL_PATH", "") + ) + self.injection_planner_latest_path = _clean_text( + injection_planner_latest_path or os.getenv("TMCRA_INJECTION_PLANNER_LATEST", "") + ) + self.injection_planner_device = _clean_text( + injection_planner_device or os.getenv("TMCRA_INJECTION_PLANNER_DEVICE", "") + ) + self.injection_planner_selection_threshold_override = float(injection_planner_selection_threshold) + self.injection_planner_row_threshold_override = float(injection_planner_row_threshold) + self.injection_planner_logic_threshold_override = float(injection_planner_logic_threshold) + self._loaded_injection_planner: Any | None = None + self._injection_planner_config: Any | None = None + self._injection_planner_payload: Dict[str, Any] = {} + self._injection_planner_thresholds: Dict[str, float] = {} + self._injection_planner_resolved_path = "" + self._injection_planner_error = "" + self._injection_planner_evidence_role_supported = False + env_temporal_layer_mode = _clean_text(os.getenv("TMCRA_TEMPORAL_LAYER_MODE", "")) + self.temporal_layer_mode = _normalize(temporal_layer_mode) or _normalize(env_temporal_layer_mode) or "observe" + env_temporal_router_mode = _clean_text(os.getenv("TMCRA_TEMPORAL_ROUTER_MODE", "")) + self.temporal_router_mode = _normalize(temporal_router_mode) or _normalize(env_temporal_router_mode) or "observe" + self.temporal_router_dir = _clean_text( + temporal_router_dir or os.getenv("TMCRA_TEMPORAL_ROUTER_DIR", "") + ) + self.temporal_router_latest_path = _clean_text( + temporal_router_latest_path + or os.getenv("TMCRA_TEMPORAL_ROUTER_LATEST", "") + or "models/temporal_router_v1_latest.txt" + ) + self.temporal_router_device = _clean_text( + temporal_router_device or os.getenv("TMCRA_TEMPORAL_ROUTER_DEVICE", "") + ) or "cpu" + self._loaded_temporal_router: Any | None = None + self._temporal_router_resolved_dir = "" + self._temporal_router_error = "" + + def _empty_graph(self) -> SessionMemoryGraphV2: + return SessionMemoryGraphV2( + audit_retention=self.audit_retention, + persistence_backend=self.storage_backend, + persistence_path=self.storage_path, + ) + + def _reload_graph(self) -> None: + if self._store is not None: + self.graph = self._store.load_graph(self.scope_id) + else: + self.graph.configure_persistence( + backend=self.storage_backend, + path=self.storage_path, + audit_retention=self.audit_retention, + ) + + def _persist_graph(self) -> None: + self.graph.configure_persistence( + backend=self.storage_backend, + path=self.storage_path, + audit_retention=self.audit_retention, + ) + if self._store is not None: + self._store.save_graph(self.scope_id, self.graph) + + def replace_graph(self, graph: SessionMemoryGraphV2) -> None: + self.graph = graph + self._persist_graph() + + def _storage_breakdown(self) -> Dict[str, int]: + core_payload = self.graph._core_payload() + audit_payload_full = self.graph._audit_payload() + if self.lightweight_stats: + audit_token_payload = { + "totals": dict(self.graph.audit_event_totals), + "trimmed": dict(self.graph.audit_trimmed_counts), + "retained": { + "turn_log": len(self.graph.turn_log), + "retrieval_log": len(self.graph.retrieval_log), + "answer_support_log": len(self.graph.answer_support_log), + }, + "audit_retention": int(self.graph.audit_retention), + } + else: + audit_token_payload = audit_payload_full + core_storage_bytes = len(json.dumps(core_payload, ensure_ascii=False).encode("utf-8")) + audit_storage_bytes = len(json.dumps(audit_payload_full, ensure_ascii=False).encode("utf-8")) + core_state_token_estimate = _estimate_tokens(json.dumps(core_payload, ensure_ascii=False)) + audit_state_token_estimate = _estimate_tokens(json.dumps(audit_token_payload, ensure_ascii=False)) + return { + "core_storage_bytes": int(core_storage_bytes), + "audit_storage_bytes": int(audit_storage_bytes), + "storage_bytes": int(core_storage_bytes + audit_storage_bytes), + "core_state_token_estimate": int(core_state_token_estimate), + "audit_state_token_estimate": int(audit_state_token_estimate), + "total_state_token_estimate": int(core_state_token_estimate + audit_state_token_estimate), + } + + def reset(self) -> None: + if self._store is not None: + self._store.clear_scope(self.scope_id) + self.graph = self._store.load_graph(self.scope_id) + else: + self.graph = self._empty_graph() + self._last_retrieval_context_tokens = 0 + self._last_writeback_summary = {} + + def _temporal_layer_enabled(self) -> bool: + return _normalize(self.temporal_layer_mode) not in _TEMPORAL_LAYER_DISABLED_MODES + + def _temporal_router_enabled(self) -> bool: + return self._temporal_layer_enabled() and _normalize(self.temporal_router_mode) not in _TEMPORAL_LAYER_DISABLED_MODES + + def _resolve_temporal_router_dir(self) -> str: + direct_dir = _clean_text(self.temporal_router_dir) + if direct_dir: + root = Path(direct_dir) + if (root / "writer_temporal_router.pt").exists() and (root / "query_temporal_router.pt").exists(): + return str(root) + latest_path = _clean_text(self.temporal_router_latest_path) + if latest_path: + pointer = Path(latest_path) + if pointer.exists(): + lines = pointer.read_text(encoding="utf-8").splitlines() + candidate = _clean_text(lines[0] if lines else "") + if candidate: + root = Path(candidate) + if (root / "writer_temporal_router.pt").exists() and (root / "query_temporal_router.pt").exists(): + return str(root) + return "" + + def _load_temporal_router(self) -> LoadedTemporalRouter | None: + if not self._temporal_router_enabled(): + self._temporal_router_error = "disabled" + return None + model_dir = self._resolve_temporal_router_dir() + self._temporal_router_resolved_dir = model_dir + if not model_dir: + self._temporal_router_error = "model_dir_missing" + return None + if self._loaded_temporal_router is not None: + return self._loaded_temporal_router + try: + self._loaded_temporal_router = LoadedTemporalRouter.from_dir( + model_dir, + device=self.temporal_router_device or "cpu", + ) + self._temporal_router_error = "" + except Exception as exc: # pragma: no cover - defensive runtime path + self._loaded_temporal_router = None + self._temporal_router_error = f"{type(exc).__name__}: {exc}" + return self._loaded_temporal_router + + def _temporal_router_status_metadata(self) -> Dict[str, Any]: + router = self._load_temporal_router() + return { + "temporal_router_enabled": router is not None, + "temporal_router_mode": _normalize(self.temporal_router_mode) or "observe", + "temporal_router_model_dir": self._temporal_router_resolved_dir, + "temporal_router_error": self._temporal_router_error, + } + + def _session_timestamp_from_payloads(self, *payloads: Mapping[str, Any] | None) -> str: + for payload in payloads: + data = dict(payload or {}) + for key in ("session_timestamp", "timestamp", "created_at", "turn_timestamp"): + value = _clean_text(data.get(key, "")) + if value: + return value + metadata = data.get("metadata") + if isinstance(metadata, Mapping): + for key in ("session_timestamp", "timestamp", "created_at", "turn_timestamp"): + value = _clean_text(metadata.get(key, "")) + if value: + return value + return "" + + def _temporal_frame_for_turn( + self, + *, + user_text: str, + previous_turn: str = "", + answer_payload: Dict[str, Any] | None = None, + extraction_result: Dict[str, Any] | None = None, + ) -> TemporalFrame | None: + if not self._temporal_layer_enabled(): + return None + session_timestamp = self._session_timestamp_from_payloads(answer_payload, extraction_result) + model_frame = None + for payload in (answer_payload, extraction_result): + data = dict(payload or {}) + candidate = data.get("temporal_frame") or dict(data.get("metadata", {}) or {}).get("temporal_frame") + if isinstance(candidate, Mapping): + model_frame = candidate + break + fallback_frame = self.temporal_organizer.organize_turn( + current_turn=user_text, + previous_turn=previous_turn, + session_timestamp=session_timestamp, + speaker="user", + ) + if model_frame is None: + router = self._load_temporal_router() + if router is not None and router.writer_available(): + predicted = router.predict_writer_frame( + current_turn=user_text, + previous_turn=previous_turn, + session_timestamp=session_timestamp, + ) + writer_confidence = float(predicted.get("confidence", 0.0) or 0.0) if predicted else 0.0 + if predicted and writer_confidence >= _float_env("TMCRA_TEMPORAL_ROUTER_WRITER_MIN_CONFIDENCE", _TEMPORAL_ROUTER_DEFAULT_WRITER_MIN_CONFIDENCE): + frame_payload = fallback_frame.to_dict() + frame_payload.update( + { + key: value + for key, value in predicted.items() + if key in {"temporal_intent", "anchor_type", "granularity", "state_operation"} + and _clean_text(value) + } + ) + if "should_create_timeline_edge" in predicted: + frame_payload["should_create_timeline_edge"] = bool(predicted.get("should_create_timeline_edge", False)) + if writer_confidence > 0.0: + frame_payload["confidence"] = writer_confidence + frame_payload["metadata"] = { + **dict(frame_payload.get("metadata", {}) or {}), + **dict(predicted.get("metadata", {}) or {}), + **self._temporal_router_status_metadata(), + } + model_frame = frame_payload + if model_frame is None: + return fallback_frame + return self.temporal_organizer.organize_turn( + current_turn=user_text, + previous_turn=previous_turn, + session_timestamp=session_timestamp, + speaker="user", + model_frame=model_frame, + ) + + def _temporal_turn_metadata(self, frame: TemporalFrame | None) -> Dict[str, Any]: + if frame is None: + return { + "temporal_layer_enabled": False, + "temporal_layer_mode": _normalize(self.temporal_layer_mode) or "observe", + **self._temporal_router_status_metadata(), + } + return { + "temporal_layer_enabled": True, + "temporal_layer_mode": _normalize(self.temporal_layer_mode) or "observe", + "temporal_frame": frame.to_dict(), + "temporal_intent": frame.temporal_intent, + "temporal_subject_key": frame.subject_key, + "temporal_state_operation": frame.state_operation, + **self._temporal_router_status_metadata(), + } + + def _apply_temporal_frame_to_records(self, records: Sequence[SessionMemoryRecordV2], frame: TemporalFrame | None) -> None: + if frame is None: + return + if frame.temporal_intent == "non_temporal" and not frame.subject_key: + return + for record in records: + metadata = dict(record.metadata or {}) + metadata.update( + { + "temporal_layer": True, + "temporal_frame": frame.to_dict(), + "temporal_intent": frame.temporal_intent, + "temporal_subject": frame.subject, + "temporal_subject_key": frame.subject_key, + "temporal_state_operation": frame.state_operation, + "temporal_event_time": frame.event_time, + "temporal_state_valid_from": frame.state_valid_from, + "temporal_state_valid_to": frame.state_valid_to, + } + ) + record.metadata = metadata + + def _build_timeline_state_layer(self) -> TimelineStateLayer: + layer = TimelineStateLayer() + for turn in sorted(list(getattr(self.graph, "turn_log", []) or []), key=lambda item: int(item.get("turn_index", 0) or 0)): + metadata = dict(turn.get("metadata", {}) or {}) + frame_payload = metadata.get("temporal_frame") + if not isinstance(frame_payload, Mapping): + continue + frame = TemporalFrame.from_mapping(frame_payload) + if frame.temporal_intent == "non_temporal" and not frame.new_state: + continue + record_ids = [item for item in list(turn.get("record_ids", []) or []) if _clean_text(item)] + source_event_id = _clean_text(metadata.get("temporal_source_record_id", "")) or (record_ids[0] if record_ids else "") + frame.metadata = { + **dict(frame.metadata or {}), + "source_text": _clean_text(turn.get("text", "")) or frame.evidence_span, + } + layer.apply_frame( + frame, + source_event_id=source_event_id, + source_turn_id=_clean_text(turn.get("turn_id", "")), + state_type=_clean_text(metadata.get("temporal_state_type", "")) or "profile", + ) + return layer + + def _temporal_runtime_pack(self, query: str) -> Dict[str, Any]: + metadata: Dict[str, Any] = { + "temporal_runtime_enabled": False, + "temporal_layer_mode": _normalize(self.temporal_layer_mode) or "observe", + **self._temporal_router_status_metadata(), + } + if not self._temporal_layer_enabled(): + metadata["temporal_runtime_reason"] = "disabled" + return {"metadata": metadata} + plan = self.temporal_query_planner.plan(query) + router = self._load_temporal_router() + if router is not None and router.query_available(): + predicted_plan = router.predict_query_plan(query=query) + query_confidence = float(predicted_plan.get("confidence", 0.0) or 0.0) if predicted_plan else 0.0 + router_confidences = dict(dict(predicted_plan.get("metadata", {}) or {}).get("temporal_router_confidences", {}) or {}) if predicted_plan else {} + intent_confidence = float(router_confidences.get("query_temporal_intent", 0.0) or 0.0) + if ( + predicted_plan + and query_confidence >= _float_env("TMCRA_TEMPORAL_ROUTER_QUERY_MIN_CONFIDENCE", _TEMPORAL_ROUTER_DEFAULT_QUERY_MIN_CONFIDENCE) + and intent_confidence >= _float_env("TMCRA_TEMPORAL_ROUTER_QUERY_INTENT_MIN_CONFIDENCE", _TEMPORAL_ROUTER_DEFAULT_QUERY_INTENT_MIN_CONFIDENCE) + ): + plan_payload = plan.to_dict() + plan_payload.update( + { + key: value + for key, value in predicted_plan.items() + if key in {"query_temporal_intent", "timeline_operation"} and _clean_text(value) + } + ) + for key in ("prefer_current_state", "prefer_previous_state", "requires_ordered_chain", "requires_comparison"): + if key in predicted_plan: + plan_payload[key] = bool(predicted_plan.get(key, False)) + if float(predicted_plan.get("confidence", 0.0) or 0.0) > 0.0: + plan_payload["confidence"] = float(predicted_plan.get("confidence", 0.0) or 0.0) + plan_payload["metadata"] = { + **dict(plan_payload.get("metadata", {}) or {}), + **dict(predicted_plan.get("metadata", {}) or {}), + **self._temporal_router_status_metadata(), + } + plan = TemporalQueryPlan(**{key: value for key, value in plan_payload.items() if key in TemporalQueryPlan.__dataclass_fields__}) + metadata["temporal_query_plan"] = plan.to_dict() + if plan.query_temporal_intent == "non_temporal" or plan.timeline_operation == "none": + metadata["temporal_runtime_reason"] = "non_temporal_query" + return {"plan": plan, "metadata": metadata} + if plan.timeline_operation in {"query_current", "query_previous"} and not _clean_text(plan.target_subject_key): + metadata["temporal_runtime_reason"] = "missing_target_subject" + return {"plan": plan, "metadata": metadata} + timeline_layer = self._build_timeline_state_layer() + pack = self.timeline_evidence_builder.build(plan=plan, timeline_layer=timeline_layer) + metadata.update( + { + "temporal_runtime_enabled": True, + "temporal_runtime_reason": "ok", + "temporal_evidence_pack": pack.to_dict(), + "temporal_selected_answer_value": _clean_text(pack.selected_evidence.get("answer_value", "")), + "temporal_selected_state_id": _clean_text(pack.selected_evidence.get("state_id", "")), + "temporal_timeline_state_count": len(pack.timeline), + } + ) + return {"plan": plan, "pack": pack, "metadata": metadata} + + def _temporal_state_hit( + self, + *, + state_payload: Mapping[str, Any], + plan_payload: Mapping[str, Any], + selected: bool, + rank: int, + ) -> MemoryHit | None: + state_value = _clean_text(state_payload.get("state", "")) + state_id = _clean_text(state_payload.get("state_id", "")) + if not state_value or not state_id: + return None + source_event_id = _clean_text(state_payload.get("source_event_id", "")) + source_record = self.graph.records_by_id.get(source_event_id) + source_metadata = dict(source_record.metadata or {}) if source_record is not None else {} + score = 0.98 if selected else max(0.75, 0.92 - (rank * 0.03)) + return MemoryHit( + memory_id=f"temporal_state:{state_id}", + category="time", + value=state_value, + relation="temporal_state", + anchors=_dedupe([state_payload.get("time", ""), state_payload.get("source_text", ""), plan_payload.get("target_subject", "")], max_items=6), + score=round(score, 6), + source_kind="temporal_state_layer", + slot_key=f"temporal.{_clean_text(plan_payload.get('target_subject_key', 'general'))}", + state="active" if bool(state_payload.get("is_current", False)) else "history", + turn_index=int(source_record.turn_index) if source_record is not None else 0, + metadata={ + **source_metadata, + "temporal_runtime_hit": True, + "temporal_runtime_selected": bool(selected), + "temporal_state_id": state_id, + "temporal_state_value": state_value, + "temporal_state_time": _clean_text(state_payload.get("time", "")), + "temporal_state_valid_to": _clean_text(state_payload.get("valid_to", "")), + "temporal_state_is_current": bool(state_payload.get("is_current", False)), + "temporal_source_event_id": source_event_id, + "temporal_source_turn_id": _clean_text(state_payload.get("source_turn_id", "")), + "temporal_query_plan": dict(plan_payload), + }, + ) + + def _apply_temporal_evidence_pack_to_hits( + self, + hits: Sequence[MemoryHit], + temporal_payload: Mapping[str, Any], + *, + top_k: int, + ) -> Dict[str, Any]: + metadata = dict(temporal_payload.get("metadata", {}) or {}) + pack = temporal_payload.get("pack") + plan = temporal_payload.get("plan") + if pack is None or plan is None or not bool(metadata.get("temporal_runtime_enabled", False)): + return {"hits": list(hits), "metadata": metadata} + pack_payload = pack.to_dict() if hasattr(pack, "to_dict") else dict(pack) + plan_payload = plan.to_dict() if hasattr(plan, "to_dict") else dict(plan) + selected_state_id = _clean_text(dict(pack_payload.get("selected_evidence", {}) or {}).get("state_id", "")) + selected_answer_value = _clean_text(dict(pack_payload.get("selected_evidence", {}) or {}).get("answer_value", "")) + timeline = list(pack_payload.get("timeline", []) or []) + synthetic_hits: List[MemoryHit] = [] + if selected_state_id: + selected_state = next((dict(item) for item in timeline if _clean_text(dict(item).get("state_id", "")) == selected_state_id), None) + if selected_state is None and selected_answer_value: + selected_state = { + "state_id": selected_state_id, + "state": selected_answer_value, + "source_event_id": dict(pack_payload.get("selected_evidence", {}) or {}).get("source_event_id", ""), + "source_turn_id": dict(pack_payload.get("selected_evidence", {}) or {}).get("source_turn_id", ""), + "is_current": bool(plan_payload.get("prefer_current_state", False)), + } + if selected_state: + hit = self._temporal_state_hit(state_payload=selected_state, plan_payload=plan_payload, selected=True, rank=0) + if hit is not None: + synthetic_hits.append(hit) + if bool(plan_payload.get("requires_ordered_chain", False)) or bool(plan_payload.get("requires_comparison", False)): + for index, state_payload in enumerate(timeline): + if _clean_text(dict(state_payload).get("state_id", "")) == selected_state_id: + continue + hit = self._temporal_state_hit( + state_payload=dict(state_payload), + plan_payload=plan_payload, + selected=False, + rank=index + 1, + ) + if hit is not None: + synthetic_hits.append(hit) + existing: List[MemoryHit] = [] + synthetic_source_ids = { + _clean_text((hit.metadata or {}).get("temporal_source_event_id", "")) + for hit in synthetic_hits + if _clean_text((hit.metadata or {}).get("temporal_source_event_id", "")) + } + for hit in hits: + hit_metadata = dict(hit.metadata or {}) + source_event_id = _clean_text(hit_metadata.get("event_id", hit.memory_id)) + if source_event_id in synthetic_source_ids: + hit_metadata.update( + { + "temporal_runtime_support": True, + "temporal_query_plan": dict(plan_payload), + "temporal_selected_answer_value": selected_answer_value, + "temporal_selected_state_id": selected_state_id, + } + ) + hit = MemoryHit( + memory_id=hit.memory_id, + category=hit.category, + value=hit.value, + relation=hit.relation, + anchors=list(hit.anchors), + score=hit.score, + source_kind=hit.source_kind, + slot_key=hit.slot_key, + state=hit.state, + turn_index=int(hit.turn_index), + metadata=hit_metadata, + ) + existing.append(hit) + merged: List[MemoryHit] = [] + seen = set() + for hit in [*synthetic_hits, *existing]: + key = _clean_text(hit.memory_id) + if key and key in seen: + continue + if key: + seen.add(key) + merged.append(hit) + metadata.update( + { + "temporal_runtime_injected_hit_count": len(synthetic_hits), + "temporal_runtime_injected_hit_ids": [hit.memory_id for hit in synthetic_hits], + } + ) + return {"hits": merged[: max(len(merged), top_k)], "metadata": metadata} + + def ingest_turn( + self, + user_text: str, + assistant_text: str = "", + *, + answer_payload: Dict[str, Any] | None = None, + extraction_result: Dict[str, Any] | None = None, + ) -> None: + self._reload_graph() + turn_index = self.graph.next_turn() + previous_topic = _last_topic_turn(self.graph) + previous_turn_text = _clean_text(self.graph.turn_log[-1].get("text", "")) if self.graph.turn_log else "" + temporal_frame = self._temporal_frame_for_turn( + user_text=user_text, + previous_turn=previous_turn_text, + answer_payload=answer_payload, + extraction_result=extraction_result, + ) + topic_bucket = _assign_topic_bucket_for_text( + self.graph, + user_text, + answer_payload=answer_payload, + turn_index=turn_index, + create=True, + ) + records = _build_turn_records( + self.extractor, + user_text=user_text, + answer_payload=answer_payload, + extraction_result=extraction_result, + turn_index=turn_index, + allow_auto_extract=self.auto_extract, + profile=self.profile, + ) + self._apply_temporal_frame_to_records(records, temporal_frame) + _apply_topic_bucket_to_records(records, topic_bucket) + stored_ids = self.graph.add_records(records) + write_embedder_metadata = _apply_write_embedder_index_to_graph( + self.graph, + stored_ids=stored_ids, + turn_text=user_text, + turn_index=turn_index, + mode=self.write_embedder_index_mode, + max_terms=self.write_embedder_index_max_terms, + ) + topic_bridge_metadata = _add_topic_bridge_edges( + self.graph, + previous_topic=previous_topic, + current_topic=topic_bucket, + current_record_ids=stored_ids, + turn_index=turn_index, + evidence=user_text, + ) + dialogue_tunnel_metadata = _add_dialogue_tunnel_edges( + self.graph, + current_topic=topic_bucket, + current_record_ids=stored_ids, + turn_index=turn_index, + evidence=user_text, + ) + turn_kind = "memory_write" if stored_ids else "noise" + self.graph.record_turn( + turn_kind=turn_kind, + text=user_text, + turn_index=turn_index, + record_ids=stored_ids, + speaker="user", + assistant_text=assistant_text, + metadata={ + "source": "user_turn", + "auto_extract": bool(self.auto_extract), + "topic_bucket_id": _clean_text(topic_bucket.get("topic_bucket_id", "")), + "topic_label": _clean_text(topic_bucket.get("topic_label", "")), + "topic_keywords": list(topic_bucket.get("topic_keywords", []) or []), + "temporal_source_record_id": stored_ids[0] if stored_ids else "", + **self._temporal_turn_metadata(temporal_frame), + **write_embedder_metadata, + **topic_bridge_metadata, + **dialogue_tunnel_metadata, + }, + ) + self._persist_graph() + + def ingest_answer_writeback( + self, + *, + query_text: str, + answer_text: str, + answer_id: str, + writeback_records: List[Dict[str, Any]], + trace: Dict[str, Any] | None = None, + ) -> List[str]: + self._reload_graph() + if not writeback_records: + self._last_writeback_summary = {"stored_record_ids": [], "promotion_events": []} + return [] + turn_index = self.graph.next_turn() + records: List[SessionMemoryRecordV2] = [] + writeback_classes: List[str] = [] + for index, raw in enumerate(writeback_records): + if not isinstance(raw, dict): + continue + category = _clean_text(raw.get("category", "fact")) or "fact" + value = _clean_text(raw.get("value", "")) + raw_slot_key = _clean_text(raw.get("slot_key", "")) or _clean_text(raw.get("slot", "")) + anchors = _dedupe(raw.get("anchors", []) or [], max_items=8) + slot_key = self.profile.stable_slot_key( + category=category, + value=value, + anchors=anchors, + slot_key=raw_slot_key, + relation=_clean_text(raw.get("relation", "")), + metadata=dict(raw.get("metadata", {}) or {}), + ) + if not value or not slot_key: + continue + raw_metadata = dict(raw.get("metadata", {}) or {}) + writeback_class = _clean_text(raw_metadata.get("writeback_class", "")) or "fact" + writeback_classes.append(writeback_class) + metadata = { + **raw_metadata, + "memory_role": _clean_text(raw_metadata.get("memory_role", "")) or "assistant", + "authority": _clean_text(raw_metadata.get("authority", "")) or "derived", + "canonical_slot_key": _clean_text(raw_metadata.get("canonical_slot_key", "")) or slot_key.removeprefix("assistant.").split(".fact", 1)[0].split(".state_change", 1)[0].split(".high_conf_conclusion", 1)[0], + "writeback_class": writeback_class, + "origin_query": _clean_text(raw_metadata.get("origin_query", "")) or _clean_text(query_text), + "origin_answer_id": _clean_text(raw_metadata.get("origin_answer_id", "")) or answer_id, + "origin_answer_ids": _dedupe([*(raw_metadata.get("origin_answer_ids", []) or []), _clean_text(raw_metadata.get("origin_answer_id", "")) or answer_id]), + "support_memory_ids": _dedupe(raw_metadata.get("support_memory_ids", []) or []), + "support_fact_refs": _dedupe(raw_metadata.get("support_fact_refs", []) or []), + "support_path_refs": _dedupe(raw_metadata.get("support_path_refs", []) or []), + "promotion_state": _clean_text(raw_metadata.get("promotion_state", "")) or "candidate", + "answer_id": answer_id, + } + record = SessionMemoryRecordV2( + memory_id=f"{slot_key}:{turn_index}:assistant:{index}", + category=category, + slot_key=slot_key, + value=value, + relation=_clean_text(raw.get("relation", "")) or "assistant_memory", + anchor_concepts=anchors, + evidence_anchors=anchors, + salience=float(raw.get("salience", 0.62) or 0.62), + confidence=float(raw.get("confidence", 0.88) or 0.88), + source_kind=_clean_text(raw.get("source_kind", "")) or "assistant_writeback", + turn_index=turn_index, + state=_clean_text(raw.get("state", "")) or "active", + metadata=metadata, + ) + records.append(record) + stored_ids = self.graph.add_records(records) + write_embedder_metadata = _apply_write_embedder_index_to_graph( + self.graph, + stored_ids=stored_ids, + turn_text=" ".join([_clean_text(query_text), _clean_text(answer_text)]), + turn_index=turn_index, + mode=self.write_embedder_index_mode, + max_terms=self.write_embedder_index_max_terms, + ) + promotion_events = self._apply_writeback_promotions(stored_ids) + writeback_class = writeback_classes[0] if len(set(writeback_classes)) == 1 and writeback_classes else ("mixed" if writeback_classes else "") + self.graph.record_turn( + turn_kind="assistant_writeback" if stored_ids else "assistant_writeback_empty", + text=query_text, + turn_index=turn_index, + record_ids=stored_ids, + speaker="assistant", + assistant_text=answer_text, + writeback_class=writeback_class, + metadata={ + "source": "assistant_writeback", + "answer_id": answer_id, + "trace": dict(trace or {}), + **write_embedder_metadata, + }, + ) + self._last_writeback_summary = { + "stored_record_ids": list(stored_ids), + "promotion_events": list(promotion_events), + **write_embedder_metadata, + } + self._persist_graph() + return stored_ids + + def last_writeback_summary(self) -> Dict[str, Any]: + return dict(self._last_writeback_summary) + + def _resolve_injection_planner_model_path(self) -> str: + explicit_path = _clean_text(self.injection_planner_model_path) + if explicit_path: + path = Path(explicit_path).expanduser() + if path.is_dir(): + path = path / "injection_planner.pt" + return str(path) + latest_path = _clean_text(self.injection_planner_latest_path) + if not latest_path: + return "" + latest = Path(latest_path).expanduser() + if latest.is_dir(): + return str(latest / "injection_planner.pt") + if latest.suffix == ".pt": + return str(latest) + try: + target = Path(latest.read_text(encoding="utf-8").strip()).expanduser() + except Exception as exc: + self._injection_planner_error = f"latest_pointer_read_failed: {exc}" + return "" + if target.is_dir(): + target = target / "injection_planner.pt" + return str(target) + + def _load_injection_planner(self) -> Any | None: + normalized_mode = _normalize(self.injection_planner_mode) + if normalized_mode in _INJECTION_PLANNER_DISABLED_MODES: + self._injection_planner_error = "disabled" + return None + if self._loaded_injection_planner is not None: + return self._loaded_injection_planner + torch_module = getattr(injection_planner_runtime, "torch", None) + if torch_module is None: + self._injection_planner_error = "torch_unavailable" + return None + model_path_text = self._resolve_injection_planner_model_path() + if not model_path_text: + if not self._injection_planner_error: + self._injection_planner_error = "model_path_missing" + return None + model_path = Path(model_path_text) + if not model_path.exists(): + self._injection_planner_error = f"model_path_not_found: {model_path}" + return None + try: + device = torch_module.device(self.injection_planner_device or "cpu") + payload = torch_module.load(model_path, map_location=device, weights_only=False) + config = injection_planner_runtime.InjectionPlannerConfig.from_dict(dict(payload.get("config", {}) or {})) + model = injection_planner_runtime.InjectionPlannerModel(config).to(device) + state_dict = dict(payload.get("state_dict", {}) or {}) + model.load_state_dict(state_dict, strict=False) + model.eval() + self._loaded_injection_planner = model + self._injection_planner_config = config + self._injection_planner_payload = dict(payload) + self._injection_planner_resolved_path = str(model_path) + self._injection_planner_error = "" + self._injection_planner_evidence_role_supported = any( + str(key).startswith("evidence_role_head.") for key in state_dict + ) + thresholds = {"selection_threshold": 0.5, "row_threshold": 0.5, "logic_threshold": 0.5} + summary_path = model_path.parent / "train_summary.json" + if summary_path.exists(): + try: + summary = json.loads(summary_path.read_text(encoding="utf-8")) + calibration = dict(summary.get("calibration", {}) or {}) + logic_calibration = dict(summary.get("logic_calibration", {}) or {}) + if calibration.get("selection_threshold") is not None: + thresholds["selection_threshold"] = float(calibration.get("selection_threshold")) + if calibration.get("row_threshold") is not None: + thresholds["row_threshold"] = float(calibration.get("row_threshold")) + if logic_calibration.get("logic_threshold") is not None: + thresholds["logic_threshold"] = float(logic_calibration.get("logic_threshold")) + except Exception: + pass + if self.injection_planner_selection_threshold_override >= 0.0: + thresholds["selection_threshold"] = float(self.injection_planner_selection_threshold_override) + if self.injection_planner_row_threshold_override >= 0.0: + thresholds["row_threshold"] = float(self.injection_planner_row_threshold_override) + if self.injection_planner_logic_threshold_override >= 0.0: + thresholds["logic_threshold"] = float(self.injection_planner_logic_threshold_override) + self._injection_planner_thresholds = thresholds + except Exception as exc: + self._loaded_injection_planner = None + self._injection_planner_config = None + self._injection_planner_payload = {} + self._injection_planner_resolved_path = str(model_path) + self._injection_planner_error = str(exc) + self._injection_planner_evidence_role_supported = False + return self._loaded_injection_planner + + def _injection_planner_candidate_from_hit( + self, + query: str, + hit: MemoryHit, + *, + index: int, + current_turn_index: int, + ) -> Dict[str, Any]: + metadata = dict(hit.metadata or {}) + category = _normalize(hit.category) + source_kind = _normalize(hit.source_kind) + logic_roles = [ + _normalize(item) + for item in ( + metadata.get("logic_roles", []) + if isinstance(metadata.get("logic_roles", []), list) + else [metadata.get("logic_roles", "")] + ) + if _normalize(item) in set(injection_planner_runtime.LOGIC_ROLES) + ] + if category in set(injection_planner_runtime.LOGIC_ROLES): + logic_roles.append(category) + if _normalize(metadata.get("writeback_class", "")) in set(injection_planner_runtime.LOGIC_ROLES): + logic_roles.append(_normalize(metadata.get("writeback_class", ""))) + if not logic_roles: + logic_roles = ["negative"] if _normalize(hit.state) in {"stale", "superseded", "false"} else ["evidence"] + evidence_snippet_role = _normalize(metadata.get("evidence_snippet_role", "")) + if bool(metadata.get("profile_layer")) or category == "profile" or source_kind.endswith("profile"): + layer = "profile" + elif category == "time" or float(metadata.get("temporal_score", 0.0) or 0.0) > 0.0: + layer = "temporal" + elif "resource" in logic_roles or _clean_text(metadata.get("resource_key", "")): + layer = "resource" + elif ( + source_kind in {"path_tunnel", "path_support", "public_dialog_path"} + or evidence_snippet_role in {"selected_path_support", "path_tunnel_support"} + or bool(metadata.get("path_tunnel_node")) + ): + layer = "path_tunnel" + elif bool(metadata.get("topic_bucket_rerank")) or bool(metadata.get("topic_bucket_dialogue_tunnel_allowed")): + layer = "topic_tunnel" + else: + layer = "event" + normalized_state = _normalize(hit.state) + if normalized_state in {"stale", "superseded", "false"}: + temporal_state = "superseded" + elif bool(metadata.get("topic_bucket_current_subject_preserved")) or bool(metadata.get("current_subject_protected")): + temporal_state = "current" + elif current_turn_index and int(hit.turn_index or 0) and current_turn_index - int(hit.turn_index) <= 1: + temporal_state = "current" + elif normalized_state == "active": + temporal_state = "stable" + else: + temporal_state = "historical" + query_tokens = set(_tokenize(query)) + hit_tokens = set(_tokenize(" ".join([hit.value, " ".join(hit.anchors), _clean_text(metadata.get("topic_label", ""))]))) + overlap = len(query_tokens & hit_tokens) / max(1, len(query_tokens | hit_tokens)) + candidate_id = _clean_text(hit.memory_id) or f"hit:{index}" + topic_label = _clean_text(metadata.get("topic_label", "")) or _clean_text(metadata.get("topic_bucket_query_label", "")) + profile_key = _clean_text(metadata.get("profile_type", "")) or _clean_text(metadata.get("semantic_slot", "")) + if not profile_key and layer in {"profile", "temporal", "topic_tunnel"}: + profile_key = topic_label + semantic_similarity = max( + 0.0, + min( + 1.0, + float( + metadata.get( + "semantic_similarity", + metadata.get( + "answer_window_semantic_similarity", + metadata.get( + "embedder_similarity", + metadata.get("dense_similarity", metadata.get("bge_m3_similarity", hit.score)), + ), + ), + ) + or 0.0 + ), + ), + ) + return { + "id": candidate_id, + "text": hit.value, + "summary": _clean_text(metadata.get("source_turn_text", "")) or _clean_text(metadata.get("event_summary", "")), + "topic": topic_label, + "profile_key": profile_key, + "resource_key": _clean_text(metadata.get("resource_key", "")) or (_clean_text(hit.slot_key) if layer == "resource" else ""), + "layer": layer, + "temporal_state": temporal_state, + "logic_roles": _dedupe(logic_roles), + "query_overlap": round(float(overlap), 6), + "retrieval_score": max(0.0, min(1.0, float(hit.score or 0.0))), + "graph_score": max( + 0.0, + min( + 1.0, + float( + metadata.get( + "event_score", + metadata.get("recall_score", metadata.get("hybrid_score", hit.score)), + ) + or 0.0 + ), + ), + ), + "tunnel_score": max( + 0.0, + min( + 1.0, + max( + float(metadata.get("path_tunnel_support_score", 0.0) or 0.0), + float(metadata.get("event_tunnel_support_score", 0.0) or 0.0), + float(metadata.get("path_chain_extension_delta_score", 0.0) or 0.0), + ), + ), + ), + "topic_similarity": max(0.0, min(1.0, float(metadata.get("topic_bucket_overlap", 0.0) or 0.0))), + "semantic_similarity": semantic_similarity, + "confidence": max(0.0, min(1.0, float(metadata.get("confidence", hit.score or 0.0) or 0.0))), + "rank_score": round(1.0 / float(index + 1), 6), + "age_turns": max(0, int(current_turn_index) - int(hit.turn_index or 0)) if current_turn_index else 0, + "branch_depth": max(0, len(list(metadata.get("selected_path_ids", []) or []))), + "contradicts_current": normalized_state in {"stale", "superseded", "false"} or bool(metadata.get("contradicts_current")), + "is_current": temporal_state == "current", + } + + def _apply_injection_planner_to_hits( + self, + query: str, + hits: Sequence[MemoryHit], + *, + top_k: int, + ) -> Dict[str, Any]: + normalized_mode = _normalize(self.injection_planner_mode) + base_metadata: Dict[str, Any] = { + "injection_planner_enabled": False, + "injection_planner_mode": normalized_mode or "observe", + } + if normalized_mode in _INJECTION_PLANNER_DISABLED_MODES: + base_metadata["injection_planner_reason"] = "disabled" + return {"hits": list(hits), "metadata": base_metadata} + if not hits: + base_metadata["injection_planner_reason"] = "no_hits" + return {"hits": list(hits), "metadata": base_metadata} + model = self._load_injection_planner() + if model is None or self._injection_planner_config is None: + base_metadata.update( + { + "injection_planner_reason": self._injection_planner_error or "load_failed", + "injection_planner_model_path": self._injection_planner_resolved_path, + } + ) + return {"hits": list(hits), "metadata": base_metadata} + torch_module = getattr(injection_planner_runtime, "torch", None) + if torch_module is None: + base_metadata["injection_planner_reason"] = "torch_unavailable" + return {"hits": list(hits), "metadata": base_metadata} + current_turn_index = max( + [int(getattr(self.graph, "turn_index", 0) or 0), *[int(hit.turn_index or 0) for hit in hits]], + default=0, + ) + candidates = [ + self._injection_planner_candidate_from_hit(query, hit, index=index, current_turn_index=current_turn_index) + for index, hit in enumerate(hits) + ] + row = {"id": "runtime_injection_plan", "query": query, "candidates": candidates, "gold": {}} + try: + dataset = injection_planner_runtime.InjectionPlannerDataset([row], self._injection_planner_config) + batch = injection_planner_runtime.collate_injection_batch([dataset[0]]) + device = next(model.parameters()).device + model_batch = { + key: value.to(device) if hasattr(value, "to") else value + for key, value in dict(batch).items() + } + with torch_module.no_grad(): + outputs = model(model_batch["features"], model_batch["valid_mask"]) + selection_scores = torch_module.sigmoid(outputs["selection_logits"])[0].detach().cpu().tolist() + should_inject_score = float(torch_module.sigmoid(outputs["should_inject_logits"])[0].detach().cpu().item()) + mode_probs = torch_module.softmax(outputs["injection_mode_logits"], dim=-1)[0].detach().cpu() + mode_index = int(torch_module.argmax(mode_probs).item()) + temporal_indices = torch_module.argmax(outputs["temporal_logits"], dim=-1)[0].detach().cpu().tolist() + logic_scores = torch_module.sigmoid(outputs["logic_logits"])[0].detach().cpu().tolist() + if bool(self._injection_planner_evidence_role_supported) and "evidence_role_logits" in outputs: + evidence_role_indices = torch_module.argmax(outputs["evidence_role_logits"], dim=-1)[0].detach().cpu().tolist() + else: + evidence_role_indices = [ + injection_planner_runtime.EVIDENCE_ROLES.index("direct_answer") + for _ in candidates + ] + except Exception as exc: + base_metadata.update( + { + "injection_planner_reason": f"inference_failed: {exc}", + "injection_planner_model_path": self._injection_planner_resolved_path, + } + ) + return {"hits": list(hits), "metadata": base_metadata} + thresholds = { + "selection_threshold": float(self._injection_planner_thresholds.get("selection_threshold", 0.5)), + "row_threshold": float(self._injection_planner_thresholds.get("row_threshold", 0.5)), + "logic_threshold": float(self._injection_planner_thresholds.get("logic_threshold", 0.5)), + } + injection_mode = injection_planner_runtime.INJECTION_MODES[mode_index] + row_allows_injection = should_inject_score >= thresholds["row_threshold"] and injection_mode != "none" + predictions: Dict[str, Dict[str, Any]] = {} + for index, candidate in enumerate(candidates): + logic_roles = [ + role + for role, score in zip(injection_planner_runtime.LOGIC_ROLES, logic_scores[index]) + if float(score) >= thresholds["logic_threshold"] + ] + evidence_role = injection_planner_runtime.EVIDENCE_ROLES[int(evidence_role_indices[index])] + role_allows_selection = evidence_role not in {"noise", "negative_evidence"} + predictions[candidate["id"]] = { + "selection_score": float(selection_scores[index]), + "selected": bool( + row_allows_injection + and role_allows_selection + and float(selection_scores[index]) >= thresholds["selection_threshold"] + ), + "temporal_state": injection_planner_runtime.TEMPORAL_STATES[int(temporal_indices[index])], + "evidence_role": evidence_role, + "logic_roles": logic_roles or ["evidence"], + "candidate_layer": candidate.get("layer", ""), + "role_allows_selection": bool(role_allows_selection), + } + annotated_hits: List[MemoryHit] = [] + for index, hit in enumerate(hits): + candidate_id = _clean_text(hit.memory_id) or f"hit:{index}" + prediction = predictions.get(candidate_id, {}) + metadata = dict(hit.metadata or {}) + metadata.update( + { + "injection_planner_enabled": True, + "injection_planner_mode": normalized_mode or "observe", + "injection_planner_model_path": self._injection_planner_resolved_path, + "injection_planner_candidate_id": candidate_id, + "injection_planner_score": round(float(prediction.get("selection_score", 0.0)), 6), + "injection_planner_selected": bool(prediction.get("selected", False)), + "injection_planner_temporal_state": _clean_text(prediction.get("temporal_state", "")), + "injection_planner_evidence_role": _clean_text(prediction.get("evidence_role", "")), + "injection_planner_role_allows_selection": bool(prediction.get("role_allows_selection", False)), + "injection_planner_logic_roles": list(prediction.get("logic_roles", []) or []), + "injection_planner_candidate_layer": _clean_text(prediction.get("candidate_layer", "")), + "injection_planner_injection_mode": injection_mode, + "injection_planner_should_inject_score": round(float(should_inject_score), 6), + "injection_planner_thresholds": dict(thresholds), + "injection_planner_evidence_role_supported": bool(self._injection_planner_evidence_role_supported), + } + ) + planner_score = float(prediction.get("selection_score", 0.0)) + next_score = max(float(hit.score), planner_score) if bool(prediction.get("selected", False)) else float(hit.score) + annotated_hits.append( + MemoryHit( + memory_id=hit.memory_id, + category=hit.category, + value=hit.value, + relation=hit.relation, + anchors=list(hit.anchors), + score=round(next_score, 6), + source_kind=hit.source_kind, + slot_key=hit.slot_key, + state=hit.state, + turn_index=int(hit.turn_index), + metadata=metadata, + ) + ) + selected_count = sum(1 for item in annotated_hits if bool((item.metadata or {}).get("injection_planner_selected"))) + if normalized_mode in _INJECTION_PLANNER_FORCE_MODES and selected_count: + selected_ids = {hit.memory_id for hit in annotated_hits if bool((hit.metadata or {}).get("injection_planner_selected"))} + planned_hits = [hit for hit in annotated_hits if hit.memory_id in selected_ids] + elif normalized_mode in _INJECTION_PLANNER_GUIDED_MODES and selected_count: + planned_hits = sorted( + annotated_hits, + key=lambda hit: ( + not bool((hit.metadata or {}).get("injection_planner_selected")), + -float((hit.metadata or {}).get("injection_planner_score", 0.0) or 0.0), + -float(hit.score or 0.0), + ), + ) + else: + planned_hits = annotated_hits + base_metadata.update( + { + "injection_planner_enabled": True, + "injection_planner_reason": "ok", + "injection_planner_model_path": self._injection_planner_resolved_path, + "injection_planner_candidate_count": len(candidates), + "injection_planner_selected_count": int(selected_count), + "injection_planner_should_inject_score": round(float(should_inject_score), 6), + "injection_planner_injection_mode": injection_mode, + "injection_planner_thresholds": dict(thresholds), + "injection_planner_evidence_role_supported": bool(self._injection_planner_evidence_role_supported), + "injection_planner_guided": bool(normalized_mode in _INJECTION_PLANNER_GUIDED_MODES | _INJECTION_PLANNER_FORCE_MODES), + "injection_planner_prediction_ids": [ + candidate_id + for candidate_id, prediction in predictions.items() + if bool(prediction.get("selected", False)) + ], + } + ) + return {"hits": planned_hits[: max(len(planned_hits), top_k)], "metadata": base_metadata} + + def _node_scorer(self) -> LoadedNodeMemoryScorer | None: + if self.retrieval_mode != "hybrid_node_scored": + return None + if self._loaded_node_scorer is not None: + return self._loaded_node_scorer + if not self.node_model_path: + self._node_scorer_error = "node_model_path_missing" + return None + try: + self._loaded_node_scorer = LoadedNodeMemoryScorer( + node_model_path=Path(self.node_model_path), + path_model_path=Path(self.path_model_path) if self.path_model_path else None, + device=self.node_model_device or None, + ) + except Exception as exc: + self._node_scorer_error = str(exc) + self._loaded_node_scorer = None + return self._loaded_node_scorer + + def _apply_writeback_promotions(self, stored_ids: Sequence[str]) -> List[Dict[str, Any]]: + promotion_events: List[Dict[str, Any]] = [] + for memory_id in stored_ids: + record = self.graph.records_by_id.get(memory_id) + if record is None or not isinstance(record.metadata, dict): + continue + if _normalize(record.metadata.get("memory_role", "")) != "assistant" or _normalize(record.metadata.get("authority", "")) != "derived": + continue + canonical_slot_key = _clean_text(record.metadata.get("canonical_slot_key", "")) + writeback_class = _clean_text(record.metadata.get("writeback_class", "")) + if not canonical_slot_key or not writeback_class: + continue + support_refs = self._support_ref_union(record.metadata) + confidence = float(record.confidence or 0.0) + same_records = [ + item + for item in self.graph.records_by_id.values() + if isinstance(item.metadata, dict) + and _normalize(item.metadata.get("memory_role", "")) == "assistant" + and _normalize(item.metadata.get("canonical_slot_key", "")) == _normalize(canonical_slot_key) + and _normalize(item.metadata.get("writeback_class", "")) == _normalize(writeback_class) + and _normalize(item.value) == _normalize(record.value) + ] + qualifying = [item for item in same_records if float(item.confidence or 0.0) >= 0.9] + distinct_answers = { + _clean_text(answer_id) + for item in qualifying + for answer_id in [*list(item.metadata.get("origin_answer_ids", []) or []), _clean_text(item.metadata.get("origin_answer_id", ""))] + if _clean_text(answer_id) + } + aggregated_support = set() + for item in same_records: + aggregated_support.update(self._support_ref_union(item.metadata)) + fast_promotion = writeback_class in {"fact", "state_change"} and confidence >= 0.97 and len(support_refs) >= 3 + standard_promotion = len(distinct_answers) >= 2 and len(aggregated_support) >= 2 + if not (fast_promotion or standard_promotion): + record.metadata["promotion_state"] = "candidate" + continue + source_head = self._source_head_for_canonical(canonical_slot_key) + blocked_conflict = source_head is not None and _normalize(source_head.value) != _normalize(record.value) + promoted_slot = f"promoted.{canonical_slot_key}" + promoted_metadata = { + **dict(record.metadata or {}), + "memory_role": "assistant", + "authority": "promoted", + "canonical_slot_key": canonical_slot_key, + "writeback_class": writeback_class, + "promotion_state": "blocked_conflict" if blocked_conflict else "promoted", + "support_memory_ids": sorted({*list(record.metadata.get("support_memory_ids", []) or []), *[ref for ref in aggregated_support if ref.startswith("fact:") is False and ref.startswith("path:") is False]}), + "support_fact_refs": sorted({*list(record.metadata.get("support_fact_refs", []) or []), *[ref for ref in aggregated_support if ref.startswith("fact:")]}), + "support_path_refs": sorted({*list(record.metadata.get("support_path_refs", []) or []), *[ref for ref in aggregated_support if ref.startswith("path:")]}), + } + promoted_record = SessionMemoryRecordV2( + memory_id=f"{promoted_slot}:{record.turn_index}:promoted", + category=record.category, + slot_key=promoted_slot, + value=record.value, + relation=record.relation, + anchor_concepts=list(record.anchor_concepts), + evidence_anchors=list(record.evidence_anchors), + salience=max(float(record.salience), 0.78), + confidence=max(float(record.confidence), 0.9), + source_kind=f"promoted_{record.source_kind}", + turn_index=int(record.turn_index), + state="active", + metadata=promoted_metadata, + ) + promoted_ids = self.graph.add_records([promoted_record]) + record.metadata["promotion_state"] = "blocked_conflict" if blocked_conflict else "promoted" + promotion_events.append( + { + "source_memory_id": record.memory_id, + "promoted_record_ids": list(promoted_ids), + "canonical_slot_key": canonical_slot_key, + "writeback_class": writeback_class, + "promotion_state": record.metadata["promotion_state"], + "blocked_conflict": bool(blocked_conflict), + } + ) + return promotion_events + + def _source_head_for_canonical(self, canonical_slot_key: str) -> SessionMemoryRecordV2 | None: + head_id = self.graph.slot_heads.get(canonical_slot_key) + if not head_id: + return None + record = self.graph.records_by_id.get(head_id) + if record is None or not isinstance(record.metadata, dict): + return None + if _normalize(record.metadata.get("memory_role", "")) != "user" or _normalize(record.metadata.get("authority", "")) != "source": + return None + return record + + def _support_ref_union(self, metadata: Dict[str, Any]) -> set[str]: + return { + *[_clean_text(item) for item in list(metadata.get("support_memory_ids", []) or []) if _clean_text(item)], + *[_clean_text(item) for item in list(metadata.get("support_fact_refs", []) or []) if _clean_text(item)], + *[_clean_text(item) for item in list(metadata.get("support_path_refs", []) or []) if _clean_text(item)], + } + + def _hybrid_node_scored_hits( + self, + query: str, + hits: Sequence[MemoryHit], + *, + top_k: int, + public_hits: Sequence[MemoryHit] | None = None, + ) -> Dict[str, Any]: + scorer = self._node_scorer() + if scorer is None: + return { + "hits": list(hits), + "metadata": { + "retrieval_mode": "heuristic", + "hybrid_enabled": False, + "hybrid_error": self._node_scorer_error, + }, + } + source_hits = _learnable_graph_hits(self.graph) + if not source_hits: + return { + "hits": list(hits), + "metadata": { + "retrieval_mode": "heuristic", + "hybrid_enabled": False, + "hybrid_error": "no_learnable_hits", + }, + } + has_public_hits = any( + _is_public_dialog_hit(hit) + and not bool((hit.metadata or {}).get("profile_layer")) + for hit in source_hits + ) + has_generic_hits = any(not _is_public_dialog_hit(hit) for hit in source_hits) + hybrid_source = "mixed_full_graph" + if has_public_hits and not has_generic_hits: + hybrid_source = "public_full_graph" + elif has_generic_hits and not has_public_hits: + hybrid_source = "generic_full_graph" + runtime_graph = _build_runtime_graph_from_hits(query, source_hits) + grouped_hits = dict(runtime_graph.get("grouped_hits", {}) or {}) + profile_first_payload = _profile_first_hybrid_rescue( + self.graph, + query, + grouped_hits=grouped_hits, + top_k=top_k, + ) + profile_first_hits = list(profile_first_payload.get("hits", []) or []) + profile_first_event_ids = list(profile_first_payload.get("event_ids", []) or []) + profile_first_memory_ids = list(profile_first_payload.get("memory_ids", []) or []) + candidate_event_ids = sorted(grouped_hits.keys()) + if not candidate_event_ids: + return { + "hits": list(hits), + "metadata": { + "retrieval_mode": "heuristic", + "hybrid_enabled": False, + "hybrid_error": "no_event_candidates", + }, + } + question_analysis = extract_question_features(query) + hybrid_candidate_limit = min( + len(candidate_event_ids), + max( + _HYBRID_SELECTED_EVENT_FLOOR, + int(self.candidate_event_k) * 4, + int(self.support_path_k) * 4, + int(top_k) * 4, + ), + ) + embedder_pre_recall_mode = _normalize(getattr(self, "embedder_pre_recall_mode", "")) + embedder_pre_recall_enabled = embedder_pre_recall_mode not in _EMBEDDER_INDEX_DISABLED_MODES + embedder_pre_recall_index_mode = self.embedder_index_recall_mode + if embedder_pre_recall_mode not in {"1", "true", "yes", "on", "auto", "seed", "candidate", "candidates"}: + embedder_pre_recall_index_mode = embedder_pre_recall_mode + pre_embedder_index_payload: Dict[str, Any] = { + "event_ids": [], + "metadata": { + "embedder_pre_recall_enabled": False, + "embedder_pre_recall_mode": embedder_pre_recall_mode or "off", + "embedder_pre_recall_index_mode": embedder_pre_recall_index_mode or "off", + "embedder_pre_recall_event_ids": [], + }, + } + pre_embedder_event_ids: List[str] = [] + pre_candidate_event_ids: List[str] = [] + if embedder_pre_recall_enabled: + pre_embedder_index_payload = _embedder_index_recall_event_ids( + query, + grouped_hits=grouped_hits, + mode=embedder_pre_recall_index_mode, + limit=self.embedder_pre_recall_k or self.embedder_index_recall_k or hybrid_candidate_limit, + max_terms=self.write_embedder_index_max_terms, + ) + pre_embedder_event_ids = list(pre_embedder_index_payload.get("event_ids", []) or []) + pre_candidate_event_ids = _bounded_event_id_union( + pre_embedder_event_ids, + max_items=hybrid_candidate_limit, + ) + pre_score_kwargs = { + "graph": runtime_graph, + "question": query, + "question_features": question_analysis, + "rerank_top_k": self.candidate_event_k, + "event_rerank_mode": self.event_rerank_mode, + "matrix_event_top_k": self.matrix_event_top_k, + "support_path_k": self.support_path_k, + "top_k": top_k, + } + if pre_candidate_event_ids: + pre_score_kwargs["candidate_event_ids"] = pre_candidate_event_ids + scored = _call_with_supported_kwargs( + scorer.score_runtime, + **pre_score_kwargs, + ) + memory_router_decision = _memory_router_decision( + scored, + mode=self.memory_router_mode, + threshold=self.memory_router_threshold, + margin=self.memory_router_margin, + ) + profile_first_router_suppressed = False + if not _memory_router_allows(memory_router_decision, "profile", "resource"): + profile_first_hits = [] + profile_first_event_ids = [] + profile_first_memory_ids = [] + profile_first_router_suppressed = True + initial_recall_event_scores = dict(scored.get("recall_event_scores", {}) or {}) + model_recall_event_ids = _bounded_event_id_union( + list(scored.get("recall_event_ids", []) or []), + [ + event_id + for event_id, _ in sorted( + initial_recall_event_scores.items(), + key=lambda item: (-float(item[1]), item[0]), + ) + ], + list(scored.get("rerank_candidate_event_ids", []) or []), + max_items=max(1, len(candidate_event_ids)), + ) + learned_recall_event_ids = list(model_recall_event_ids) + symbolic_recall_event_ids = _symbolic_recall_event_ids( + query, + runtime_graph, + grouped_hits=grouped_hits, + limit=hybrid_candidate_limit, + ) + symbolic_recall_event_ids = _bounded_event_id_union( + profile_first_event_ids, + symbolic_recall_event_ids, + max_items=hybrid_candidate_limit, + ) + if pre_embedder_event_ids: + embedder_index_payload = pre_embedder_index_payload + else: + embedder_index_payload = _embedder_index_recall_event_ids( + query, + grouped_hits=grouped_hits, + mode=self.embedder_index_recall_mode, + limit=self.embedder_index_recall_k or hybrid_candidate_limit, + max_terms=self.write_embedder_index_max_terms, + ) + embedder_index_event_ids = list(embedder_index_payload.get("event_ids", []) or []) + embedder_index_metadata = dict(embedder_index_payload.get("metadata", {}) or {}) + embedder_index_metadata.update( + { + "embedder_pre_recall_enabled": bool(pre_candidate_event_ids), + "embedder_pre_recall_mode": embedder_pre_recall_mode or "off", + "embedder_pre_recall_index_mode": embedder_pre_recall_index_mode or "off", + "embedder_pre_recall_event_ids": list(pre_embedder_event_ids), + "embedder_pre_recall_candidate_event_ids": list(pre_candidate_event_ids), + "embedder_pre_recall_candidate_count": int(len(pre_candidate_event_ids)), + } + ) + hybrid_candidate_event_ids = _bounded_event_id_union( + profile_first_event_ids, + embedder_index_event_ids, + learned_recall_event_ids, + symbolic_recall_event_ids, + max_items=hybrid_candidate_limit, + ) + learned_candidate_event_ids = _bounded_event_id_union( + learned_recall_event_ids, + max_items=hybrid_candidate_limit, + ) + learned_recall_event_id_set = set(learned_candidate_event_ids) + hybrid_candidate_union_added_event_ids = [ + event_id + for event_id in hybrid_candidate_event_ids + if event_id not in learned_recall_event_id_set + ] + hybrid_candidate_union_priority_changed = list(hybrid_candidate_event_ids) != list(learned_candidate_event_ids) + hybrid_candidate_union_rescored = False + if hybrid_candidate_union_added_event_ids or (embedder_index_event_ids and hybrid_candidate_union_priority_changed): + scored = _call_with_supported_kwargs( + scorer.score_runtime, + graph=runtime_graph, + question=query, + question_features=question_analysis, + candidate_event_ids=hybrid_candidate_event_ids, + rerank_top_k=self.candidate_event_k, + event_rerank_mode=self.event_rerank_mode, + matrix_event_top_k=self.matrix_event_top_k, + support_path_k=self.support_path_k, + top_k=top_k, + ) + hybrid_candidate_union_rescored = True + memory_router_decision = _memory_router_decision( + scored, + mode=self.memory_router_mode, + threshold=self.memory_router_threshold, + margin=self.memory_router_margin, + ) + recall_event_scores = dict(scored.get("recall_event_scores", {}) or {}) + rerank_candidate_event_ids = list(scored.get("rerank_candidate_event_ids", []) or []) + base_event_scores = dict(scored.get("base_event_scores", {}) or {}) + rerank_event_scores = dict(scored.get("rerank_event_scores", {}) or {}) + calibrated_event_scores = dict(scored.get("calibrated_event_scores", {}) or {}) + matrix_event_scores = dict(scored.get("matrix_event_scores", {}) or {}) + event_fusion_delta_scores = dict(scored.get("event_fusion_delta_scores", {}) or {}) + event_tunnel_support_scores = dict(scored.get("event_tunnel_support_scores", {}) or {}) + event_tunnel_delta_scores = dict(scored.get("event_tunnel_delta_scores", {}) or {}) + tri_maze_event_reverse_scores = dict(scored.get("tri_maze_event_reverse_scores", {}) or {}) + tri_maze_event_boundary_scores = dict(scored.get("tri_maze_event_boundary_scores", {}) or {}) + tri_maze_event_reverse_relations = dict(scored.get("tri_maze_event_reverse_relations", {}) or {}) + matrix_rerank_event_ids = list(scored.get("matrix_rerank_event_ids", []) or []) + matrix_enabled = bool(scored.get("matrix_enabled", False)) + rerank_path_scores = dict(scored.get("rerank_path_scores", {}) or {}) + matrix_path_scores = dict(scored.get("matrix_path_scores", {}) or {}) + tri_maze_path_reverse_scores = dict(scored.get("tri_maze_path_reverse_scores", {}) or {}) + tri_maze_path_boundary_scores = dict(scored.get("tri_maze_path_boundary_scores", {}) or {}) + tri_maze_path_reverse_relations = dict(scored.get("tri_maze_path_reverse_relations", {}) or {}) + matrix_path_rerank_ids = list(scored.get("matrix_path_rerank_ids", []) or []) + matrix_path_enabled = bool(scored.get("matrix_path_enabled", False)) + fusion_enabled = bool(scored.get("fusion_enabled", False)) + event_fusion_enabled = bool(scored.get("event_fusion_enabled", fusion_enabled)) + path_fusion_enabled = bool(scored.get("path_fusion_enabled", fusion_enabled)) + event_calibration_enabled = bool(scored.get("event_calibration_enabled", False)) + path_calibration_enabled = bool(scored.get("path_calibration_enabled", False)) + event_tunnel_enabled = bool(scored.get("event_tunnel_enabled", False)) + path_tunnel_enabled = bool(scored.get("path_tunnel_enabled", False)) + final_event_fusion_enabled = bool(scored.get("final_event_fusion_enabled", False)) + final_path_fusion_enabled = bool(scored.get("final_path_fusion_enabled", False)) + decision_fusion_enabled = bool(scored.get("decision_fusion_enabled", False)) + decision_score_source = _clean_text(scored.get("decision_score_source", "")) + event_scores = dict(scored.get("event_scores", {}) or {}) + base_path_scores = dict(scored.get("base_path_scores", {}) or {}) + calibrated_path_scores = dict(scored.get("calibrated_path_scores", {}) or {}) + path_fusion_delta_scores = dict(scored.get("path_fusion_delta_scores", {}) or {}) + path_tunnel_support_scores = dict(scored.get("path_tunnel_support_scores", {}) or {}) + path_tunnel_delta_scores = dict(scored.get("path_tunnel_delta_scores", {}) or {}) + path_model_scores = dict(scored.get("path_model_scores", {}) or {}) + path_chain_extension_delta_scores = dict(scored.get("path_chain_extension_delta_scores", {}) or {}) + path_chain_extended_scores = dict(scored.get("path_chain_extended_scores", {}) or {}) + path_chain_extension_enabled = bool(scored.get("path_chain_extension_enabled", False)) + answer_type_scores = dict(scored.get("answer_type_scores", {}) or {}) + selected_event_ids_from_model = list(scored.get("selected_event_ids", []) or []) + selected_path_ids_from_model = list(scored.get("selected_path_ids", []) or []) + focused_answer_type_from_model = _clean_text(scored.get("focused_answer_type", "")) + path_scores = dict(scored.get("path_scores", {}) or {}) + temporal_scores = dict(scored.get("temporal_scores", {}) or {}) + runtime_paths = {_clean_text(path.get("id", "")): dict(path) for path in list(runtime_graph.get("paths", []) or [])} + answer_plan_scores_raw = dict(scored.get("answer_plan_scores", {}) or {}) + + def _answer_plan_score_map(role: str) -> Dict[str, float]: + raw_scores = answer_plan_scores_raw.get(role, {}) + if not isinstance(raw_scores, Mapping): + return {} + normalized: Dict[str, float] = {} + for raw_event_id, raw_score in raw_scores.items(): + event_id = _clean_text(raw_event_id) + if not event_id: + continue + try: + normalized[event_id] = float(raw_score or 0.0) + except (TypeError, ValueError): + normalized[event_id] = 0.0 + return normalized + + answer_plan_selected_scores = _answer_plan_score_map("selected") + answer_plan_current_scores = _answer_plan_score_map("current") + answer_plan_historical_scores = _answer_plan_score_map("historical") + answer_plan_suppressed_scores = _answer_plan_score_map("suppressed") + answer_plan_scores = { + "selected": dict(answer_plan_selected_scores), + "current": dict(answer_plan_current_scores), + "historical": dict(answer_plan_historical_scores), + "suppressed": dict(answer_plan_suppressed_scores), + } + try: + answer_plan_event_selection_threshold = float( + os.getenv("TMCRA_ANSWER_PLAN_EVENT_SELECTION_THRESHOLD", "0.50") or 0.50 + ) + except (TypeError, ValueError): + answer_plan_event_selection_threshold = 0.50 + try: + answer_plan_event_selection_top_k = int(os.getenv("TMCRA_ANSWER_PLAN_EVENT_SELECTION_TOP_K", "0") or 0) + except (TypeError, ValueError): + answer_plan_event_selection_top_k = 0 + if answer_plan_event_selection_top_k <= 0: + answer_plan_event_selection_top_k = max(top_k, self.support_path_k * 2, _HYBRID_SELECTED_EVENT_FLOOR) + answer_plan_available_event_ids = { + _clean_text(event_id) + for event_id in [ + *list(grouped_hits.keys()), + *[_clean_text(path.get("event_id", "")) for path in runtime_paths.values()], + ] + if _clean_text(event_id) + } + answer_plan_event_rows: List[tuple[str, float, float, float, float, float]] = [] + answer_plan_event_ids = set(answer_plan_selected_scores) | set(answer_plan_current_scores) | set(answer_plan_historical_scores) + for event_id in answer_plan_event_ids: + if answer_plan_available_event_ids and event_id not in answer_plan_available_event_ids: + continue + selected_score = float(answer_plan_selected_scores.get(event_id, 0.0) or 0.0) + current_score = float(answer_plan_current_scores.get(event_id, 0.0) or 0.0) + historical_score = float(answer_plan_historical_scores.get(event_id, 0.0) or 0.0) + suppressed_score = float(answer_plan_suppressed_scores.get(event_id, 0.0) or 0.0) + support_score = max(selected_score, current_score) + adjusted_score = support_score - max(0.0, suppressed_score - support_score) * 0.5 + if support_score < answer_plan_event_selection_threshold or suppressed_score > support_score: + continue + answer_plan_event_rows.append( + (event_id, adjusted_score, support_score, selected_score, current_score, historical_score) + ) + answer_plan_event_rows.sort( + key=lambda row: (-float(row[1]), -float(row[2]), -float(row[3]), -float(row[4]), row[0]) + ) + answer_plan_raw_ranked_event_ids = [ + event_id for event_id, *_ in answer_plan_event_rows[: max(1, answer_plan_event_selection_top_k)] + ] + try: + answer_plan_promotion_min_margin = float( + os.getenv("TMCRA_ANSWER_PLAN_PROMOTION_MIN_MARGIN", "0.02") or 0.02 + ) + except (TypeError, ValueError): + answer_plan_promotion_min_margin = 0.02 + answer_plan_promotion_score_margin = 0.0 + if len(answer_plan_event_rows) == 1: + answer_plan_promotion_score_margin = float(answer_plan_event_rows[0][1]) + answer_plan_promotion_enabled = True + elif len(answer_plan_event_rows) > 1: + comparison_index = min(len(answer_plan_event_rows) - 1, max(1, min(answer_plan_event_selection_top_k, 5) - 1)) + answer_plan_promotion_score_margin = float(answer_plan_event_rows[0][1]) - float(answer_plan_event_rows[comparison_index][1]) + answer_plan_promotion_enabled = answer_plan_promotion_score_margin >= answer_plan_promotion_min_margin + else: + answer_plan_promotion_enabled = False + answer_plan_ranked_event_ids = list(answer_plan_raw_ranked_event_ids if answer_plan_promotion_enabled else []) + answer_plan_selected_event_ids = [ + event_id + for event_id in answer_plan_ranked_event_ids + if float(answer_plan_selected_scores.get(event_id, 0.0) or 0.0) + >= answer_plan_event_selection_threshold + ] + answer_plan_current_event_ids = [ + event_id + for event_id in answer_plan_ranked_event_ids + if float(answer_plan_current_scores.get(event_id, 0.0) or 0.0) + >= answer_plan_event_selection_threshold + ] + answer_plan_support_scores = { + event_id: round(float(support_score), 6) + for event_id, _, support_score, *_ in answer_plan_event_rows + } + answer_plan_adjusted_scores = { + event_id: round(float(adjusted_score), 6) + for event_id, adjusted_score, *_ in answer_plan_event_rows + } + answer_plan_rank_lookup = { + event_id: rank for rank, event_id in enumerate(answer_plan_ranked_event_ids, start=1) + } + answer_plan_ranked_event_id_set = set(answer_plan_ranked_event_ids) + + def _answer_plan_hit_metadata(event_id: str) -> Dict[str, Any]: + clean_event_id = _clean_text(event_id) + selected_score = float(answer_plan_selected_scores.get(clean_event_id, 0.0) or 0.0) + current_score = float(answer_plan_current_scores.get(clean_event_id, 0.0) or 0.0) + historical_score = float(answer_plan_historical_scores.get(clean_event_id, 0.0) or 0.0) + suppressed_score = float(answer_plan_suppressed_scores.get(clean_event_id, 0.0) or 0.0) + support_score = max(selected_score, current_score) + return { + "answer_plan_score": round(float(support_score), 6), + "answer_plan_selected_score": round(float(selected_score), 6), + "answer_plan_current_score": round(float(current_score), 6), + "answer_plan_historical_score": round(float(historical_score), 6), + "answer_plan_suppressed_score": round(float(suppressed_score), 6), + "answer_plan_adjusted_score": round(float(answer_plan_adjusted_scores.get(clean_event_id, 0.0)), 6), + "answer_plan_selected": bool(clean_event_id in answer_plan_ranked_event_id_set), + "answer_plan_rank": int(answer_plan_rank_lookup.get(clean_event_id, 0) or 0), + } + + embedder_fusion_mode = _normalize(self.embedder_fusion_mode) + embedder_fusion_enabled = ( + embedder_fusion_mode not in _EMBEDDER_INDEX_DISABLED_MODES + and bool(embedder_index_event_ids) + and bool(self.embedder_fusion_top_k > 0) + and bool(self.embedder_fusion_weight > 0.0) + ) + embedder_fusion_applied_event_scores: Dict[str, float] = {} + embedder_fusion_boosts: Dict[str, float] = {} + embedder_event_scores = dict(embedder_index_metadata.get("embedder_index_event_scores", {}) or {}) + if embedder_fusion_enabled and embedder_event_scores: + ranked_embedder_events = [ + (event_id, float(score)) + for event_id, score in sorted( + embedder_event_scores.items(), + key=lambda item: (-float(item[1]), item[0]), + ) + if _clean_text(event_id) + ][: max(1, int(self.embedder_fusion_top_k))] + for rank, (event_id, embedder_score) in enumerate(ranked_embedder_events, start=1): + if embedder_score < float(self.embedder_fusion_score_floor): + continue + rank_bonus = max(0.0, 0.08 - (rank - 1) * 0.006) + boost = min( + float(self.embedder_fusion_max_boost), + (float(self.embedder_fusion_weight) * embedder_score) + rank_bonus, + ) + current_score = max( + float(recall_event_scores.get(event_id, 0.0) or 0.0), + float(event_scores.get(event_id, 0.0) or 0.0), + float(base_event_scores.get(event_id, 0.0) or 0.0), + float(rerank_event_scores.get(event_id, 0.0) or 0.0), + float(calibrated_event_scores.get(event_id, 0.0) or 0.0), + ) + fused_score = current_score + boost + recall_event_scores[event_id] = max(float(recall_event_scores.get(event_id, 0.0) or 0.0), fused_score) + event_scores[event_id] = max(float(event_scores.get(event_id, 0.0) or 0.0), fused_score) + base_event_scores[event_id] = max(float(base_event_scores.get(event_id, 0.0) or 0.0), fused_score) + rerank_event_scores[event_id] = max(float(rerank_event_scores.get(event_id, 0.0) or 0.0), fused_score) + calibrated_event_scores[event_id] = max(float(calibrated_event_scores.get(event_id, 0.0) or 0.0), fused_score) + embedder_fusion_applied_event_scores[event_id] = round(fused_score, 6) + embedder_fusion_boosts[event_id] = round(boost, 6) + if embedder_fusion_applied_event_scores: + decision_score_source = f"{decision_score_source or 'learned_decision_fusion'}+embedder_fusion" + recall_event_ids = [ + event_id + for event_id, _ in sorted(recall_event_scores.items(), key=lambda item: (-float(item[1]), item[0])) + ][: self.candidate_event_k] + if profile_first_event_ids: + profile_first_score_by_event = { + _clean_text((hit.metadata or {}).get("profile_first_hybrid_event_id", "")) or _runtime_event_key(hit): float(hit.score) + for hit in profile_first_hits + } + for rank, event_id in enumerate(profile_first_event_ids, start=1): + if not event_id: + continue + floor = max(0.92, float(profile_first_score_by_event.get(event_id, 0.0))) + max(0.0, 0.04 - (rank * 0.005)) + recall_event_scores[event_id] = max(float(recall_event_scores.get(event_id, 0.0)), floor) + event_scores[event_id] = max(float(event_scores.get(event_id, 0.0)), floor) + base_event_scores[event_id] = max(float(base_event_scores.get(event_id, 0.0)), floor) + rerank_event_scores[event_id] = max(float(rerank_event_scores.get(event_id, 0.0)), floor) + calibrated_event_scores[event_id] = max(float(calibrated_event_scores.get(event_id, 0.0)), floor) + for path_id, path in runtime_paths.items(): + event_id = _clean_text(path.get("event_id", "")) + path_type = _clean_text(path.get("type", "")) + if event_id not in set(profile_first_event_ids): + continue + if path_type not in {"speaker_event_profile", "speaker_event_source_turn", "speaker_event_status"}: + continue + floor = max(0.90, float(event_scores.get(event_id, 0.0))) + path_scores[path_id] = max(float(path_scores.get(path_id, 0.0)), floor) + base_path_scores[path_id] = max(float(base_path_scores.get(path_id, 0.0)), floor) + calibrated_path_scores[path_id] = max(float(calibrated_path_scores.get(path_id, 0.0)), floor) + path_model_scores[path_id] = max(float(path_model_scores.get(path_id, 0.0)), floor) + selection_consistency_repaired = False + selection_consistency_reason = "" + model_focused_answer_type = _clean_text(focused_answer_type_from_model) + embedder_fusion_selected_event_ids: List[str] = [] + embedder_fusion_selected_path_ids: List[str] = [] + if decision_fusion_enabled and (selected_path_ids_from_model or selected_event_ids_from_model): + base_selected_path_limit = max(1, min(max(1, self.support_path_k), max(1, top_k))) + selected_path_ids = [ + path_id + for path_id in selected_path_ids_from_model + if _clean_text(path_id) in runtime_paths + ] + if not selected_path_ids: + selected_path_ids = [ + path_id + for path_id, _ in sorted( + ((path_id, float(score or 0.0)) for path_id, score in path_scores.items()), + key=lambda item: (-float(item[1]), item[0]), + ) + ][:base_selected_path_limit] + focused_answer_type = _reconciled_focused_answer_type(question_analysis, answer_type_scores, focused_answer_type_from_model) + tunnel_rescue_path_ids: List[str] = [] + tunnel_rescue_pre_filter_path_ids: List[str] = [] + path_utility_direct_support_path_ids: List[str] = [] + path_utility_contrast_support_path_ids: List[str] = [] + path_utility_latent_context_path_ids: List[str] = [] + path_utility_drift_noise_path_ids: List[str] = [] + path_utility_roles: Dict[str, str] = {} + path_utility_reasons: Dict[str, str] = {} + path_utility_scores: Dict[str, float] = {} + path_utility_overlap_tokens: Dict[str, List[str]] = {} + path_utility_anchor_event_ids: List[str] = [] + path_utility_anchor_subject_signatures: List[str] = [] + tunnel_rescue_score_threshold = float(self.path_tunnel_rescue_score_floor) + tunnel_rescue_candidate_count = 0 + tunnel_rescue_filtered_count = 0 + if ( + self.path_tunnel_rescue_k > 0 + and path_tunnel_support_scores + and _memory_router_allows(memory_router_decision, "path_tunnel", "topic_tunnel") + ): + selected_path_id_set = set(selected_path_ids) + candidate_scores = [ + float(score or 0.0) + for path_id, score in path_tunnel_support_scores.items() + if _clean_text(path_id) in runtime_paths + ] + tunnel_rescue_candidate_count = len(candidate_scores) + if candidate_scores and self.path_tunnel_rescue_min_score_margin > 0.0: + sorted_scores = sorted(candidate_scores) + median_score = sorted_scores[len(sorted_scores) // 2] + tunnel_rescue_score_threshold = max( + tunnel_rescue_score_threshold, + median_score + float(self.path_tunnel_rescue_min_score_margin), + ) + event_turn_indices = [ + _runtime_event_turn_index_from_id(_clean_text(path.get("event_id", ""))) + for path in runtime_paths.values() + ] + current_turn_index = max( + [int(getattr(self.graph, "turn_index", 0) or 0), *[turn for turn in event_turn_indices if turn > 0]], + default=0, + ) + ranked_tunnel_path_ids = [] + for path_id, score in sorted( + ( + (path_id, float(score or 0.0)) + for path_id, score in path_tunnel_support_scores.items() + if _clean_text(path_id) in runtime_paths + ), + key=lambda item: (-float(item[1]), item[0]), + ): + if path_id in selected_path_id_set or score < tunnel_rescue_score_threshold: + continue + path_event_turn = _runtime_event_turn_index_from_id( + _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + ) + path_age = max(0, current_turn_index - path_event_turn) if current_turn_index and path_event_turn else 0 + if self.path_tunnel_rescue_min_age > 0 and path_age < self.path_tunnel_rescue_min_age: + continue + ranked_tunnel_path_ids.append(path_id) + tunnel_rescue_filtered_count = len(ranked_tunnel_path_ids) + tunnel_rescue_pre_filter_path_ids = ranked_tunnel_path_ids[: max(self.path_tunnel_rescue_k, self.path_tunnel_rescue_k * 4)] + utility_gate = _path_utility_gate( + tunnel_rescue_pre_filter_path_ids, + query=query, + runtime_graph=runtime_graph, + runtime_paths=runtime_paths, + grouped_hits=grouped_hits, + selected_path_ids=selected_path_ids, + selected_event_ids_from_model=selected_event_ids_from_model, + path_scores=path_scores, + path_tunnel_support_scores=path_tunnel_support_scores, + question_analysis=question_analysis, + focused_answer_type=focused_answer_type, + score_threshold=tunnel_rescue_score_threshold, + limit=self.path_tunnel_rescue_k, + ) + tunnel_rescue_path_ids = list(utility_gate.get("injected_path_ids", []) or []) + path_utility_direct_support_path_ids = list(utility_gate.get("direct_support_path_ids", []) or []) + path_utility_contrast_support_path_ids = list(utility_gate.get("contrast_support_path_ids", []) or []) + path_utility_latent_context_path_ids = list(utility_gate.get("latent_context_path_ids", []) or []) + path_utility_drift_noise_path_ids = list(utility_gate.get("drift_noise_path_ids", []) or []) + path_utility_roles = dict(utility_gate.get("roles", {}) or {}) + path_utility_reasons = dict(utility_gate.get("reasons", {}) or {}) + path_utility_scores = dict(utility_gate.get("scores", {}) or {}) + path_utility_overlap_tokens = dict(utility_gate.get("overlap_tokens", {}) or {}) + path_utility_anchor_event_ids = list(utility_gate.get("anchor_event_ids", []) or []) + path_utility_anchor_subject_signatures = list(utility_gate.get("anchor_subject_signatures", []) or []) + if tunnel_rescue_path_ids: + selected_path_ids = _dedupe([*selected_path_ids, *tunnel_rescue_path_ids]) + effective_selected_path_limit = base_selected_path_limit + len(tunnel_rescue_path_ids) + selected_event_ids = _dedupe( + [ + *[ + _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + for path_id in selected_path_ids + if _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + ], + *[ + _clean_text(event_id) + for event_id in selected_event_ids_from_model + if _clean_text(event_id) + ], + ] + ) + if answer_plan_ranked_event_ids: + selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids]) + if not selected_event_ids: + selected_event_ids = [ + event_id + for event_id, _ in sorted( + ((event_id, float(score or 0.0)) for event_id, score in event_scores.items()), + key=lambda item: (-float(item[1]), item[0]), + ) + ][: max(1, min(self.candidate_event_k, max(top_k, _HYBRID_SELECTED_EVENT_FLOOR)))] + repaired_path_ids, selection_consistency_repaired, selection_consistency_reason = _repair_selected_paths_for_focus( + selected_path_ids, + runtime_paths=runtime_paths, + selected_event_ids=selected_event_ids, + path_scores=path_scores, + event_scores=event_scores, + temporal_scores=temporal_scores, + question_analysis=question_analysis, + answer_type_scores=answer_type_scores, + focused_answer_type=focused_answer_type, + limit=max(1, effective_selected_path_limit), + ) + if selection_consistency_repaired: + selected_path_ids = repaired_path_ids + selected_event_ids = _dedupe( + [ + *[ + _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + for path_id in selected_path_ids + if _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + ], + *[ + _clean_text(event_id) + for event_id in selected_event_ids_from_model + if _clean_text(event_id) + ], + ] + ) + if answer_plan_ranked_event_ids: + selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids]) + if profile_first_event_ids: + selected_event_ids = _dedupe([*profile_first_event_ids, *selected_event_ids]) + if embedder_fusion_applied_event_scores and self.embedder_fusion_select_k > 0: + ranked_fusion_event_ids = [ + event_id + for event_id in embedder_index_event_ids + if event_id in embedder_fusion_applied_event_scores + ][: max(1, int(self.embedder_fusion_select_k))] + selected_event_ids = _dedupe([*ranked_fusion_event_ids, *selected_event_ids]) + selected_path_id_set = set(selected_path_ids) + selected_embedder_path_ids: List[str] = [] + for event_id in ranked_fusion_event_ids: + candidate_paths = [ + (path_id, path) + for path_id, path in runtime_paths.items() + if _clean_text(path.get("event_id", "")) == event_id + ] + candidate_paths.sort( + key=lambda item: ( + int(_clean_text(item[1].get("type", "")) in {"speaker_event_source_turn", "speaker_event_profile", "speaker_event_status", "speaker_event_time"}), + float(path_scores.get(item[0], 0.0) or 0.0), + item[0], + ), + reverse=True, + ) + for path_id, _ in candidate_paths: + if path_id in selected_path_id_set: + continue + selected_embedder_path_ids.append(path_id) + selected_path_id_set.add(path_id) + embedder_fusion_selected_path_ids.append(path_id) + break + if selected_embedder_path_ids: + selected_path_ids = _dedupe([*selected_embedder_path_ids, *selected_path_ids]) + embedder_fusion_selected_event_ids = list(ranked_fusion_event_ids) + if answer_plan_ranked_event_ids: + selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids]) + final_hits: List[MemoryHit] = [] + seen_memory_ids = set() + selected_event_id_set = set(selected_event_ids) + for path_id in selected_path_ids: + path = runtime_paths.get(path_id, {}) + event_id = _clean_text(path.get("event_id", "")) + if event_id not in selected_event_id_set: + continue + path_type = _clean_text(path.get("type", "")) + support_node_id = _path_support_node_id(path) + support_hit = _support_hit_for_path(path_type, grouped_hits.get(event_id, [])) + event_hit = _representative_event_hit( + [*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)], + query=query, + ) + decision_score = round(float(path_scores.get(path_id, 0.0)), 6) + candidate_hits: List[tuple[MemoryHit | None, float]] = [(support_hit, decision_score)] + if event_hit is not None and (support_hit is None or event_hit.memory_id != support_hit.memory_id): + candidate_hits.append((event_hit, max(0.0, decision_score - 0.0001))) + for raw_hit, hit_score in candidate_hits: + if raw_hit is None or raw_hit.memory_id in seen_memory_ids: + continue + seen_memory_ids.add(raw_hit.memory_id) + recall_score = float(recall_event_scores.get(event_id, event_scores.get(event_id, 0.0))) + event_score = float(event_scores.get(event_id, raw_hit.score)) + temporal_score = float(temporal_scores.get(support_node_id, 0.0)) + metadata = dict(raw_hit.metadata or {}) + metadata.update( + { + "event_id": event_id, + **_answer_plan_hit_metadata(event_id), + "path_id": path_id, + "base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6), + "rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6), + "calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6), + "matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6), + "event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6), + "event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6), + "event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6), + "tri_maze_event_reverse_score": round(float(tri_maze_event_reverse_scores.get(event_id, 0.0)), 6), + "tri_maze_event_boundary_score": round(float(tri_maze_event_boundary_scores.get(event_id, 0.0)), 6), + "tri_maze_event_reverse_relation": round(float(tri_maze_event_reverse_relations.get(event_id, 0.0)), 6), + "matrix_enabled": matrix_enabled, + "event_calibration_enabled": event_calibration_enabled, + "path_calibration_enabled": path_calibration_enabled, + "event_tunnel_enabled": event_tunnel_enabled, + "path_tunnel_enabled": path_tunnel_enabled, + "final_event_fusion_enabled": final_event_fusion_enabled, + "final_path_fusion_enabled": final_path_fusion_enabled, + "decision_fusion_enabled": True, + "event_fusion_enabled": event_fusion_enabled, + "path_fusion_enabled": path_fusion_enabled, + "event_score": round(event_score, 6), + "recall_score": round(recall_score, 6), + "path_score": round(float(path_scores.get(path_id, hit_score)), 6), + "base_path_score": round(float(base_path_scores.get(path_id, path_scores.get(path_id, hit_score))), 6), + "calibrated_path_score": round(float(calibrated_path_scores.get(path_id, path_scores.get(path_id, hit_score))), 6), + "path_fusion_delta_score": round(float(path_fusion_delta_scores.get(path_id, 0.0)), 6), + "path_tunnel_support_score": round(float(path_tunnel_support_scores.get(path_id, 0.0)), 6), + "path_tunnel_delta_score": round(float(path_tunnel_delta_scores.get(path_id, 0.0)), 6), + "path_model_score": round(float(path_model_scores.get(path_id, path_scores.get(path_id, hit_score))), 6), + "path_chain_extension_enabled": path_chain_extension_enabled, + "path_chain_extension_delta_score": round(float(path_chain_extension_delta_scores.get(path_id, 0.0)), 6), + "path_chain_extended_score": round(float(path_chain_extended_scores.get(path_id, path_scores.get(path_id, hit_score))), 6), + "tri_maze_path_reverse_score": round(float(tri_maze_path_reverse_scores.get(path_id, 0.0)), 6), + "tri_maze_path_boundary_score": round(float(tri_maze_path_boundary_scores.get(path_id, 0.0)), 6), + "tri_maze_path_reverse_relation": round(float(tri_maze_path_reverse_relations.get(path_id, 0.0)), 6), + "effective_path_score": round(float(path_scores.get(path_id, hit_score)), 6), + "temporal_score": round(temporal_score, 6), + "raw_public_score": round(float(raw_hit.score), 6), + "hybrid_score": round(hit_score, 6), + "hybrid_score_source": decision_score_source or "learned_decision_fusion", + "evidence_snippet_role": "selected_path_support" if raw_hit is support_hit else "selected_path_event", + "selected_event_ids": list(selected_event_ids), + "selected_path_ids": list(selected_path_ids), + "path_tunnel_rescue_enabled": bool(self.path_tunnel_rescue_k > 0), + "path_tunnel_rescue_k": int(self.path_tunnel_rescue_k), + "path_tunnel_rescue_score_floor": round(float(self.path_tunnel_rescue_score_floor), 6), + "path_tunnel_rescue_min_age": int(self.path_tunnel_rescue_min_age), + "path_tunnel_rescue_min_score_margin": round(float(self.path_tunnel_rescue_min_score_margin), 6), + "path_tunnel_rescue_score_threshold": round(float(tunnel_rescue_score_threshold), 6), + "path_tunnel_rescue_candidate_count": int(tunnel_rescue_candidate_count), + "path_tunnel_rescue_filtered_count": int(tunnel_rescue_filtered_count), + "path_tunnel_rescue_path_ids": list(tunnel_rescue_path_ids), + "path_utility_gate_enabled": bool(self.path_tunnel_rescue_k > 0), + "path_utility_pre_filter_path_ids": list(tunnel_rescue_pre_filter_path_ids), + "path_utility_injected_path_ids": list(tunnel_rescue_path_ids), + "path_utility_direct_support_path_ids": list(path_utility_direct_support_path_ids), + "path_utility_contrast_support_path_ids": list(path_utility_contrast_support_path_ids), + "path_utility_latent_context_path_ids": list(path_utility_latent_context_path_ids), + "path_utility_drift_noise_path_ids": list(path_utility_drift_noise_path_ids), + "path_utility_roles": dict(path_utility_roles), + "path_utility_reasons": dict(path_utility_reasons), + "path_utility_scores": dict(path_utility_scores), + "path_utility_overlap_tokens": dict(path_utility_overlap_tokens), + "path_utility_anchor_event_ids": list(path_utility_anchor_event_ids), + "path_utility_anchor_subject_signatures": list(path_utility_anchor_subject_signatures), + "tunnel_recall_pre_filter_count": int(len(tunnel_rescue_pre_filter_path_ids)), + "tunnel_usable_post_filter_count": int(len(tunnel_rescue_path_ids)), + "model_focused_answer_type": model_focused_answer_type, + "selection_consistency_repaired": bool(selection_consistency_repaired), + "selection_consistency_reason": selection_consistency_reason, + } + ) + final_hits.append( + MemoryHit( + memory_id=raw_hit.memory_id, + category=raw_hit.category, + value=raw_hit.value, + relation=raw_hit.relation, + anchors=list(raw_hit.anchors), + score=round(hit_score, 6), + source_kind=raw_hit.source_kind, + slot_key=raw_hit.slot_key, + state=raw_hit.state, + turn_index=int(raw_hit.turn_index), + metadata=metadata, + ) + ) + for event_rank, event_id in enumerate(selected_event_ids, start=1): + event_hit = _representative_event_hit( + [*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)], + query=query, + ) + if event_hit is None or event_hit.memory_id in seen_memory_ids: + continue + seen_memory_ids.add(event_hit.memory_id) + recall_score = float(recall_event_scores.get(event_id, event_hit.score)) + event_score = float(event_scores.get(event_id, event_hit.score)) + metadata = dict(event_hit.metadata or {}) + metadata.update( + { + "event_id": event_id, + **_answer_plan_hit_metadata(event_id), + "path_id": "", + "base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6), + "rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6), + "calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6), + "matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6), + "event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6), + "event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6), + "event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6), + "tri_maze_event_reverse_score": round(float(tri_maze_event_reverse_scores.get(event_id, 0.0)), 6), + "tri_maze_event_boundary_score": round(float(tri_maze_event_boundary_scores.get(event_id, 0.0)), 6), + "tri_maze_event_reverse_relation": round(float(tri_maze_event_reverse_relations.get(event_id, 0.0)), 6), + "matrix_enabled": matrix_enabled, + "event_calibration_enabled": event_calibration_enabled, + "path_calibration_enabled": path_calibration_enabled, + "event_tunnel_enabled": event_tunnel_enabled, + "path_tunnel_enabled": path_tunnel_enabled, + "final_event_fusion_enabled": final_event_fusion_enabled, + "final_path_fusion_enabled": final_path_fusion_enabled, + "decision_fusion_enabled": True, + "event_fusion_enabled": event_fusion_enabled, + "path_fusion_enabled": path_fusion_enabled, + "event_score": round(event_score, 6), + "recall_score": round(recall_score, 6), + "path_score": 0.0, + "base_path_score": 0.0, + "calibrated_path_score": 0.0, + "path_fusion_delta_score": 0.0, + "path_tunnel_support_score": 0.0, + "path_tunnel_delta_score": 0.0, + "path_model_score": 0.0, + "path_chain_extension_enabled": path_chain_extension_enabled, + "path_chain_extension_delta_score": 0.0, + "path_chain_extended_score": 0.0, + "effective_path_score": 0.0, + "temporal_score": 0.0, + "raw_public_score": round(float(event_hit.score), 6), + "hybrid_score": round(event_score, 6), + "hybrid_score_source": decision_score_source or "learned_final_event_fusion", + "evidence_snippet_role": "selected_event_representative", + "selected_event_rank": int(event_rank), + "selected_event_ids": list(selected_event_ids), + "selected_path_ids": list(selected_path_ids), + "path_tunnel_rescue_enabled": bool(self.path_tunnel_rescue_k > 0), + "path_tunnel_rescue_k": int(self.path_tunnel_rescue_k), + "path_tunnel_rescue_score_floor": round(float(self.path_tunnel_rescue_score_floor), 6), + "path_tunnel_rescue_min_age": int(self.path_tunnel_rescue_min_age), + "path_tunnel_rescue_min_score_margin": round(float(self.path_tunnel_rescue_min_score_margin), 6), + "path_tunnel_rescue_score_threshold": round(float(tunnel_rescue_score_threshold), 6), + "path_tunnel_rescue_candidate_count": int(tunnel_rescue_candidate_count), + "path_tunnel_rescue_filtered_count": int(tunnel_rescue_filtered_count), + "path_tunnel_rescue_path_ids": list(tunnel_rescue_path_ids), + "path_utility_gate_enabled": bool(self.path_tunnel_rescue_k > 0), + "path_utility_pre_filter_path_ids": list(tunnel_rescue_pre_filter_path_ids), + "path_utility_injected_path_ids": list(tunnel_rescue_path_ids), + "path_utility_direct_support_path_ids": list(path_utility_direct_support_path_ids), + "path_utility_contrast_support_path_ids": list(path_utility_contrast_support_path_ids), + "path_utility_latent_context_path_ids": list(path_utility_latent_context_path_ids), + "path_utility_drift_noise_path_ids": list(path_utility_drift_noise_path_ids), + "path_utility_roles": dict(path_utility_roles), + "path_utility_reasons": dict(path_utility_reasons), + "path_utility_scores": dict(path_utility_scores), + "path_utility_overlap_tokens": dict(path_utility_overlap_tokens), + "path_utility_anchor_event_ids": list(path_utility_anchor_event_ids), + "path_utility_anchor_subject_signatures": list(path_utility_anchor_subject_signatures), + "tunnel_recall_pre_filter_count": int(len(tunnel_rescue_pre_filter_path_ids)), + "tunnel_usable_post_filter_count": int(len(tunnel_rescue_path_ids)), + "model_focused_answer_type": model_focused_answer_type, + "selection_consistency_repaired": bool(selection_consistency_repaired), + "selection_consistency_reason": selection_consistency_reason, + } + ) + final_hits.append( + MemoryHit( + memory_id=event_hit.memory_id, + category=event_hit.category, + value=event_hit.value, + relation=event_hit.relation, + anchors=list(event_hit.anchors), + score=round(event_score, 6), + source_kind=event_hit.source_kind, + slot_key=event_hit.slot_key, + state=event_hit.state, + turn_index=int(event_hit.turn_index), + metadata=metadata, + ) + ) + if not final_hits: + final_hits = list(hits) + if profile_first_hits: + final_hits = _inject_profile_first_hits( + final_hits, + profile_first_hits, + selected_event_ids=selected_event_ids, + selected_path_ids=selected_path_ids, + ) + final_hits = _coverage_preserving_final_hits(final_hits, selected_event_ids=selected_event_ids, top_k=top_k) + final_hit_event_ids = _event_ids_from_hits(final_hits) + final_missing_selected_event_ids = [event_id for event_id in selected_event_ids if event_id not in set(final_hit_event_ids)] + return { + "hits": final_hits, + "metadata": { + "retrieval_mode": "hybrid_node_scored", + "hybrid_enabled": True, + "hybrid_source": hybrid_source, + **memory_router_decision, + "profile_first_router_suppressed": bool(profile_first_router_suppressed), + "recall_event_ids": list(recall_event_ids), + "learned_recall_event_ids": list(learned_recall_event_ids), + "model_recall_event_ids": list(model_recall_event_ids), + "symbolic_recall_event_ids": list(symbolic_recall_event_ids), + "embedder_index_recall_event_ids": list(embedder_index_event_ids), + **embedder_index_metadata, + "embedder_fusion_mode": embedder_fusion_mode or "off", + "embedder_fusion_enabled": bool(embedder_fusion_enabled), + "embedder_fusion_weight": round(float(self.embedder_fusion_weight), 6), + "embedder_fusion_score_floor": round(float(self.embedder_fusion_score_floor), 6), + "embedder_fusion_top_k": int(self.embedder_fusion_top_k), + "embedder_fusion_select_k": int(self.embedder_fusion_select_k), + "embedder_fusion_max_boost": round(float(self.embedder_fusion_max_boost), 6), + "embedder_fusion_event_scores": dict(embedder_fusion_applied_event_scores), + "embedder_fusion_boosts": dict(embedder_fusion_boosts), + "embedder_fusion_selected_event_ids": list(embedder_fusion_selected_event_ids), + "embedder_fusion_selected_path_ids": list(embedder_fusion_selected_path_ids), + "profile_first_hybrid_enabled": bool(profile_first_event_ids), + "profile_first_event_ids": list(profile_first_event_ids), + "profile_first_memory_ids": list(profile_first_memory_ids), + "hybrid_candidate_event_ids": list(hybrid_candidate_event_ids), + "hybrid_candidate_union_enabled": True, + "hybrid_candidate_union_rescored": bool(hybrid_candidate_union_rescored), + "hybrid_candidate_union_added_event_ids": list(hybrid_candidate_union_added_event_ids), + "hybrid_candidate_union_priority_changed": bool(hybrid_candidate_union_priority_changed), + "rerank_candidate_event_ids": list(rerank_candidate_event_ids), + "base_event_scores": dict(base_event_scores), + "rerank_event_scores": dict(rerank_event_scores), + "calibrated_event_scores": dict(calibrated_event_scores), + "matrix_event_scores": dict(matrix_event_scores), + "event_fusion_delta_scores": dict(event_fusion_delta_scores), + "event_tunnel_support_scores": dict(event_tunnel_support_scores), + "event_tunnel_delta_scores": dict(event_tunnel_delta_scores), + "tri_maze_event_reverse_scores": dict(tri_maze_event_reverse_scores), + "tri_maze_event_boundary_scores": dict(tri_maze_event_boundary_scores), + "tri_maze_event_reverse_relations": dict(tri_maze_event_reverse_relations), + "matrix_rerank_event_ids": list(matrix_rerank_event_ids), + "matrix_enabled": matrix_enabled, + "rerank_path_scores": dict(rerank_path_scores), + "matrix_path_scores": dict(matrix_path_scores), + "tri_maze_path_reverse_scores": dict(tri_maze_path_reverse_scores), + "tri_maze_path_boundary_scores": dict(tri_maze_path_boundary_scores), + "tri_maze_path_reverse_relations": dict(tri_maze_path_reverse_relations), + "matrix_path_rerank_ids": list(matrix_path_rerank_ids), + "matrix_path_enabled": matrix_path_enabled, + "fusion_enabled": fusion_enabled, + "event_calibration_enabled": event_calibration_enabled, + "path_calibration_enabled": path_calibration_enabled, + "event_tunnel_enabled": event_tunnel_enabled, + "path_tunnel_enabled": path_tunnel_enabled, + "final_event_fusion_enabled": final_event_fusion_enabled, + "final_path_fusion_enabled": final_path_fusion_enabled, + "decision_fusion_enabled": True, + "decision_score_source": decision_score_source or "learned_decision_fusion", + "event_fusion_enabled": event_fusion_enabled, + "path_fusion_enabled": path_fusion_enabled, + "selected_event_ids": list(selected_event_ids), + "path_rescue_event_ids": [], + "selected_path_ids": list(selected_path_ids), + "path_tunnel_rescue_enabled": bool(self.path_tunnel_rescue_k > 0), + "path_tunnel_rescue_k": int(self.path_tunnel_rescue_k), + "path_tunnel_rescue_score_floor": round(float(self.path_tunnel_rescue_score_floor), 6), + "path_tunnel_rescue_min_age": int(self.path_tunnel_rescue_min_age), + "path_tunnel_rescue_min_score_margin": round(float(self.path_tunnel_rescue_min_score_margin), 6), + "path_tunnel_rescue_score_threshold": round(float(tunnel_rescue_score_threshold), 6), + "path_tunnel_rescue_candidate_count": int(tunnel_rescue_candidate_count), + "path_tunnel_rescue_filtered_count": int(tunnel_rescue_filtered_count), + "path_tunnel_rescue_path_ids": list(tunnel_rescue_path_ids), + "path_utility_gate_enabled": bool(self.path_tunnel_rescue_k > 0), + "path_utility_pre_filter_path_ids": list(tunnel_rescue_pre_filter_path_ids), + "path_utility_injected_path_ids": list(tunnel_rescue_path_ids), + "path_utility_direct_support_path_ids": list(path_utility_direct_support_path_ids), + "path_utility_contrast_support_path_ids": list(path_utility_contrast_support_path_ids), + "path_utility_latent_context_path_ids": list(path_utility_latent_context_path_ids), + "path_utility_drift_noise_path_ids": list(path_utility_drift_noise_path_ids), + "path_utility_roles": dict(path_utility_roles), + "path_utility_reasons": dict(path_utility_reasons), + "path_utility_scores": dict(path_utility_scores), + "path_utility_overlap_tokens": dict(path_utility_overlap_tokens), + "path_utility_anchor_event_ids": list(path_utility_anchor_event_ids), + "path_utility_anchor_subject_signatures": list(path_utility_anchor_subject_signatures), + "tunnel_recall_pre_filter_count": int(len(tunnel_rescue_pre_filter_path_ids)), + "tunnel_usable_post_filter_count": int(len(tunnel_rescue_path_ids)), + "final_hit_event_ids": list(final_hit_event_ids), + "final_hit_dia_ids": _dia_ids_from_hits(final_hits), + "final_missing_selected_event_ids": list(final_missing_selected_event_ids), + "selected_event_count": int(len(selected_event_ids)), + "selected_path_count": int(len(selected_path_ids)), + "temporal_scores": dict(temporal_scores), + "recall_event_scores": dict(recall_event_scores), + "event_scores": dict(event_scores), + "base_path_scores": dict(base_path_scores), + "calibrated_path_scores": dict(calibrated_path_scores), + "path_fusion_delta_scores": dict(path_fusion_delta_scores), + "path_tunnel_support_scores": dict(path_tunnel_support_scores), + "path_tunnel_delta_scores": dict(path_tunnel_delta_scores), + "path_model_scores": dict(path_model_scores), + "path_chain_extension_enabled": path_chain_extension_enabled, + "path_chain_extension_delta_scores": dict(path_chain_extension_delta_scores), + "path_chain_extended_scores": dict(path_chain_extended_scores), + "effective_path_scores": dict(path_scores), + "path_scores": dict(path_scores), + "answer_type_scores": dict(answer_type_scores), + "answer_plan_scores": dict(answer_plan_scores), + "answer_plan_support_scores": dict(answer_plan_support_scores), + "answer_plan_adjusted_scores": dict(answer_plan_adjusted_scores), + "answer_plan_raw_ranked_event_ids": list(answer_plan_raw_ranked_event_ids), + "answer_plan_ranked_event_ids": list(answer_plan_ranked_event_ids), + "answer_plan_selected_event_ids": list(answer_plan_selected_event_ids), + "answer_plan_current_event_ids": list(answer_plan_current_event_ids), + "answer_plan_promotion_enabled": bool(answer_plan_promotion_enabled), + "answer_plan_promotion_score_margin": round(float(answer_plan_promotion_score_margin), 6), + "answer_plan_promotion_min_margin": round(float(answer_plan_promotion_min_margin), 6), + "answer_plan_event_selection_threshold": round(float(answer_plan_event_selection_threshold), 6), + "answer_plan_event_selection_top_k": int(answer_plan_event_selection_top_k), + "focused_answer_type": focused_answer_type, + "model_focused_answer_type": model_focused_answer_type, + "selection_consistency_repaired": bool(selection_consistency_repaired), + "selection_consistency_reason": selection_consistency_reason, + "preferred_path_types": [], + }, + } + dominant_answer_type = _dominant_answer_type(question_analysis, answer_type_scores) + preferred_path_types = _answer_type_preferred_path_types(question_analysis, answer_type_scores) + if bool(memory_router_decision.get("memory_router_guided")): + temporal_focus = _memory_router_allows(memory_router_decision, "temporal") + else: + temporal_focus = dominant_answer_type == "time" or bool(question_analysis.get("is_temporal", False)) + router_profile_focus = bool(memory_router_decision.get("memory_router_guided")) and _memory_router_allows( + memory_router_decision, + "profile", + "resource", + ) + learned_event_available = bool(event_scores) + learned_path_available = bool(path_scores) + ranked_events = [ + event_id + for event_id, _ in sorted( + ( + (event_id, event_scores.get(event_id, recall_event_scores.get(event_id, 0.0))) + for event_id in recall_event_ids + ), + key=lambda item: (-float(item[1]), item[0]), + ) + ] + base_selected_event_count = min( + len(ranked_events), + max( + self.support_path_k * 2, + min(self.candidate_event_k, max(top_k, _HYBRID_SELECTED_EVENT_FLOOR)), + ), + ) + effective_path_scores = dict(path_scores) + if not learned_path_available: + effective_path_scores = { + path_id: _calibrated_path_score( + path=runtime_paths.get(path_id, {}), + base_score=float(score), + temporal_scores=temporal_scores, + question_analysis=question_analysis, + answer_type_scores=answer_type_scores, + ) + for path_id, score in path_scores.items() + } + ranked_path_ids_all = [ + path_id + for path_id, _ in sorted( + ((path_id, float(score)) for path_id, score in effective_path_scores.items()), + key=lambda item: (-float(item[1]), item[0]), + ) + ] + path_rescue_count = min( + len(ranked_path_ids_all), + max(self.support_path_k * 2, min(self.candidate_event_k, max(1, top_k))), + ) + path_rescue_event_ids = _dedupe( + _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) + for path_id in ranked_path_ids_all[: max(1, path_rescue_count)] + ) + selected_event_ids = _dedupe( + [ + *path_rescue_event_ids, + *ranked_events[: max(1, base_selected_event_count)], + ] + ) + if profile_first_event_ids: + selected_event_ids = _dedupe([*profile_first_event_ids, *selected_event_ids]) + if answer_plan_ranked_event_ids: + selected_event_ids = _dedupe([*selected_event_ids, *answer_plan_ranked_event_ids]) + selected_event_id_set = set(selected_event_ids) + ranked_path_ids = [ + path_id + for path_id in ranked_path_ids_all + if _clean_text(runtime_paths.get(path_id, {}).get("event_id", "")) in selected_event_id_set + ] + if temporal_focus: + focused_time_path_ids = [ + path_id + for path_id in ranked_path_ids + if _clean_text(runtime_paths.get(path_id, {}).get("type", "")) == "speaker_event_time" + ] + if focused_time_path_ids: + ranked_path_ids = focused_time_path_ids + path_limit_cap = _HYBRID_SELECTED_PATH_CAP + if temporal_focus: + path_limit_cap = _HYBRID_TEMPORAL_PATH_CAP + elif dominant_answer_type == "profile" or router_profile_focus: + path_limit_cap = _HYBRID_PROFILE_PATH_CAP + selected_path_count = min( + len(ranked_path_ids), + max(1, min(max(1, self.support_path_k), min(max(1, top_k), path_limit_cap))), + ) + selected_path_ids = ranked_path_ids[: max(1, selected_path_count)] + final_hits: List[MemoryHit] = [] + seen_memory_ids = set() + temporal_event_hits_added = 0 + for path_id in selected_path_ids: + path = runtime_paths.get(path_id, {}) + event_id = _clean_text(path.get("event_id", "")) + support_node_id = _path_support_node_id(path) + path_type = _clean_text(path.get("type", "")) + support_hit = _support_hit_for_path(path_type, grouped_hits.get(event_id, [])) + event_hit = _representative_event_hit( + [*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)], + query=query, + ) + hit_pairs: List[tuple[MemoryHit | None, float]] = [(support_hit, path_scores.get(path_id, 0.0))] + if not temporal_focus: + hit_pairs.append((event_hit, path_scores.get(path_id, 0.0))) + elif support_hit is None and event_hit is not None: + hit_pairs.append((event_hit, path_scores.get(path_id, 0.0))) + elif path_type == "speaker_event_time" and event_hit is not None and temporal_event_hits_added < 1: + hit_pairs.append((event_hit, path_scores.get(path_id, 0.0))) + temporal_event_hits_added += 1 + for raw_hit, path_score in hit_pairs: + if raw_hit is None or raw_hit.memory_id in seen_memory_ids: + continue + seen_memory_ids.add(raw_hit.memory_id) + recall_score = float(recall_event_scores.get(event_id, event_scores.get(event_id, 0.0))) + event_score = float(event_scores.get(event_id, raw_hit.score)) + temporal_score = float(temporal_scores.get(support_node_id, 0.0)) + effective_path_score = float(effective_path_scores.get(path_id, path_score)) + hybrid_score = round(effective_path_score, 6) if learned_path_available else round( + (0.55 * event_score) + + (0.20 * float(path_score)) + + (0.15 * recall_score) + + (0.10 * temporal_score), + 6, + ) + metadata = dict(raw_hit.metadata or {}) + metadata.update( + { + "event_id": event_id, + **_answer_plan_hit_metadata(event_id), + "path_id": path_id, + "base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6), + "rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6), + "calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6), + "matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6), + "event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6), + "event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6), + "event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6), + "matrix_enabled": matrix_enabled, + "event_calibration_enabled": event_calibration_enabled, + "path_calibration_enabled": path_calibration_enabled, + "event_tunnel_enabled": event_tunnel_enabled, + "path_tunnel_enabled": path_tunnel_enabled, + "final_event_fusion_enabled": final_event_fusion_enabled, + "final_path_fusion_enabled": final_path_fusion_enabled, + "decision_fusion_enabled": False, + "event_fusion_enabled": event_fusion_enabled, + "path_fusion_enabled": path_fusion_enabled, + "event_score": round(event_score, 6), + "recall_score": round(recall_score, 6), + "path_score": round(float(path_score), 6), + "base_path_score": round(float(base_path_scores.get(path_id, path_score)), 6), + "calibrated_path_score": round(float(calibrated_path_scores.get(path_id, effective_path_score)), 6), + "path_fusion_delta_score": round(float(path_fusion_delta_scores.get(path_id, 0.0)), 6), + "path_tunnel_support_score": round(float(path_tunnel_support_scores.get(path_id, 0.0)), 6), + "path_tunnel_delta_score": round(float(path_tunnel_delta_scores.get(path_id, 0.0)), 6), + "path_model_score": round(float(path_model_scores.get(path_id, path_score)), 6), + "path_chain_extension_enabled": path_chain_extension_enabled, + "path_chain_extension_delta_score": round(float(path_chain_extension_delta_scores.get(path_id, 0.0)), 6), + "path_chain_extended_score": round(float(path_chain_extended_scores.get(path_id, effective_path_score)), 6), + "effective_path_score": round(effective_path_score, 6), + "temporal_score": round(temporal_score, 6), + "raw_public_score": round(float(raw_hit.score), 6), + "hybrid_score": hybrid_score, + "hybrid_score_source": ( + "learned_path_fusion" + if path_fusion_enabled + else "learned_path_score" + if learned_path_available + else "heuristic_mix" + ), + "evidence_snippet_role": "selected_path_support" if raw_hit is support_hit else "selected_path_event", + "selected_event_ids": list(selected_event_ids), + "selected_path_ids": list(selected_path_ids), + } + ) + final_hits.append( + MemoryHit( + memory_id=raw_hit.memory_id, + category=raw_hit.category, + value=raw_hit.value, + relation=raw_hit.relation, + anchors=list(raw_hit.anchors), + score=hybrid_score, + source_kind=raw_hit.source_kind, + slot_key=raw_hit.slot_key, + state=raw_hit.state, + turn_index=int(raw_hit.turn_index), + metadata=metadata, + ) + ) + for event_rank, event_id in enumerate(selected_event_ids, start=1): + event_hit = _representative_event_hit( + [*grouped_hits.get(event_id, []), *_event_record_hits_from_graph(self.graph, event_id)], + query=query, + ) + if event_hit is None or event_hit.memory_id in seen_memory_ids: + continue + seen_memory_ids.add(event_hit.memory_id) + recall_score = float(recall_event_scores.get(event_id, event_hit.score)) + event_score = float(event_scores.get(event_id, event_hit.score)) + hybrid_score = round(event_score, 6) if learned_event_available else round( + (0.7 * event_score) + (0.3 * recall_score), + 6, + ) + metadata = dict(event_hit.metadata or {}) + metadata.update( + { + "event_id": event_id, + **_answer_plan_hit_metadata(event_id), + "path_id": "", + "base_event_score": round(float(base_event_scores.get(event_id, event_score)), 6), + "rerank_event_score": round(float(rerank_event_scores.get(event_id, event_score)), 6), + "calibrated_event_score": round(float(calibrated_event_scores.get(event_id, event_score)), 6), + "matrix_event_score": round(float(matrix_event_scores.get(event_id, 0.0)), 6), + "event_fusion_delta_score": round(float(event_fusion_delta_scores.get(event_id, 0.0)), 6), + "event_tunnel_support_score": round(float(event_tunnel_support_scores.get(event_id, 0.0)), 6), + "event_tunnel_delta_score": round(float(event_tunnel_delta_scores.get(event_id, 0.0)), 6), + "matrix_enabled": matrix_enabled, + "event_calibration_enabled": event_calibration_enabled, + "path_calibration_enabled": path_calibration_enabled, + "event_tunnel_enabled": event_tunnel_enabled, + "path_tunnel_enabled": path_tunnel_enabled, + "final_event_fusion_enabled": final_event_fusion_enabled, + "final_path_fusion_enabled": final_path_fusion_enabled, + "decision_fusion_enabled": False, + "event_fusion_enabled": event_fusion_enabled, + "path_fusion_enabled": path_fusion_enabled, + "event_score": round(event_score, 6), + "recall_score": round(recall_score, 6), + "path_score": 0.0, + "base_path_score": 0.0, + "calibrated_path_score": 0.0, + "path_fusion_delta_score": 0.0, + "path_tunnel_support_score": 0.0, + "path_tunnel_delta_score": 0.0, + "path_model_score": 0.0, + "path_chain_extension_enabled": path_chain_extension_enabled, + "path_chain_extension_delta_score": 0.0, + "path_chain_extended_score": 0.0, + "effective_path_score": 0.0, + "temporal_score": 0.0, + "raw_public_score": round(float(event_hit.score), 6), + "hybrid_score": hybrid_score, + "hybrid_score_source": ( + "learned_event_fusion" + if event_fusion_enabled + else "learned_event_score" + if learned_event_available + else "heuristic_event_mix" + ), + "evidence_snippet_role": "selected_event_representative", + "selected_event_rank": int(event_rank), + "selected_event_ids": list(selected_event_ids), + "selected_path_ids": list(selected_path_ids), + } + ) + final_hits.append( + MemoryHit( + memory_id=event_hit.memory_id, + category=event_hit.category, + value=event_hit.value, + relation=event_hit.relation, + anchors=list(event_hit.anchors), + score=hybrid_score, + source_kind=event_hit.source_kind, + slot_key=event_hit.slot_key, + state=event_hit.state, + turn_index=int(event_hit.turn_index), + metadata=metadata, + ) + ) + if not final_hits: + final_hits = list(hits) + if profile_first_hits: + final_hits = _inject_profile_first_hits( + final_hits, + profile_first_hits, + selected_event_ids=selected_event_ids, + selected_path_ids=selected_path_ids, + ) + final_hits = _coverage_preserving_final_hits(final_hits, selected_event_ids=selected_event_ids, top_k=top_k) + final_hit_event_ids = _event_ids_from_hits(final_hits) + final_missing_selected_event_ids = [event_id for event_id in selected_event_ids if event_id not in set(final_hit_event_ids)] + return { + "hits": final_hits, + "metadata": { + "retrieval_mode": "hybrid_node_scored", + "hybrid_enabled": True, + "hybrid_source": hybrid_source, + **memory_router_decision, + "profile_first_router_suppressed": bool(profile_first_router_suppressed), + "recall_event_ids": list(recall_event_ids), + "learned_recall_event_ids": list(learned_recall_event_ids), + "model_recall_event_ids": list(model_recall_event_ids), + "symbolic_recall_event_ids": list(symbolic_recall_event_ids), + "embedder_index_recall_event_ids": list(embedder_index_event_ids), + **embedder_index_metadata, + "embedder_fusion_mode": embedder_fusion_mode or "off", + "embedder_fusion_enabled": bool(embedder_fusion_enabled), + "embedder_fusion_weight": round(float(self.embedder_fusion_weight), 6), + "embedder_fusion_score_floor": round(float(self.embedder_fusion_score_floor), 6), + "embedder_fusion_top_k": int(self.embedder_fusion_top_k), + "embedder_fusion_select_k": int(self.embedder_fusion_select_k), + "embedder_fusion_max_boost": round(float(self.embedder_fusion_max_boost), 6), + "embedder_fusion_event_scores": dict(embedder_fusion_applied_event_scores), + "embedder_fusion_boosts": dict(embedder_fusion_boosts), + "embedder_fusion_selected_event_ids": list(embedder_fusion_selected_event_ids), + "embedder_fusion_selected_path_ids": list(embedder_fusion_selected_path_ids), + "profile_first_hybrid_enabled": bool(profile_first_event_ids), + "profile_first_event_ids": list(profile_first_event_ids), + "profile_first_memory_ids": list(profile_first_memory_ids), + "hybrid_candidate_event_ids": list(hybrid_candidate_event_ids), + "hybrid_candidate_union_enabled": True, + "hybrid_candidate_union_rescored": bool(hybrid_candidate_union_rescored), + "hybrid_candidate_union_added_event_ids": list(hybrid_candidate_union_added_event_ids), + "hybrid_candidate_union_priority_changed": bool(hybrid_candidate_union_priority_changed), + "rerank_candidate_event_ids": list(rerank_candidate_event_ids), + "base_event_scores": dict(base_event_scores), + "rerank_event_scores": dict(rerank_event_scores), + "calibrated_event_scores": dict(calibrated_event_scores), + "matrix_event_scores": dict(matrix_event_scores), + "event_fusion_delta_scores": dict(event_fusion_delta_scores), + "event_tunnel_support_scores": dict(event_tunnel_support_scores), + "event_tunnel_delta_scores": dict(event_tunnel_delta_scores), + "tri_maze_event_reverse_scores": dict(tri_maze_event_reverse_scores), + "tri_maze_event_boundary_scores": dict(tri_maze_event_boundary_scores), + "tri_maze_event_reverse_relations": dict(tri_maze_event_reverse_relations), + "matrix_rerank_event_ids": list(matrix_rerank_event_ids), + "matrix_enabled": matrix_enabled, + "rerank_path_scores": dict(rerank_path_scores), + "matrix_path_scores": dict(matrix_path_scores), + "tri_maze_path_reverse_scores": dict(tri_maze_path_reverse_scores), + "tri_maze_path_boundary_scores": dict(tri_maze_path_boundary_scores), + "tri_maze_path_reverse_relations": dict(tri_maze_path_reverse_relations), + "matrix_path_rerank_ids": list(matrix_path_rerank_ids), + "matrix_path_enabled": matrix_path_enabled, + "fusion_enabled": fusion_enabled, + "event_calibration_enabled": event_calibration_enabled, + "path_calibration_enabled": path_calibration_enabled, + "event_tunnel_enabled": event_tunnel_enabled, + "path_tunnel_enabled": path_tunnel_enabled, + "final_event_fusion_enabled": final_event_fusion_enabled, + "final_path_fusion_enabled": final_path_fusion_enabled, + "decision_fusion_enabled": False, + "decision_score_source": "", + "event_fusion_enabled": event_fusion_enabled, + "path_fusion_enabled": path_fusion_enabled, + "selected_event_ids": list(selected_event_ids), + "path_rescue_event_ids": list(path_rescue_event_ids), + "selected_path_ids": list(selected_path_ids), + "final_hit_event_ids": list(final_hit_event_ids), + "final_hit_dia_ids": _dia_ids_from_hits(final_hits), + "final_missing_selected_event_ids": list(final_missing_selected_event_ids), + "selected_event_count": int(len(selected_event_ids)), + "selected_path_count": int(len(selected_path_ids)), + "temporal_scores": dict(temporal_scores), + "recall_event_scores": dict(recall_event_scores), + "event_scores": dict(event_scores), + "base_path_scores": dict(base_path_scores), + "calibrated_path_scores": dict(calibrated_path_scores), + "path_fusion_delta_scores": dict(path_fusion_delta_scores), + "path_tunnel_support_scores": dict(path_tunnel_support_scores), + "path_tunnel_delta_scores": dict(path_tunnel_delta_scores), + "path_model_scores": dict(path_model_scores), + "path_chain_extension_enabled": path_chain_extension_enabled, + "path_chain_extension_delta_scores": dict(path_chain_extension_delta_scores), + "path_chain_extended_scores": dict(path_chain_extended_scores), + "effective_path_scores": dict(effective_path_scores), + "path_scores": dict(path_scores), + "answer_type_scores": dict(answer_type_scores), + "answer_plan_scores": dict(answer_plan_scores), + "answer_plan_support_scores": dict(answer_plan_support_scores), + "answer_plan_adjusted_scores": dict(answer_plan_adjusted_scores), + "answer_plan_raw_ranked_event_ids": list(answer_plan_raw_ranked_event_ids), + "answer_plan_ranked_event_ids": list(answer_plan_ranked_event_ids), + "answer_plan_selected_event_ids": list(answer_plan_selected_event_ids), + "answer_plan_current_event_ids": list(answer_plan_current_event_ids), + "answer_plan_promotion_enabled": bool(answer_plan_promotion_enabled), + "answer_plan_promotion_score_margin": round(float(answer_plan_promotion_score_margin), 6), + "answer_plan_promotion_min_margin": round(float(answer_plan_promotion_min_margin), 6), + "answer_plan_event_selection_threshold": round(float(answer_plan_event_selection_threshold), 6), + "answer_plan_event_selection_top_k": int(answer_plan_event_selection_top_k), + "focused_answer_type": dominant_answer_type, + "preferred_path_types": list(preferred_path_types), + }, + } + + def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval: + self._reload_graph() + start = time.perf_counter() + candidate_top_k = max(top_k, self.candidate_event_k if self.retrieval_mode == "hybrid_node_scored" else min(max(top_k, 18), 48)) + payload = self.graph.retrieve(query, top_k=candidate_top_k) + hits = [_raw_hit_to_memory_hit(item) for item in payload.get("hits", []) or []] + scored_lookup = {hit.memory_id: hit for hit in hits if hit.memory_id} + active_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("active_hits", []) or []], scored_lookup) + history_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("history_hits", []) or []], scored_lookup) + stale_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("stale_hits", []) or []], scored_lookup) + overwrite_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("overwrite_hits", []) or []], scored_lookup) + false_hits = _restore_hit_scores([_raw_hit_to_memory_hit(item) for item in payload.get("false_hits", []) or []], scored_lookup) + public_hits = _public_graph_hits(self.graph) if self.retrieval_mode == "hybrid_node_scored" else [] + hybrid_payload = self._hybrid_node_scored_hits(query, hits, top_k=top_k, public_hits=public_hits) + hits = list(hybrid_payload.get("hits", []) or hits) + hybrid_metadata = dict(hybrid_payload.get("metadata", {}) or {}) + memory_router_decision = _memory_router_decision( + hybrid_metadata, + mode=self.memory_router_mode, + threshold=self.memory_router_threshold, + margin=self.memory_router_margin, + ) + current_subject_payload = _current_subject_protected_hits( + query=query, + graph=self.graph, + final_hits=hits, + top_k=top_k, + ) + hits = list(current_subject_payload.get("hits", hits) or hits) + current_subject_metadata = dict(current_subject_payload.get("metadata", {}) or {}) + audit_anchor_payload = _audit_anchor_protected_hits( + query=query, + final_hits=hits, + candidate_hits=list(public_hits) + list(active_hits) + list(history_hits), + metadata={**dict(payload.get("metadata", {}) or {}), **hybrid_metadata, **current_subject_metadata}, + top_k=top_k, + ) + hits = list(audit_anchor_payload.get("hits", hits) or hits) + audit_anchor_metadata = dict(audit_anchor_payload.get("metadata", {}) or {}) + identifier_payload = _identifier_protected_hits( + query=query, + final_hits=hits, + candidate_hits=list(public_hits) + list(active_hits) + list(history_hits), + top_k=top_k, + ) + hits = list(identifier_payload.get("hits", hits) or hits) + identifier_metadata = dict(identifier_payload.get("metadata", {}) or {}) + if _memory_router_allows(memory_router_decision, "path_tunnel", "topic_tunnel"): + depth_chain_payload = _depth_chain_protected_hits( + query=query, + graph=self.graph, + final_hits=hits, + top_k=top_k, + ) + hits = list(depth_chain_payload.get("hits", hits) or hits) + depth_chain_metadata = dict(depth_chain_payload.get("metadata", {}) or {}) + else: + depth_chain_metadata = { + "depth_chain_protected_enabled": False, + "depth_chain_router_suppressed": True, + } + if _memory_router_allows(memory_router_decision, "profile", "resource"): + profile_focused_payload = _profile_focused_pack_hits( + self.graph, + query, + hits, + top_k=top_k, + ) + hits = list(profile_focused_payload.get("hits", hits) or hits) + profile_focused_metadata = dict(profile_focused_payload.get("metadata", {}) or {}) + else: + profile_focused_metadata = { + "profile_focused_pack_enabled": False, + "profile_focused_router_suppressed": True, + } + if _memory_router_allows(memory_router_decision, "topic_tunnel"): + topic_bucket_payload = _topic_bucket_rerank_hits(self.graph, query, hits, top_k=top_k) + topic_bucket_hits = topic_bucket_payload.get("hits", hits) + hits = list(hits if topic_bucket_hits is None else topic_bucket_hits) + topic_bucket_metadata = dict(topic_bucket_payload.get("metadata", {}) or {}) + else: + topic_bucket_metadata = { + "topic_bucket_rerank_enabled": False, + "topic_bucket_router_suppressed": True, + } + temporal_runtime_payload = self._temporal_runtime_pack(query) + temporal_evidence_payload = self._apply_temporal_evidence_pack_to_hits( + hits, + temporal_runtime_payload, + top_k=top_k, + ) + hits = list(temporal_evidence_payload.get("hits", hits) or hits) + temporal_runtime_metadata = dict(temporal_evidence_payload.get("metadata", {}) or {}) + injection_planner_payload = self._apply_injection_planner_to_hits(query, hits, top_k=top_k) + hits = list(injection_planner_payload.get("hits", hits) or hits) + injection_planner_metadata = dict(injection_planner_payload.get("metadata", {}) or {}) + facet_query_pack_payload = _facet_query_pack_hits( + self.graph, + query, + hits, + top_k=top_k, + ) + hits = list(facet_query_pack_payload.get("hits", hits) or hits) + facet_query_pack_metadata = dict(facet_query_pack_payload.get("metadata", {}) or {}) + unit_coverage_mode = _normalize(os.getenv("TMCRA_UNIT_COVERAGE_PACK_MODE", "on")) + if unit_coverage_mode in _MULTI_UNIT_CHAIN_DISABLED_MODES: + unit_coverage_metadata = { + "unit_coverage_pack_enabled": False, + "unit_coverage_reason": "disabled", + } + else: + unit_coverage_payload = _unit_coverage_pack_hits( + self.graph, + query, + hits, + top_k=top_k, + ) + hits = list(unit_coverage_payload.get("hits", hits) or hits) + unit_coverage_metadata = dict(unit_coverage_payload.get("metadata", {}) or {}) + if _normalize(os.getenv("TMCRA_MULTI_UNIT_CHAIN_SLOT_MODE", "on")) not in _MULTI_UNIT_CHAIN_DISABLED_MODES: + multi_unit_chain_slot_payload = _multi_unit_chain_slot_hits( + self.graph, + query, + hits, + top_k=top_k, + ) + hits = list(multi_unit_chain_slot_payload.get("hits", hits) or hits) + multi_unit_chain_slot_metadata = dict(multi_unit_chain_slot_payload.get("metadata", {}) or {}) + else: + multi_unit_chain_slot_metadata = { + "multi_unit_chain_slot_enabled": False, + "multi_unit_chain_slot_reason": "disabled", + } + profile_protected_reinserted_count = 0 + profile_protected_ids = _dedupe( + [ + *list(profile_focused_metadata.get("profile_focused_pack_memory_ids", []) or []), + *list(profile_focused_metadata.get("profile_first_memory_ids", []) or []), + ], + max_items=max(1, min(6, int(top_k or 1))), + ) + if profile_protected_ids: + existing_hits_by_id = {hit.memory_id: hit for hit in hits if hit.memory_id} + protected_hits: List[MemoryHit] = [] + for memory_id in profile_protected_ids: + hit = existing_hits_by_id.get(memory_id) + if hit is None: + record = getattr(self.graph, "records_by_id", {}).get(memory_id) + if record is None: + continue + hit = _memory_hit_from_record(record) + metadata = dict(hit.metadata or {}) + metadata.update( + { + "profile_first_hybrid_rescue": True, + "profile_protected_slot": True, + "evidence_snippet_role": "profile_protected_slot", + } + ) + protected_hits.append( + MemoryHit( + memory_id=hit.memory_id, + category=hit.category, + value=hit.value, + relation=hit.relation, + anchors=list(hit.anchors), + score=max(float(hit.score), 4.8), + source_kind=hit.source_kind, + slot_key=hit.slot_key, + state=hit.state, + turn_index=int(hit.turn_index), + metadata=metadata, + ) + ) + if protected_hits: + profile_protected_reinserted_count = len(protected_hits) + seen_profile_ids = {hit.memory_id for hit in protected_hits if hit.memory_id} + hits = [*protected_hits, *[hit for hit in hits if hit.memory_id not in seen_profile_ids]] + embedder_fusion_output_event_ids = [ + _clean_text(event_id) + for event_id in list(hybrid_metadata.get("embedder_fusion_selected_event_ids", []) or []) + if _clean_text(event_id) + ] + embedder_fusion_output_reordered = False + if embedder_fusion_output_event_ids: + before_order = _event_ids_from_hits(hits) + seen_output_memory_ids = {hit.memory_id for hit in hits} + for event_id in embedder_fusion_output_event_ids: + event_hit = _representative_event_hit(_event_record_hits_from_graph(self.graph, event_id), query=query) + if event_hit is None or event_hit.memory_id in seen_output_memory_ids: + continue + metadata = dict(event_hit.metadata or {}) + metadata.update( + { + "event_id": event_id, + "evidence_snippet_role": "embedder_fusion_event_representative", + "embedder_fusion_event_representative": True, + } + ) + hits.append( + MemoryHit( + memory_id=event_hit.memory_id, + category=event_hit.category, + value=event_hit.value, + relation=event_hit.relation, + anchors=list(event_hit.anchors), + score=float(event_hit.score), + source_kind=event_hit.source_kind, + slot_key=event_hit.slot_key, + state=event_hit.state, + turn_index=int(event_hit.turn_index), + metadata=metadata, + ) + ) + seen_output_memory_ids.add(event_hit.memory_id) + hits = _coverage_preserving_final_hits( + hits, + selected_event_ids=_dedupe([*embedder_fusion_output_event_ids, *before_order]), + top_k=top_k, + ) + embedder_fusion_output_reordered = before_order != _event_ids_from_hits(hits) + retrieval_context_tokens = int(payload.get("context_token_estimate", _estimate_tokens_from_hits(hits))) + self._last_retrieval_context_tokens = retrieval_context_tokens + result = MemoryRetrieval( + concepts=list(payload.get("concepts", []) or []), + relations=list(payload.get("relations", []) or []), + hits=hits, + active_hits=active_hits, + history_hits=history_hits, + stale_hits=stale_hits, + overwrite_hits=overwrite_hits, + false_hits=false_hits, + retrieval_seconds=time.perf_counter() - start, + context_token_estimate=retrieval_context_tokens, + retrieval_context_token_estimate=retrieval_context_tokens, + metadata={ + "query_id": payload.get("query_id", ""), + **dict(payload.get("metadata", {}) or {}), + **hybrid_metadata, + **memory_router_decision, + **current_subject_metadata, + **audit_anchor_metadata, + **identifier_metadata, + **depth_chain_metadata, + **profile_focused_metadata, + **topic_bucket_metadata, + **temporal_runtime_metadata, + **injection_planner_metadata, + **facet_query_pack_metadata, + **unit_coverage_metadata, + **multi_unit_chain_slot_metadata, + "profile_protected_reinserted_count": profile_protected_reinserted_count, + "embedder_fusion_output_event_ids": list(embedder_fusion_output_event_ids), + "embedder_fusion_output_reordered": bool(embedder_fusion_output_reordered), + }, + ) + self._persist_graph() + return result + + def export_dialog_graph(self, *, mode: str = "light") -> Dict[str, Any]: + self._reload_graph() + return self.graph.export_graph( + snapshot_points=(1000, 5000, 10000, 20000, 50000, 100000, 200000, 300000, 500000), + mode=mode, + ) + + def export_dialog_graph_mermaid(self) -> str: + self._reload_graph() + return self.graph.export_mermaid() + + def register_answer_support(self, *, answer_id: str, memory_ids: List[str], query_id: str = "", answer_text: str = "") -> None: + self._reload_graph() + self.graph.register_answer_support(answer_id=answer_id, memory_ids=memory_ids, query_id=query_id, answer_text=answer_text) + self._persist_graph() + + def telemetry_snapshot(self) -> Dict[str, Any]: + self._reload_graph() + return self.graph.summary() + + def stats(self) -> Dict[str, Any]: + self._reload_graph() + storage = self._storage_breakdown() + return _state_stats( + storage_bytes=storage["storage_bytes"], + retrieval_context_tokens=self._last_retrieval_context_tokens, + total_state_tokens=storage["total_state_token_estimate"], + core_storage_bytes=storage["core_storage_bytes"], + audit_storage_bytes=storage["audit_storage_bytes"], + core_state_token_estimate=storage["core_state_token_estimate"], + audit_state_token_estimate=storage["audit_state_token_estimate"], + lightweight_stats=bool(self.lightweight_stats), + **self.graph.summary(), + ) + + def storage_bytes(self) -> int: + self._reload_graph() + return self._storage_breakdown()["storage_bytes"] + + def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]: + retrieval = self.retrieve(query, top_k=top_k) + retrieval_payload = retrieval.to_dict() + stats = self.stats() + context_summary = _graph_prompt_state_summary(self.graph, retrieval) + prompt_context_payload = { + "query": query, + "retrieval": retrieval_payload, + "context_summary": context_summary, + } + prompt_context_chars = len(json.dumps(prompt_context_payload, ensure_ascii=False)) + prompt_context_tokens_est = _estimate_tokens(json.dumps(prompt_context_payload, ensure_ascii=False)) + return { + "mode": "graph_session_memory_v2", + "query": query, + "retrieval": retrieval_payload, + "stats": stats, + "state": context_summary, + "context_summary": context_summary, + "prompt_context_chars": int(prompt_context_chars), + "prompt_context_tokens_est": int(prompt_context_tokens_est), + "context_truncated": bool(context_summary.get("context_truncated", False)), + "truncation_reason": _clean_text(context_summary.get("truncation_reason", "")), + } + + +class SummaryWindowMemoryAdapter(MemoryAdapter): + name = "summary_window_memory" + + def __init__(self, *, window_size: int = 24, auto_extract: bool = False) -> None: + self.extractor = SessionMemoryExtractor() + self.window_size = max(4, int(window_size)) + self.turn_index = 0 + self.active_slots: Dict[str, SessionMemoryRecordV2] = {} + self.recent_turns: deque[Dict[str, Any]] = deque(maxlen=self.window_size) + self.auto_extract = bool(auto_extract) + self._last_retrieval_context_tokens = 0 + + def reset(self) -> None: + self.turn_index = 0 + self.active_slots = {} + self.recent_turns = deque(maxlen=self.window_size) + self._last_retrieval_context_tokens = 0 + + def ingest_turn( + self, + user_text: str, + assistant_text: str = "", + *, + answer_payload: Dict[str, Any] | None = None, + extraction_result: Dict[str, Any] | None = None, + ) -> None: + self.turn_index += 1 + records = _build_turn_records( + self.extractor, + user_text=user_text, + answer_payload=answer_payload, + extraction_result=extraction_result, + turn_index=self.turn_index, + allow_auto_extract=self.auto_extract, + ) + for record in records: + previous = self.active_slots.get(record.slot_key) + if previous: + previous.state = "superseded" + record.supersedes.append(previous.memory_id) + record.state = "active" + self.active_slots[record.slot_key] = record + self.recent_turns.append( + { + "turn_index": self.turn_index, + "text": _clean_text(user_text), + "assistant": _clean_text(assistant_text), + } + ) + + def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval: + start = time.perf_counter() + query_tokens = set(_tokenize(query)) + hints = set(infer_category_hints(query)) + scored: List[tuple[float, SessionMemoryRecordV2]] = [] + for record in self.active_slots.values(): + token_set = record.token_set() + overlap = len(query_tokens & token_set) if query_tokens and token_set else 0 + score = overlap / max(1, len(query_tokens | token_set)) if query_tokens and token_set else 0.0 + if hints and record.category in hints: + score += 0.22 + if record.slot_key.lower() in _normalize(query): + score += 0.12 + score += min(0.08, record.turn_index * 0.0004) + score += 0.2 + if score > 0: + scored.append((score, record)) + scored.sort(key=lambda item: (item[0], item[1].turn_index), reverse=True) + selected_records = [record for _, record in scored[:top_k]] + hits = [ + MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + score=float(score), + source_kind=record.source_kind, + slot_key=record.slot_key, + state=record.state, + turn_index=record.turn_index, + metadata={"window_size": self.window_size}, + ) + for score, record in scored[:top_k] + ] + concepts = [] + relations = [] + for hit in hits: + concepts.append({"concept": hit.value, "type": hit.category, "source_kind": hit.source_kind}) + for anchor in hit.anchors[:2]: + concepts.append({"concept": anchor, "type": "context", "source_kind": hit.source_kind}) + relation = _relation_hit(hit, weight_bias=0.06) + if relation: + relations.append(relation) + retrieval_context_tokens = _estimate_tokens_from_hits(hits) + self._last_retrieval_context_tokens = retrieval_context_tokens + return MemoryRetrieval( + concepts=concepts, + relations=relations, + hits=hits, + active_hits=list(hits), + retrieval_seconds=time.perf_counter() - start, + context_token_estimate=retrieval_context_tokens, + retrieval_context_token_estimate=retrieval_context_tokens, + metadata={ + "records": len(self.active_slots), + "window_size": self.window_size, + "recent_turns": len(self.recent_turns), + }, + ) + + def stats(self) -> Dict[str, Any]: + payload = { + "active_slots": {slot: record.to_dict() for slot, record in self.active_slots.items()}, + "recent_turns": list(self.recent_turns), + } + total_state_tokens = _estimate_tokens(json.dumps(payload, ensure_ascii=False)) + return _state_stats( + storage_bytes=self.storage_bytes(), + retrieval_context_tokens=self._last_retrieval_context_tokens, + total_state_tokens=total_state_tokens, + records=len(self.active_slots), + active_slots=len(self.active_slots), + recent_turns=len(self.recent_turns), + ) + + def storage_bytes(self) -> int: + payload = { + "active_slots": {slot: record.to_dict() for slot, record in self.active_slots.items()}, + "recent_turns": list(self.recent_turns), + } + return len(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]: + return { + "mode": "summary_window_memory", + "query": query, + "retrieval": self.retrieve(query, top_k=top_k).to_dict(), + "stats": self.stats(), + "state": { + "active_slots": {slot: record.to_dict() for slot, record in self.active_slots.items()}, + "recent_turns": list(self.recent_turns), + }, + } + + +@dataclass(slots=True) +class _VectorRecord: + memory_id: str + category: str + value: str + relation: str + anchors: List[str] + tokens: List[str] + turn_index: int + slot_key: str = "" + active: bool = True + source_kind: str = "vector_memory" + metadata: Dict[str, Any] = field(default_factory=dict) + + +class VectorRAGMemoryAdapter(MemoryAdapter): + name = "vector_rag_memory" + + def __init__(self, *, auto_extract: bool = False) -> None: + self.extractor = SessionMemoryExtractor() + self.records: List[_VectorRecord] = [] + self.turn_index = 0 + self.auto_extract = bool(auto_extract) + self._last_retrieval_context_tokens = 0 + + def reset(self) -> None: + self.records = [] + self.turn_index = 0 + self._last_retrieval_context_tokens = 0 + + def ingest_turn( + self, + user_text: str, + assistant_text: str = "", + *, + answer_payload: Dict[str, Any] | None = None, + extraction_result: Dict[str, Any] | None = None, + ) -> None: + _ = assistant_text + self.turn_index += 1 + records = _build_turn_records( + self.extractor, + user_text=user_text, + answer_payload=answer_payload, + extraction_result=extraction_result, + turn_index=self.turn_index, + allow_auto_extract=self.auto_extract, + ) + for record in records: + if record.slot_key: + for previous in self.records: + if previous.slot_key == record.slot_key and previous.active: + previous.active = False + self.records.append( + _VectorRecord( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchor_concepts), + tokens=list(record.token_set()), + turn_index=record.turn_index, + slot_key=record.slot_key, + active=record.state == "active", + source_kind=record.source_kind, + metadata=dict(record.metadata), + ) + ) + + def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval: + start = time.perf_counter() + query_tokens = set(_tokenize(query)) + hints = set(infer_category_hints(query)) + scored: List[tuple[float, _VectorRecord]] = [] + for record in self.records: + token_set = set(record.tokens) + overlap = len(query_tokens & token_set) if query_tokens and token_set else 0 + score = overlap / max(1, len(query_tokens | token_set)) if query_tokens and token_set else 0.0 + if hints and record.category in hints: + score += 0.18 + if record.slot_key and record.slot_key.lower() in _normalize(query): + score += 0.1 + score += min(0.12, record.turn_index * 0.0004) + score += 0.18 if record.active else -0.25 + if score > 0: + scored.append((score, record)) + if not scored: + for record in self.records[-top_k:]: + scored.append((0.05 + (0.15 if record.active else 0.0), record)) + scored.sort(key=lambda item: (item[0], item[1].active, item[1].turn_index), reverse=True) + hits = [ + MemoryHit( + memory_id=record.memory_id, + category=record.category, + value=record.value, + relation=record.relation, + anchors=list(record.anchors), + score=float(score), + source_kind=record.source_kind, + slot_key=record.slot_key, + state="active" if record.active else "superseded", + turn_index=record.turn_index, + metadata=dict(record.metadata), + ) + for score, record in scored[:top_k] + ] + concepts = [] + relations = [] + for hit in hits: + concepts.append({"concept": hit.value, "type": hit.category, "source_kind": hit.source_kind}) + for anchor in hit.anchors[:2]: + concepts.append({"concept": anchor, "type": "context", "source_kind": hit.source_kind}) + relation = _relation_hit(hit) + if relation: + relations.append(relation) + retrieval_context_tokens = _estimate_tokens_from_hits(hits) + self._last_retrieval_context_tokens = retrieval_context_tokens + return MemoryRetrieval( + concepts=concepts, + relations=relations, + hits=hits, + active_hits=[hit for hit in hits if hit.state == "active"], + history_hits=[hit for hit in hits if hit.state != "active"], + retrieval_seconds=time.perf_counter() - start, + context_token_estimate=retrieval_context_tokens, + retrieval_context_token_estimate=retrieval_context_tokens, + metadata={ + "records": len(self.records), + "active_records": sum(1 for item in self.records if item.active), + }, + ) + + def stats(self) -> Dict[str, Any]: + payload = [ + { + "memory_id": record.memory_id, + "category": record.category, + "value": record.value, + "relation": record.relation, + "anchors": record.anchors, + "slot_key": record.slot_key, + "active": record.active, + } + for record in self.records + ] + total_state_tokens = _estimate_tokens(json.dumps(payload, ensure_ascii=False)) + return _state_stats( + storage_bytes=self.storage_bytes(), + retrieval_context_tokens=self._last_retrieval_context_tokens, + total_state_tokens=total_state_tokens, + records=len(self.records), + active_records=sum(1 for item in self.records if item.active), + ) + + def storage_bytes(self) -> int: + payload = [ + { + "memory_id": record.memory_id, + "category": record.category, + "value": record.value, + "relation": record.relation, + "anchors": record.anchors, + "slot_key": record.slot_key, + "active": record.active, + } + for record in self.records + ] + return len(json.dumps(payload, ensure_ascii=False).encode("utf-8")) + + def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]: + state = [ + { + "memory_id": record.memory_id, + "category": record.category, + "value": record.value, + "relation": record.relation, + "anchors": list(record.anchors), + "slot_key": record.slot_key, + "active": bool(record.active), + "turn_index": int(record.turn_index), + } + for record in self.records + ] + return { + "mode": "vector_rag_memory", + "query": query, + "retrieval": self.retrieve(query, top_k=top_k).to_dict(), + "stats": self.stats(), + "state": state, + } + + +class FullHistoryMemoryAdapter(MemoryAdapter): + name = "full_history_memory" + + def __init__(self) -> None: + self.turns: List[Dict[str, str]] = [] + self._last_retrieval_context_tokens = 0 + + def reset(self) -> None: + self.turns = [] + self._last_retrieval_context_tokens = 0 + + def ingest_turn( + self, + user_text: str, + assistant_text: str = "", + *, + answer_payload: Dict[str, Any] | None = None, + extraction_result: Dict[str, Any] | None = None, + ) -> None: + _ = answer_payload, extraction_result + self.turns.append({"user": _clean_text(user_text), "assistant": _clean_text(assistant_text)}) + + def retrieve(self, query: str, *, top_k: int = 6) -> MemoryRetrieval: + start = time.perf_counter() + query_tokens = set(_tokenize(query)) + scored: List[tuple[float, Dict[str, str], int]] = [] + for index, turn in enumerate(self.turns): + combined = f"{turn.get('user', '')} {turn.get('assistant', '')}" + token_set = set(_tokenize(combined)) + if not token_set: + continue + overlap = len(query_tokens & token_set) if query_tokens else 0 + score = overlap / max(1, len(query_tokens | token_set)) if query_tokens else 0.0 + scored.append((score, turn, index)) + scored.sort(key=lambda item: (item[0], item[2]), reverse=True) + hits = [ + MemoryHit( + memory_id=f"turn:{index}", + category="history_turn", + value=turn.get("user", ""), + relation="conversation_context", + anchors=[turn.get("assistant", "")] if turn.get("assistant") else [], + score=float(score), + source_kind="full_history", + slot_key=f"turn.{index}", + state="active", + turn_index=index + 1, + ) + for score, turn, index in scored[:top_k] + if turn.get("user", "") + ] + retrieval_context_tokens = _estimate_tokens(json.dumps(self.turns, ensure_ascii=False)) + self._last_retrieval_context_tokens = retrieval_context_tokens + return MemoryRetrieval( + hits=hits, + active_hits=list(hits), + retrieval_seconds=time.perf_counter() - start, + context_token_estimate=retrieval_context_tokens, + retrieval_context_token_estimate=retrieval_context_tokens, + metadata={"records": len(self.turns)}, + ) + + def stats(self) -> Dict[str, Any]: + total_state_tokens = _estimate_tokens(json.dumps(self.turns, ensure_ascii=False)) + return _state_stats( + storage_bytes=self.storage_bytes(), + retrieval_context_tokens=self._last_retrieval_context_tokens, + total_state_tokens=total_state_tokens, + records=len(self.turns), + ) + + def storage_bytes(self) -> int: + return len(json.dumps(self.turns, ensure_ascii=False).encode("utf-8")) + + def build_prompt_context(self, query: str, *, top_k: int = 8) -> Dict[str, Any]: + return { + "mode": "full_history_memory", + "query": query, + "retrieval": self.retrieve(query, top_k=top_k).to_dict(), + "stats": self.stats(), + "state": { + "turns": list(self.turns), + }, + }