plutoV2_miniProject_3rd-yr / mp1 /pluto /session_memory.py
ayushKishor's picture
Add Pluto memory layer and pipeline fixes
23cdeed
# -*- coding: utf-8 -*-
"""
Session memory compression and retrieval.
PostgreSQL is initialized lazily. If it is not configured or unavailable, writes
fall back to local JSON files under the corpus directory.
"""
from __future__ import annotations
from datetime import datetime, timezone
import json
import logging
from pathlib import Path
from typing import Any
from pydantic import BaseModel, Field
from pluto.db import _get_connection
from pluto.utils import extract_json_from_response
logger = logging.getLogger("pluto")
LOCAL_MEMORY_DIR = ".session_memory"
RAW_ARCHIVE_DIR = ".session_archive"
class CompressedSession(BaseModel):
session_id: str
doc_id: str
timestamp: str
queries_resolved: list[dict] = Field(default_factory=list)
key_findings: list[str] = Field(default_factory=list)
open_questions: list[str] = Field(default_factory=list)
links_to_prior_sessions: list[str] = Field(default_factory=list)
def compress_session(
session_id: str,
doc_id: str,
session_result: dict,
corpus_dir: str | Path,
) -> CompressedSession:
"""Compress and store a session result without raising on storage failure."""
corpus_path = Path(corpus_dir)
raw_path = _write_raw_session(corpus_path, session_id, session_result)
try:
raw = _call_compression_llm(session_id=session_id, doc_id=doc_id, session_result=session_result)
compressed = _parse_compressed_session(session_id, doc_id, raw)
except Exception as exc:
logger.warning("Session compression LLM failed for %s: %s", session_id, exc)
compressed = _fallback_compressed_session(session_id, doc_id, session_result)
try:
_store_postgres(compressed, raw_path)
except Exception as exc:
logger.warning("PostgreSQL session memory unavailable; writing local fallback: %s", exc)
_store_local(corpus_path, compressed)
return compressed
def list_session_context(
doc_id: str,
corpus_dir: str | Path,
limit: int = 10,
) -> list[dict]:
"""Return compressed sessions for one document, newest first."""
try:
return _list_postgres(doc_id, limit)
except Exception as exc:
logger.warning("PostgreSQL session memory unavailable; reading local fallback: %s", exc)
return _list_local(Path(corpus_dir), doc_id, limit)
def _call_compression_llm(session_id: str, doc_id: str, session_result: dict) -> str:
from pluto.dispatcher import dispatch
from pluto.modes import get_mode
get_mode("MODE_QUICK")
prompt = f"""Compress this QA session as JSON only.
Schema:
{{
"queries_resolved": [
{{"query": "...", "answer_summary": "...", "chunks_used": 0, "confidence": 0.0}}
],
"key_findings": ["finding"],
"open_questions": ["question"],
"links_to_prior_sessions": []
}}
Session id: {session_id}
Document id: {doc_id}
Session result:
{json.dumps(session_result, ensure_ascii=False)[:14000]}
"""
return dispatch("MODE_QUICK", prompt)
def _parse_compressed_session(session_id: str, doc_id: str, raw: str) -> CompressedSession:
data = json.loads(extract_json_from_response(raw))
return CompressedSession(
session_id=session_id,
doc_id=doc_id,
timestamp=_utc_now(),
queries_resolved=data.get("queries_resolved", []) if isinstance(data.get("queries_resolved"), list) else [],
key_findings=_string_list(data.get("key_findings")),
open_questions=_string_list(data.get("open_questions")),
links_to_prior_sessions=_string_list(data.get("links_to_prior_sessions")),
)
def _fallback_compressed_session(session_id: str, doc_id: str, session_result: dict) -> CompressedSession:
final_answer = session_result.get("final_answer", {}) if isinstance(session_result, dict) else {}
trace = session_result.get("trace_summary", {}) if isinstance(session_result, dict) else {}
query = session_result.get("query", "") if isinstance(session_result, dict) else ""
answer = final_answer.get("response", "") if isinstance(final_answer, dict) else ""
return CompressedSession(
session_id=session_id,
doc_id=doc_id,
timestamp=_utc_now(),
queries_resolved=[
{
"query": query,
"answer_summary": str(answer)[:500],
"chunks_used": trace.get("chunks_processed", 0) if isinstance(trace, dict) else 0,
"confidence": session_result.get("confidence", 0.0) if isinstance(session_result, dict) else 0.0,
}
],
key_findings=[],
open_questions=session_result.get("missing_info", []) if isinstance(session_result, dict) else [],
links_to_prior_sessions=[],
)
def _store_postgres(compressed: CompressedSession, raw_path: str) -> None:
conn = _get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO session_memory (session_id, doc_id, compressed_json, raw_path)
VALUES (%s, %s, %s::jsonb, %s)
ON CONFLICT (session_id) DO UPDATE SET
doc_id = EXCLUDED.doc_id,
compressed_json = EXCLUDED.compressed_json,
raw_path = EXCLUDED.raw_path
""",
(
compressed.session_id,
compressed.doc_id,
json.dumps(compressed.model_dump(), ensure_ascii=False),
raw_path,
),
)
conn.commit()
finally:
conn.close()
def _list_postgres(doc_id: str, limit: int) -> list[dict]:
conn = _get_connection()
try:
with conn.cursor() as cur:
cur.execute(
"""
SELECT compressed_json
FROM session_memory
WHERE doc_id = %s
ORDER BY created_at DESC
LIMIT %s
""",
(doc_id, limit),
)
rows = cur.fetchall()
finally:
conn.close()
results = []
for row in rows:
value = row[0]
if isinstance(value, str):
value = json.loads(value)
results.append(value)
return results
def _store_local(corpus_dir: Path, compressed: CompressedSession) -> None:
memory_dir = corpus_dir / LOCAL_MEMORY_DIR
memory_dir.mkdir(parents=True, exist_ok=True)
path = memory_dir / f"{compressed.session_id}.json"
path.write_text(json.dumps(compressed.model_dump(), ensure_ascii=False, indent=1), encoding="utf-8")
def _list_local(corpus_dir: Path, doc_id: str, limit: int) -> list[dict]:
memory_dir = corpus_dir / LOCAL_MEMORY_DIR
if not memory_dir.exists():
return []
sessions = []
for path in memory_dir.glob("*.json"):
try:
data = json.loads(path.read_text(encoding="utf-8"))
except Exception:
continue
if data.get("doc_id") == doc_id:
sessions.append(data)
sessions.sort(key=lambda item: item.get("timestamp", ""), reverse=True)
return sessions[:limit]
def _write_raw_session(corpus_dir: Path, session_id: str, session_result: dict) -> str:
archive_dir = corpus_dir / RAW_ARCHIVE_DIR
archive_dir.mkdir(parents=True, exist_ok=True)
path = archive_dir / f"{session_id}.json"
path.write_text(json.dumps(session_result, ensure_ascii=False, indent=1), encoding="utf-8")
return str(path)
def _string_list(value: Any) -> list[str]:
if not isinstance(value, list):
return []
return [str(item) for item in value if str(item).strip()]
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat()