# -*- coding: utf-8 -*- """ Session memory compression and retrieval. PostgreSQL is initialized lazily. If it is not configured or unavailable, writes fall back to local JSON files under the corpus directory. """ from __future__ import annotations from datetime import datetime, timezone import json import logging from pathlib import Path from typing import Any from pydantic import BaseModel, Field from pluto.db import _get_connection from pluto.utils import extract_json_from_response logger = logging.getLogger("pluto") LOCAL_MEMORY_DIR = ".session_memory" RAW_ARCHIVE_DIR = ".session_archive" class CompressedSession(BaseModel): session_id: str doc_id: str timestamp: str queries_resolved: list[dict] = Field(default_factory=list) key_findings: list[str] = Field(default_factory=list) open_questions: list[str] = Field(default_factory=list) links_to_prior_sessions: list[str] = Field(default_factory=list) def compress_session( session_id: str, doc_id: str, session_result: dict, corpus_dir: str | Path, ) -> CompressedSession: """Compress and store a session result without raising on storage failure.""" corpus_path = Path(corpus_dir) raw_path = _write_raw_session(corpus_path, session_id, session_result) try: raw = _call_compression_llm(session_id=session_id, doc_id=doc_id, session_result=session_result) compressed = _parse_compressed_session(session_id, doc_id, raw) except Exception as exc: logger.warning("Session compression LLM failed for %s: %s", session_id, exc) compressed = _fallback_compressed_session(session_id, doc_id, session_result) try: _store_postgres(compressed, raw_path) except Exception as exc: logger.warning("PostgreSQL session memory unavailable; writing local fallback: %s", exc) _store_local(corpus_path, compressed) return compressed def list_session_context( doc_id: str, corpus_dir: str | Path, limit: int = 10, ) -> list[dict]: """Return compressed sessions for one document, newest first.""" try: return _list_postgres(doc_id, limit) except Exception as exc: logger.warning("PostgreSQL session memory unavailable; reading local fallback: %s", exc) return _list_local(Path(corpus_dir), doc_id, limit) def _call_compression_llm(session_id: str, doc_id: str, session_result: dict) -> str: from pluto.dispatcher import dispatch from pluto.modes import get_mode get_mode("MODE_QUICK") prompt = f"""Compress this QA session as JSON only. Schema: {{ "queries_resolved": [ {{"query": "...", "answer_summary": "...", "chunks_used": 0, "confidence": 0.0}} ], "key_findings": ["finding"], "open_questions": ["question"], "links_to_prior_sessions": [] }} Session id: {session_id} Document id: {doc_id} Session result: {json.dumps(session_result, ensure_ascii=False)[:14000]} """ return dispatch("MODE_QUICK", prompt) def _parse_compressed_session(session_id: str, doc_id: str, raw: str) -> CompressedSession: data = json.loads(extract_json_from_response(raw)) return CompressedSession( session_id=session_id, doc_id=doc_id, timestamp=_utc_now(), queries_resolved=data.get("queries_resolved", []) if isinstance(data.get("queries_resolved"), list) else [], key_findings=_string_list(data.get("key_findings")), open_questions=_string_list(data.get("open_questions")), links_to_prior_sessions=_string_list(data.get("links_to_prior_sessions")), ) def _fallback_compressed_session(session_id: str, doc_id: str, session_result: dict) -> CompressedSession: final_answer = session_result.get("final_answer", {}) if isinstance(session_result, dict) else {} trace = session_result.get("trace_summary", {}) if isinstance(session_result, dict) else {} query = session_result.get("query", "") if isinstance(session_result, dict) else "" answer = final_answer.get("response", "") if isinstance(final_answer, dict) else "" return CompressedSession( session_id=session_id, doc_id=doc_id, timestamp=_utc_now(), queries_resolved=[ { "query": query, "answer_summary": str(answer)[:500], "chunks_used": trace.get("chunks_processed", 0) if isinstance(trace, dict) else 0, "confidence": session_result.get("confidence", 0.0) if isinstance(session_result, dict) else 0.0, } ], key_findings=[], open_questions=session_result.get("missing_info", []) if isinstance(session_result, dict) else [], links_to_prior_sessions=[], ) def _store_postgres(compressed: CompressedSession, raw_path: str) -> None: conn = _get_connection() try: with conn.cursor() as cur: cur.execute( """ INSERT INTO session_memory (session_id, doc_id, compressed_json, raw_path) VALUES (%s, %s, %s::jsonb, %s) ON CONFLICT (session_id) DO UPDATE SET doc_id = EXCLUDED.doc_id, compressed_json = EXCLUDED.compressed_json, raw_path = EXCLUDED.raw_path """, ( compressed.session_id, compressed.doc_id, json.dumps(compressed.model_dump(), ensure_ascii=False), raw_path, ), ) conn.commit() finally: conn.close() def _list_postgres(doc_id: str, limit: int) -> list[dict]: conn = _get_connection() try: with conn.cursor() as cur: cur.execute( """ SELECT compressed_json FROM session_memory WHERE doc_id = %s ORDER BY created_at DESC LIMIT %s """, (doc_id, limit), ) rows = cur.fetchall() finally: conn.close() results = [] for row in rows: value = row[0] if isinstance(value, str): value = json.loads(value) results.append(value) return results def _store_local(corpus_dir: Path, compressed: CompressedSession) -> None: memory_dir = corpus_dir / LOCAL_MEMORY_DIR memory_dir.mkdir(parents=True, exist_ok=True) path = memory_dir / f"{compressed.session_id}.json" path.write_text(json.dumps(compressed.model_dump(), ensure_ascii=False, indent=1), encoding="utf-8") def _list_local(corpus_dir: Path, doc_id: str, limit: int) -> list[dict]: memory_dir = corpus_dir / LOCAL_MEMORY_DIR if not memory_dir.exists(): return [] sessions = [] for path in memory_dir.glob("*.json"): try: data = json.loads(path.read_text(encoding="utf-8")) except Exception: continue if data.get("doc_id") == doc_id: sessions.append(data) sessions.sort(key=lambda item: item.get("timestamp", ""), reverse=True) return sessions[:limit] def _write_raw_session(corpus_dir: Path, session_id: str, session_result: dict) -> str: archive_dir = corpus_dir / RAW_ARCHIVE_DIR archive_dir.mkdir(parents=True, exist_ok=True) path = archive_dir / f"{session_id}.json" path.write_text(json.dumps(session_result, ensure_ascii=False, indent=1), encoding="utf-8") return str(path) def _string_list(value: Any) -> list[str]: if not isinstance(value, list): return [] return [str(item) for item in value if str(item).strip()] def _utc_now() -> str: return datetime.now(timezone.utc).isoformat()