ayushKishor's picture
Add Pluto memory layer and pipeline fixes
23cdeed
# -*- coding: utf-8 -*-
"""
pluto/pipeline.py - Orchestrator for document understanding and query answering.
Phase A: document understanding, cached per document.
Phase B: query answering via S0 -> S1 -> S2 -> S3.
"""
from __future__ import annotations
from typing import Any
from pluto.bus import MessageBus
from pluto.extraction_cache import ExtractionCache
from pluto.models import (
ClaimStatus,
FinalAnswer,
FinalEvidence,
FinalOutput,
Section,
TraceSummary,
)
from pluto.modes import is_real_switching
from pluto.stages.extract import run_extract
from pluto.stages.merge import run_merge
from pluto.stages.route import run_route
from pluto.stages.evidence_check import run_evidence_check
from pluto.tools import CorpusTools
from pluto.tracer import Tracer
class PipelineRunner:
"""Two-phase pipeline: understand documents, then answer queries."""
def __init__(
self,
corpus_dir: str,
output_dir: str = "./output",
doc_index=None,
prior_session_context: list[dict] | None = None,
) -> None:
self.tracer = Tracer()
self.doc_index = doc_index
self.prior_session_context = prior_session_context or []
self.tools = CorpusTools(corpus_dir, output_dir, self.tracer, doc_index=doc_index)
self.cache = ExtractionCache(corpus_dir)
self._progress_callback: Any = None
self.bus = MessageBus()
self.bus.subscribe(self._handle_bus_message)
def _handle_bus_message(self, sender: str, msg_type: str, payload: dict) -> None:
self._emit("bus", {"sender": sender, "type": msg_type, "payload": payload})
def on_progress(self, callback) -> None:
"""Register a callback(stage, data) for live progress updates."""
self._progress_callback = callback
def _emit(self, stage: str, data: dict) -> None:
if self._progress_callback:
self._progress_callback(stage, data)
def run(
self,
query: str,
selected_doc_ids: list[str] | None = None,
detail_level: str = "standard",
) -> FinalOutput:
"""Execute the full pipeline for *query*."""
selected_doc_ids = _normalize_selected_doc_ids(selected_doc_ids)
detail_level = _normalize_detail_level(detail_level)
self.tracer.log(
"pipeline_start",
{
"query": query,
"selected_doc_ids": selected_doc_ids,
"detail_level": detail_level,
},
)
self._ensure_docs_understood(selected_doc_ids=selected_doc_ids)
route_query = _prepend_prior_session_context(query, self.prior_session_context)
self._emit("route", {"status": "running", "query": query})
route_out = run_route(
route_query,
self.tools,
self.tracer,
bus=self.bus,
selected_doc_ids=selected_doc_ids,
detail_level=detail_level,
)
self._emit(
"route",
{
"status": "complete",
"docs": len(route_out.doc_scope),
"chunks": len(route_out.chunk_plan),
},
)
self._emit("extract", {"status": "running", "total_chunks": len(route_out.chunk_plan)})
extractions = run_extract(
route_out.chunk_plan,
self.tools,
self.tracer,
query=query,
cache=self.cache,
)
cache_stats = self.cache.stats()
self._emit(
"extract",
{
"status": "complete",
"extractions": len(extractions),
"total_claims": sum(len(ext.extracted.claims) for ext in extractions),
"cache_hits": cache_stats["hits"],
"cache_misses": cache_stats["misses"],
},
)
overview = ""
if self.doc_index:
overviews = []
for doc_scope in route_out.doc_scope:
doc_overview = self.doc_index.get_overview(doc_scope.doc_id)
if doc_overview:
overviews.append(doc_overview)
overview = "\n\n".join(overviews)
self._emit("merge", {"status": "running"})
merge_out = run_merge(
query,
extractions,
self.tracer,
bus=self.bus,
overview=overview,
detail_level=detail_level,
)
self._emit(
"merge",
{
"status": "complete",
"sections": len(merge_out.synthesis.answer_outline),
"key_claims": len(merge_out.synthesis.key_claims),
},
)
self._emit("evidence_check", {"status": "running"})
evidence_check_out = run_evidence_check(merge_out, extractions, self.tracer, bus=self.bus)
self._emit(
"evidence_check",
{
"status": "complete",
"checked": len(evidence_check_out.evidence_check.checked_claims),
"unsupported": len(evidence_check_out.evidence_check.unsupported_claims),
"gaps": len(evidence_check_out.evidence_check.required_followups),
},
)
final = self._build_final(
query,
merge_out,
evidence_check_out,
extractions,
overview=overview,
bus=self.bus,
detail_level=detail_level,
)
self.tools.finish(final.model_dump())
self._emit("finish", {"status": "complete", "confidence": final.confidence})
self.tracer.log("pipeline_complete", {"elapsed_s": self.tracer.elapsed()})
return final
def _ensure_docs_understood(self, selected_doc_ids: list[str] | None = None) -> None:
"""Run Phase A for unprocessed docs in scope."""
if not self.doc_index:
return
from pluto.stages.understand import run_understand
selected_doc_set = set(_normalize_selected_doc_ids(selected_doc_ids))
for doc_info in self.doc_index.list_docs():
doc_id = doc_info["doc_id"]
if selected_doc_set and doc_id not in selected_doc_set:
continue
if doc_info["is_processed"]:
continue
self._emit("understand", {"status": "running", "doc_id": doc_id})
run_understand(doc_id, self.doc_index, self.tracer)
self._emit("understand", {"status": "complete", "doc_id": doc_id})
def _build_final(
self,
query,
merge_out,
evidence_check_out,
extractions,
overview="",
bus: MessageBus | None = None,
detail_level: str = "standard",
) -> FinalOutput:
"""Assemble the FinalOutput from stage results."""
del query, overview
detail_level = _normalize_detail_level(detail_level)
sections: list[Section] = []
for section_point in merge_out.synthesis.answer_outline:
content = "\n".join(f"• {point}" for point in section_point.points) if section_point.points else ""
sections.append(Section(title=section_point.section, content=content))
section_parts = [f"**{section.title}**\n{section.content}" for section in sections if section.content]
supported_checked_claims = [
checked
for checked in evidence_check_out.evidence_check.checked_claims
if checked.status == ClaimStatus.SUPPORTED
]
claim_parts = [checked.claim for checked in supported_checked_claims]
if section_parts:
response = "\n\n".join(section_parts)
elif claim_parts:
response = " ".join(claim_parts)
else:
response = "No answer could be generated from the provided documents."
evidence: list[FinalEvidence] = []
for extraction in extractions:
for claim in extraction.extracted.claims:
if not claim.evidence:
continue
evidence.append(
FinalEvidence(
doc_id=claim.evidence.doc_id,
chunk_id=claim.evidence.chunk_id,
where=claim.evidence.where,
supports=claim.text,
quote=claim.evidence.quote,
)
)
total = len(evidence_check_out.evidence_check.checked_claims)
supported = sum(
1
for checked in evidence_check_out.evidence_check.checked_claims
if checked.status == ClaimStatus.SUPPORTED
)
uncertain = sum(
1
for checked in evidence_check_out.evidence_check.checked_claims
if checked.status == ClaimStatus.UNCERTAIN
)
if total > 0:
confidence = round((supported + (0.5 * uncertain)) / total, 2)
if confidence == 0.0 and sections and evidence:
confidence = 0.35
elif sections:
confidence = 0.6
else:
confidence = 0.0
trace = TraceSummary(
real_switching=is_real_switching(),
modes_used_counts=dict(self.tracer.modes_used),
models_used=sorted(self.tracer.models_used),
docs_opened=sorted(self.tracer.docs_opened),
chunks_processed=self.tracer.chunks_processed,
search_queries=self.tracer.search_queries,
budget_notes=f"Within limits ({detail_level} mode)",
)
bus_messages = bus.dump() if bus else []
return FinalOutput(
final_answer=FinalAnswer(response=response, sections=sections),
evidence=evidence,
trace_summary=trace,
confidence=confidence,
missing_info=merge_out.synthesis.open_gaps + evidence_check_out.evidence_check.required_followups,
next_actions=evidence_check_out.evidence_check.required_followups,
bus_messages=bus_messages,
)
def _normalize_selected_doc_ids(selected_doc_ids: list[str] | None) -> list[str]:
seen: set[str] = set()
normalized: list[str] = []
for raw_doc_id in selected_doc_ids or []:
doc_id = str(raw_doc_id or "").strip()
if not doc_id or doc_id in seen:
continue
seen.add(doc_id)
normalized.append(doc_id)
return normalized
def _normalize_detail_level(detail_level: str | None) -> str:
return "detailed" if str(detail_level or "").strip().lower() == "detailed" else "standard"
def _prepend_prior_session_context(query: str, prior_session_context: list[dict]) -> str:
key_findings: list[str] = []
open_questions: list[str] = []
for session in prior_session_context or []:
key_findings.extend(str(item) for item in session.get("key_findings", []) if str(item).strip())
open_questions.extend(str(item) for item in session.get("open_questions", []) if str(item).strip())
if not key_findings and not open_questions:
return query
findings_block = "\n".join(f"- {finding}" for finding in key_findings[:10])
questions_block = "\n".join(f"- {question}" for question in open_questions[:10])
return (
"[Prior session findings for this document:\n"
f"{findings_block}\n"
"Open questions from prior sessions:\n"
f"{questions_block}]\n\n"
f"{query}"
)