# -*- coding: utf-8 -*- """ pluto/pipeline.py - Orchestrator for document understanding and query answering. Phase A: document understanding, cached per document. Phase B: query answering via S0 -> S1 -> S2 -> S3. """ from __future__ import annotations from typing import Any from pluto.bus import MessageBus from pluto.extraction_cache import ExtractionCache from pluto.models import ( ClaimStatus, FinalAnswer, FinalEvidence, FinalOutput, Section, TraceSummary, ) from pluto.modes import is_real_switching from pluto.stages.extract import run_extract from pluto.stages.merge import run_merge from pluto.stages.route import run_route from pluto.stages.evidence_check import run_evidence_check from pluto.tools import CorpusTools from pluto.tracer import Tracer class PipelineRunner: """Two-phase pipeline: understand documents, then answer queries.""" def __init__( self, corpus_dir: str, output_dir: str = "./output", doc_index=None, prior_session_context: list[dict] | None = None, ) -> None: self.tracer = Tracer() self.doc_index = doc_index self.prior_session_context = prior_session_context or [] self.tools = CorpusTools(corpus_dir, output_dir, self.tracer, doc_index=doc_index) self.cache = ExtractionCache(corpus_dir) self._progress_callback: Any = None self.bus = MessageBus() self.bus.subscribe(self._handle_bus_message) def _handle_bus_message(self, sender: str, msg_type: str, payload: dict) -> None: self._emit("bus", {"sender": sender, "type": msg_type, "payload": payload}) def on_progress(self, callback) -> None: """Register a callback(stage, data) for live progress updates.""" self._progress_callback = callback def _emit(self, stage: str, data: dict) -> None: if self._progress_callback: self._progress_callback(stage, data) def run( self, query: str, selected_doc_ids: list[str] | None = None, detail_level: str = "standard", ) -> FinalOutput: """Execute the full pipeline for *query*.""" selected_doc_ids = _normalize_selected_doc_ids(selected_doc_ids) detail_level = _normalize_detail_level(detail_level) self.tracer.log( "pipeline_start", { "query": query, "selected_doc_ids": selected_doc_ids, "detail_level": detail_level, }, ) self._ensure_docs_understood(selected_doc_ids=selected_doc_ids) route_query = _prepend_prior_session_context(query, self.prior_session_context) self._emit("route", {"status": "running", "query": query}) route_out = run_route( route_query, self.tools, self.tracer, bus=self.bus, selected_doc_ids=selected_doc_ids, detail_level=detail_level, ) self._emit( "route", { "status": "complete", "docs": len(route_out.doc_scope), "chunks": len(route_out.chunk_plan), }, ) self._emit("extract", {"status": "running", "total_chunks": len(route_out.chunk_plan)}) extractions = run_extract( route_out.chunk_plan, self.tools, self.tracer, query=query, cache=self.cache, ) cache_stats = self.cache.stats() self._emit( "extract", { "status": "complete", "extractions": len(extractions), "total_claims": sum(len(ext.extracted.claims) for ext in extractions), "cache_hits": cache_stats["hits"], "cache_misses": cache_stats["misses"], }, ) overview = "" if self.doc_index: overviews = [] for doc_scope in route_out.doc_scope: doc_overview = self.doc_index.get_overview(doc_scope.doc_id) if doc_overview: overviews.append(doc_overview) overview = "\n\n".join(overviews) self._emit("merge", {"status": "running"}) merge_out = run_merge( query, extractions, self.tracer, bus=self.bus, overview=overview, detail_level=detail_level, ) self._emit( "merge", { "status": "complete", "sections": len(merge_out.synthesis.answer_outline), "key_claims": len(merge_out.synthesis.key_claims), }, ) self._emit("evidence_check", {"status": "running"}) evidence_check_out = run_evidence_check(merge_out, extractions, self.tracer, bus=self.bus) self._emit( "evidence_check", { "status": "complete", "checked": len(evidence_check_out.evidence_check.checked_claims), "unsupported": len(evidence_check_out.evidence_check.unsupported_claims), "gaps": len(evidence_check_out.evidence_check.required_followups), }, ) final = self._build_final( query, merge_out, evidence_check_out, extractions, overview=overview, bus=self.bus, detail_level=detail_level, ) self.tools.finish(final.model_dump()) self._emit("finish", {"status": "complete", "confidence": final.confidence}) self.tracer.log("pipeline_complete", {"elapsed_s": self.tracer.elapsed()}) return final def _ensure_docs_understood(self, selected_doc_ids: list[str] | None = None) -> None: """Run Phase A for unprocessed docs in scope.""" if not self.doc_index: return from pluto.stages.understand import run_understand selected_doc_set = set(_normalize_selected_doc_ids(selected_doc_ids)) for doc_info in self.doc_index.list_docs(): doc_id = doc_info["doc_id"] if selected_doc_set and doc_id not in selected_doc_set: continue if doc_info["is_processed"]: continue self._emit("understand", {"status": "running", "doc_id": doc_id}) run_understand(doc_id, self.doc_index, self.tracer) self._emit("understand", {"status": "complete", "doc_id": doc_id}) def _build_final( self, query, merge_out, evidence_check_out, extractions, overview="", bus: MessageBus | None = None, detail_level: str = "standard", ) -> FinalOutput: """Assemble the FinalOutput from stage results.""" del query, overview detail_level = _normalize_detail_level(detail_level) sections: list[Section] = [] for section_point in merge_out.synthesis.answer_outline: content = "\n".join(f"• {point}" for point in section_point.points) if section_point.points else "" sections.append(Section(title=section_point.section, content=content)) section_parts = [f"**{section.title}**\n{section.content}" for section in sections if section.content] supported_checked_claims = [ checked for checked in evidence_check_out.evidence_check.checked_claims if checked.status == ClaimStatus.SUPPORTED ] claim_parts = [checked.claim for checked in supported_checked_claims] if section_parts: response = "\n\n".join(section_parts) elif claim_parts: response = " ".join(claim_parts) else: response = "No answer could be generated from the provided documents." evidence: list[FinalEvidence] = [] for extraction in extractions: for claim in extraction.extracted.claims: if not claim.evidence: continue evidence.append( FinalEvidence( doc_id=claim.evidence.doc_id, chunk_id=claim.evidence.chunk_id, where=claim.evidence.where, supports=claim.text, quote=claim.evidence.quote, ) ) total = len(evidence_check_out.evidence_check.checked_claims) supported = sum( 1 for checked in evidence_check_out.evidence_check.checked_claims if checked.status == ClaimStatus.SUPPORTED ) uncertain = sum( 1 for checked in evidence_check_out.evidence_check.checked_claims if checked.status == ClaimStatus.UNCERTAIN ) if total > 0: confidence = round((supported + (0.5 * uncertain)) / total, 2) if confidence == 0.0 and sections and evidence: confidence = 0.35 elif sections: confidence = 0.6 else: confidence = 0.0 trace = TraceSummary( real_switching=is_real_switching(), modes_used_counts=dict(self.tracer.modes_used), models_used=sorted(self.tracer.models_used), docs_opened=sorted(self.tracer.docs_opened), chunks_processed=self.tracer.chunks_processed, search_queries=self.tracer.search_queries, budget_notes=f"Within limits ({detail_level} mode)", ) bus_messages = bus.dump() if bus else [] return FinalOutput( final_answer=FinalAnswer(response=response, sections=sections), evidence=evidence, trace_summary=trace, confidence=confidence, missing_info=merge_out.synthesis.open_gaps + evidence_check_out.evidence_check.required_followups, next_actions=evidence_check_out.evidence_check.required_followups, bus_messages=bus_messages, ) def _normalize_selected_doc_ids(selected_doc_ids: list[str] | None) -> list[str]: seen: set[str] = set() normalized: list[str] = [] for raw_doc_id in selected_doc_ids or []: doc_id = str(raw_doc_id or "").strip() if not doc_id or doc_id in seen: continue seen.add(doc_id) normalized.append(doc_id) return normalized def _normalize_detail_level(detail_level: str | None) -> str: return "detailed" if str(detail_level or "").strip().lower() == "detailed" else "standard" def _prepend_prior_session_context(query: str, prior_session_context: list[dict]) -> str: key_findings: list[str] = [] open_questions: list[str] = [] for session in prior_session_context or []: key_findings.extend(str(item) for item in session.get("key_findings", []) if str(item).strip()) open_questions.extend(str(item) for item in session.get("open_questions", []) if str(item).strip()) if not key_findings and not open_questions: return query findings_block = "\n".join(f"- {finding}" for finding in key_findings[:10]) questions_block = "\n".join(f"- {question}" for question in open_questions[:10]) return ( "[Prior session findings for this document:\n" f"{findings_block}\n" "Open questions from prior sessions:\n" f"{questions_block}]\n\n" f"{query}" )