| |
| """ |
| pluto/pipeline.py - Orchestrator for document understanding and query answering. |
| |
| Phase A: document understanding, cached per document. |
| Phase B: query answering via S0 -> S1 -> S2 -> S3. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from pluto.bus import MessageBus |
| from pluto.extraction_cache import ExtractionCache |
| from pluto.models import ( |
| ClaimStatus, |
| FinalAnswer, |
| FinalEvidence, |
| FinalOutput, |
| Section, |
| TraceSummary, |
| ) |
| from pluto.modes import is_real_switching |
| from pluto.stages.extract import run_extract |
| from pluto.stages.merge import run_merge |
| from pluto.stages.route import run_route |
| from pluto.stages.evidence_check import run_evidence_check |
| from pluto.tools import CorpusTools |
| from pluto.tracer import Tracer |
|
|
|
|
| class PipelineRunner: |
| """Two-phase pipeline: understand documents, then answer queries.""" |
|
|
| def __init__( |
| self, |
| corpus_dir: str, |
| output_dir: str = "./output", |
| doc_index=None, |
| prior_session_context: list[dict] | None = None, |
| ) -> None: |
| self.tracer = Tracer() |
| self.doc_index = doc_index |
| self.prior_session_context = prior_session_context or [] |
| self.tools = CorpusTools(corpus_dir, output_dir, self.tracer, doc_index=doc_index) |
| self.cache = ExtractionCache(corpus_dir) |
| self._progress_callback: Any = None |
| self.bus = MessageBus() |
| self.bus.subscribe(self._handle_bus_message) |
|
|
| def _handle_bus_message(self, sender: str, msg_type: str, payload: dict) -> None: |
| self._emit("bus", {"sender": sender, "type": msg_type, "payload": payload}) |
|
|
| def on_progress(self, callback) -> None: |
| """Register a callback(stage, data) for live progress updates.""" |
| self._progress_callback = callback |
|
|
| def _emit(self, stage: str, data: dict) -> None: |
| if self._progress_callback: |
| self._progress_callback(stage, data) |
|
|
| def run( |
| self, |
| query: str, |
| selected_doc_ids: list[str] | None = None, |
| detail_level: str = "standard", |
| ) -> FinalOutput: |
| """Execute the full pipeline for *query*.""" |
| selected_doc_ids = _normalize_selected_doc_ids(selected_doc_ids) |
| detail_level = _normalize_detail_level(detail_level) |
|
|
| self.tracer.log( |
| "pipeline_start", |
| { |
| "query": query, |
| "selected_doc_ids": selected_doc_ids, |
| "detail_level": detail_level, |
| }, |
| ) |
|
|
| self._ensure_docs_understood(selected_doc_ids=selected_doc_ids) |
|
|
| route_query = _prepend_prior_session_context(query, self.prior_session_context) |
|
|
| self._emit("route", {"status": "running", "query": query}) |
| route_out = run_route( |
| route_query, |
| self.tools, |
| self.tracer, |
| bus=self.bus, |
| selected_doc_ids=selected_doc_ids, |
| detail_level=detail_level, |
| ) |
| self._emit( |
| "route", |
| { |
| "status": "complete", |
| "docs": len(route_out.doc_scope), |
| "chunks": len(route_out.chunk_plan), |
| }, |
| ) |
|
|
| self._emit("extract", {"status": "running", "total_chunks": len(route_out.chunk_plan)}) |
| extractions = run_extract( |
| route_out.chunk_plan, |
| self.tools, |
| self.tracer, |
| query=query, |
| cache=self.cache, |
| ) |
| cache_stats = self.cache.stats() |
| self._emit( |
| "extract", |
| { |
| "status": "complete", |
| "extractions": len(extractions), |
| "total_claims": sum(len(ext.extracted.claims) for ext in extractions), |
| "cache_hits": cache_stats["hits"], |
| "cache_misses": cache_stats["misses"], |
| }, |
| ) |
|
|
| overview = "" |
| if self.doc_index: |
| overviews = [] |
| for doc_scope in route_out.doc_scope: |
| doc_overview = self.doc_index.get_overview(doc_scope.doc_id) |
| if doc_overview: |
| overviews.append(doc_overview) |
| overview = "\n\n".join(overviews) |
|
|
| self._emit("merge", {"status": "running"}) |
| merge_out = run_merge( |
| query, |
| extractions, |
| self.tracer, |
| bus=self.bus, |
| overview=overview, |
| detail_level=detail_level, |
| ) |
| self._emit( |
| "merge", |
| { |
| "status": "complete", |
| "sections": len(merge_out.synthesis.answer_outline), |
| "key_claims": len(merge_out.synthesis.key_claims), |
| }, |
| ) |
|
|
| self._emit("evidence_check", {"status": "running"}) |
| evidence_check_out = run_evidence_check(merge_out, extractions, self.tracer, bus=self.bus) |
| self._emit( |
| "evidence_check", |
| { |
| "status": "complete", |
| "checked": len(evidence_check_out.evidence_check.checked_claims), |
| "unsupported": len(evidence_check_out.evidence_check.unsupported_claims), |
| "gaps": len(evidence_check_out.evidence_check.required_followups), |
| }, |
| ) |
|
|
| final = self._build_final( |
| query, |
| merge_out, |
| evidence_check_out, |
| extractions, |
| overview=overview, |
| bus=self.bus, |
| detail_level=detail_level, |
| ) |
|
|
| self.tools.finish(final.model_dump()) |
| self._emit("finish", {"status": "complete", "confidence": final.confidence}) |
| self.tracer.log("pipeline_complete", {"elapsed_s": self.tracer.elapsed()}) |
| return final |
|
|
| def _ensure_docs_understood(self, selected_doc_ids: list[str] | None = None) -> None: |
| """Run Phase A for unprocessed docs in scope.""" |
| if not self.doc_index: |
| return |
|
|
| from pluto.stages.understand import run_understand |
|
|
| selected_doc_set = set(_normalize_selected_doc_ids(selected_doc_ids)) |
| for doc_info in self.doc_index.list_docs(): |
| doc_id = doc_info["doc_id"] |
| if selected_doc_set and doc_id not in selected_doc_set: |
| continue |
| if doc_info["is_processed"]: |
| continue |
|
|
| self._emit("understand", {"status": "running", "doc_id": doc_id}) |
| run_understand(doc_id, self.doc_index, self.tracer) |
| self._emit("understand", {"status": "complete", "doc_id": doc_id}) |
|
|
| def _build_final( |
| self, |
| query, |
| merge_out, |
| evidence_check_out, |
| extractions, |
| overview="", |
| bus: MessageBus | None = None, |
| detail_level: str = "standard", |
| ) -> FinalOutput: |
| """Assemble the FinalOutput from stage results.""" |
| del query, overview |
| detail_level = _normalize_detail_level(detail_level) |
|
|
| sections: list[Section] = [] |
| for section_point in merge_out.synthesis.answer_outline: |
| content = "\n".join(f"• {point}" for point in section_point.points) if section_point.points else "" |
| sections.append(Section(title=section_point.section, content=content)) |
|
|
| section_parts = [f"**{section.title}**\n{section.content}" for section in sections if section.content] |
| supported_checked_claims = [ |
| checked |
| for checked in evidence_check_out.evidence_check.checked_claims |
| if checked.status == ClaimStatus.SUPPORTED |
| ] |
| claim_parts = [checked.claim for checked in supported_checked_claims] |
|
|
| if section_parts: |
| response = "\n\n".join(section_parts) |
| elif claim_parts: |
| response = " ".join(claim_parts) |
| else: |
| response = "No answer could be generated from the provided documents." |
|
|
| evidence: list[FinalEvidence] = [] |
| for extraction in extractions: |
| for claim in extraction.extracted.claims: |
| if not claim.evidence: |
| continue |
| evidence.append( |
| FinalEvidence( |
| doc_id=claim.evidence.doc_id, |
| chunk_id=claim.evidence.chunk_id, |
| where=claim.evidence.where, |
| supports=claim.text, |
| quote=claim.evidence.quote, |
| ) |
| ) |
|
|
| total = len(evidence_check_out.evidence_check.checked_claims) |
| supported = sum( |
| 1 |
| for checked in evidence_check_out.evidence_check.checked_claims |
| if checked.status == ClaimStatus.SUPPORTED |
| ) |
| uncertain = sum( |
| 1 |
| for checked in evidence_check_out.evidence_check.checked_claims |
| if checked.status == ClaimStatus.UNCERTAIN |
| ) |
|
|
| if total > 0: |
| confidence = round((supported + (0.5 * uncertain)) / total, 2) |
| if confidence == 0.0 and sections and evidence: |
| confidence = 0.35 |
| elif sections: |
| confidence = 0.6 |
| else: |
| confidence = 0.0 |
|
|
| trace = TraceSummary( |
| real_switching=is_real_switching(), |
| modes_used_counts=dict(self.tracer.modes_used), |
| models_used=sorted(self.tracer.models_used), |
| docs_opened=sorted(self.tracer.docs_opened), |
| chunks_processed=self.tracer.chunks_processed, |
| search_queries=self.tracer.search_queries, |
| budget_notes=f"Within limits ({detail_level} mode)", |
| ) |
|
|
| bus_messages = bus.dump() if bus else [] |
| return FinalOutput( |
| final_answer=FinalAnswer(response=response, sections=sections), |
| evidence=evidence, |
| trace_summary=trace, |
| confidence=confidence, |
| missing_info=merge_out.synthesis.open_gaps + evidence_check_out.evidence_check.required_followups, |
| next_actions=evidence_check_out.evidence_check.required_followups, |
| bus_messages=bus_messages, |
| ) |
|
|
|
|
| def _normalize_selected_doc_ids(selected_doc_ids: list[str] | None) -> list[str]: |
| seen: set[str] = set() |
| normalized: list[str] = [] |
| for raw_doc_id in selected_doc_ids or []: |
| doc_id = str(raw_doc_id or "").strip() |
| if not doc_id or doc_id in seen: |
| continue |
| seen.add(doc_id) |
| normalized.append(doc_id) |
| return normalized |
|
|
|
|
| def _normalize_detail_level(detail_level: str | None) -> str: |
| return "detailed" if str(detail_level or "").strip().lower() == "detailed" else "standard" |
|
|
|
|
| def _prepend_prior_session_context(query: str, prior_session_context: list[dict]) -> str: |
| key_findings: list[str] = [] |
| open_questions: list[str] = [] |
| for session in prior_session_context or []: |
| key_findings.extend(str(item) for item in session.get("key_findings", []) if str(item).strip()) |
| open_questions.extend(str(item) for item in session.get("open_questions", []) if str(item).strip()) |
|
|
| if not key_findings and not open_questions: |
| return query |
|
|
| findings_block = "\n".join(f"- {finding}" for finding in key_findings[:10]) |
| questions_block = "\n".join(f"- {question}" for question in open_questions[:10]) |
| return ( |
| "[Prior session findings for this document:\n" |
| f"{findings_block}\n" |
| "Open questions from prior sessions:\n" |
| f"{questions_block}]\n\n" |
| f"{query}" |
| ) |
|
|