| |
| """ |
| Session memory compression and retrieval. |
| |
| PostgreSQL is initialized lazily. If it is not configured or unavailable, writes |
| fall back to local JSON files under the corpus directory. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from datetime import datetime, timezone |
| import json |
| import logging |
| from pathlib import Path |
| from typing import Any |
|
|
| from pydantic import BaseModel, Field |
|
|
| from pluto.db import _get_connection |
| from pluto.utils import extract_json_from_response |
|
|
|
|
| logger = logging.getLogger("pluto") |
| LOCAL_MEMORY_DIR = ".session_memory" |
| RAW_ARCHIVE_DIR = ".session_archive" |
|
|
|
|
| class CompressedSession(BaseModel): |
| session_id: str |
| doc_id: str |
| timestamp: str |
| queries_resolved: list[dict] = Field(default_factory=list) |
| key_findings: list[str] = Field(default_factory=list) |
| open_questions: list[str] = Field(default_factory=list) |
| links_to_prior_sessions: list[str] = Field(default_factory=list) |
|
|
|
|
| def compress_session( |
| session_id: str, |
| doc_id: str, |
| session_result: dict, |
| corpus_dir: str | Path, |
| ) -> CompressedSession: |
| """Compress and store a session result without raising on storage failure.""" |
| corpus_path = Path(corpus_dir) |
| raw_path = _write_raw_session(corpus_path, session_id, session_result) |
|
|
| try: |
| raw = _call_compression_llm(session_id=session_id, doc_id=doc_id, session_result=session_result) |
| compressed = _parse_compressed_session(session_id, doc_id, raw) |
| except Exception as exc: |
| logger.warning("Session compression LLM failed for %s: %s", session_id, exc) |
| compressed = _fallback_compressed_session(session_id, doc_id, session_result) |
|
|
| try: |
| _store_postgres(compressed, raw_path) |
| except Exception as exc: |
| logger.warning("PostgreSQL session memory unavailable; writing local fallback: %s", exc) |
| _store_local(corpus_path, compressed) |
|
|
| return compressed |
|
|
|
|
| def list_session_context( |
| doc_id: str, |
| corpus_dir: str | Path, |
| limit: int = 10, |
| ) -> list[dict]: |
| """Return compressed sessions for one document, newest first.""" |
| try: |
| return _list_postgres(doc_id, limit) |
| except Exception as exc: |
| logger.warning("PostgreSQL session memory unavailable; reading local fallback: %s", exc) |
| return _list_local(Path(corpus_dir), doc_id, limit) |
|
|
|
|
| def _call_compression_llm(session_id: str, doc_id: str, session_result: dict) -> str: |
| from pluto.dispatcher import dispatch |
| from pluto.modes import get_mode |
|
|
| get_mode("MODE_QUICK") |
| prompt = f"""Compress this QA session as JSON only. |
| |
| Schema: |
| {{ |
| "queries_resolved": [ |
| {{"query": "...", "answer_summary": "...", "chunks_used": 0, "confidence": 0.0}} |
| ], |
| "key_findings": ["finding"], |
| "open_questions": ["question"], |
| "links_to_prior_sessions": [] |
| }} |
| |
| Session id: {session_id} |
| Document id: {doc_id} |
| Session result: |
| {json.dumps(session_result, ensure_ascii=False)[:14000]} |
| """ |
| return dispatch("MODE_QUICK", prompt) |
|
|
|
|
| def _parse_compressed_session(session_id: str, doc_id: str, raw: str) -> CompressedSession: |
| data = json.loads(extract_json_from_response(raw)) |
| return CompressedSession( |
| session_id=session_id, |
| doc_id=doc_id, |
| timestamp=_utc_now(), |
| queries_resolved=data.get("queries_resolved", []) if isinstance(data.get("queries_resolved"), list) else [], |
| key_findings=_string_list(data.get("key_findings")), |
| open_questions=_string_list(data.get("open_questions")), |
| links_to_prior_sessions=_string_list(data.get("links_to_prior_sessions")), |
| ) |
|
|
|
|
| def _fallback_compressed_session(session_id: str, doc_id: str, session_result: dict) -> CompressedSession: |
| final_answer = session_result.get("final_answer", {}) if isinstance(session_result, dict) else {} |
| trace = session_result.get("trace_summary", {}) if isinstance(session_result, dict) else {} |
| query = session_result.get("query", "") if isinstance(session_result, dict) else "" |
| answer = final_answer.get("response", "") if isinstance(final_answer, dict) else "" |
| return CompressedSession( |
| session_id=session_id, |
| doc_id=doc_id, |
| timestamp=_utc_now(), |
| queries_resolved=[ |
| { |
| "query": query, |
| "answer_summary": str(answer)[:500], |
| "chunks_used": trace.get("chunks_processed", 0) if isinstance(trace, dict) else 0, |
| "confidence": session_result.get("confidence", 0.0) if isinstance(session_result, dict) else 0.0, |
| } |
| ], |
| key_findings=[], |
| open_questions=session_result.get("missing_info", []) if isinstance(session_result, dict) else [], |
| links_to_prior_sessions=[], |
| ) |
|
|
|
|
| def _store_postgres(compressed: CompressedSession, raw_path: str) -> None: |
| conn = _get_connection() |
| try: |
| with conn.cursor() as cur: |
| cur.execute( |
| """ |
| INSERT INTO session_memory (session_id, doc_id, compressed_json, raw_path) |
| VALUES (%s, %s, %s::jsonb, %s) |
| ON CONFLICT (session_id) DO UPDATE SET |
| doc_id = EXCLUDED.doc_id, |
| compressed_json = EXCLUDED.compressed_json, |
| raw_path = EXCLUDED.raw_path |
| """, |
| ( |
| compressed.session_id, |
| compressed.doc_id, |
| json.dumps(compressed.model_dump(), ensure_ascii=False), |
| raw_path, |
| ), |
| ) |
| conn.commit() |
| finally: |
| conn.close() |
|
|
|
|
| def _list_postgres(doc_id: str, limit: int) -> list[dict]: |
| conn = _get_connection() |
| try: |
| with conn.cursor() as cur: |
| cur.execute( |
| """ |
| SELECT compressed_json |
| FROM session_memory |
| WHERE doc_id = %s |
| ORDER BY created_at DESC |
| LIMIT %s |
| """, |
| (doc_id, limit), |
| ) |
| rows = cur.fetchall() |
| finally: |
| conn.close() |
|
|
| results = [] |
| for row in rows: |
| value = row[0] |
| if isinstance(value, str): |
| value = json.loads(value) |
| results.append(value) |
| return results |
|
|
|
|
| def _store_local(corpus_dir: Path, compressed: CompressedSession) -> None: |
| memory_dir = corpus_dir / LOCAL_MEMORY_DIR |
| memory_dir.mkdir(parents=True, exist_ok=True) |
| path = memory_dir / f"{compressed.session_id}.json" |
| path.write_text(json.dumps(compressed.model_dump(), ensure_ascii=False, indent=1), encoding="utf-8") |
|
|
|
|
| def _list_local(corpus_dir: Path, doc_id: str, limit: int) -> list[dict]: |
| memory_dir = corpus_dir / LOCAL_MEMORY_DIR |
| if not memory_dir.exists(): |
| return [] |
|
|
| sessions = [] |
| for path in memory_dir.glob("*.json"): |
| try: |
| data = json.loads(path.read_text(encoding="utf-8")) |
| except Exception: |
| continue |
| if data.get("doc_id") == doc_id: |
| sessions.append(data) |
|
|
| sessions.sort(key=lambda item: item.get("timestamp", ""), reverse=True) |
| return sessions[:limit] |
|
|
|
|
| def _write_raw_session(corpus_dir: Path, session_id: str, session_result: dict) -> str: |
| archive_dir = corpus_dir / RAW_ARCHIVE_DIR |
| archive_dir.mkdir(parents=True, exist_ok=True) |
| path = archive_dir / f"{session_id}.json" |
| path.write_text(json.dumps(session_result, ensure_ascii=False, indent=1), encoding="utf-8") |
| return str(path) |
|
|
|
|
| def _string_list(value: Any) -> list[str]: |
| if not isinstance(value, list): |
| return [] |
| return [str(item) for item in value if str(item).strip()] |
|
|
|
|
| def _utc_now() -> str: |
| return datetime.now(timezone.utc).isoformat() |
|
|