ayushKishor's picture
Add Pluto memory layer and pipeline fixes
23cdeed
# -*- coding: utf-8 -*-
"""
pluto/models.py β€” Pydantic schemas for all 4 pipeline stages + final output.
Every model matches the spec JSON schemas exactly (Β§6-Β§7).
"""
from __future__ import annotations
import hashlib
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
# ── Enums ──────────────────────────────────────────────────────────────────────
class ChunkType(str, Enum):
TEXT = "text"
MATH = "math"
TABLE = "table"
FIGURE = "figure"
CODE = "code"
REFERENCES = "references"
NOISE = "noise"
class ModeName(str, Enum):
MODE_QUICK = "MODE_QUICK"
MODE_REASONING = "MODE_REASONING"
MODE_VISION = "MODE_VISION"
class SupportType(str, Enum):
EXPLICIT = "explicit"
IMPLICIT = "implicit"
INFERRED = "inferred"
class ClaimStatus(str, Enum):
SUPPORTED = "supported"
UNSUPPORTED = "unsupported"
UNCERTAIN = "uncertain"
class Importance(str, Enum):
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class Priority(str, Enum):
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
# ── Evidence ───────────────────────────────────────────────────────────────────
class Evidence(BaseModel):
doc_id: str
chunk_id: str
where: str = ""
quote: str = Field(default="", max_length=200)
# ── S0 ROUTE ───────────────────────────────────────────────────────────────────
class DocScope(BaseModel):
doc_id: str
reason: str
class ChunkPlan(BaseModel):
doc_id: str
chunk_id: str
where: str = ""
chunk_type: ChunkType
mode: ModeName
priority: Priority = Priority.MEDIUM
task: str = ""
class Budgets(BaseModel):
max_chunks_to_read: int = 200
max_extractions: int = 25
class RouteOutput(BaseModel):
stage: str = "route"
user_query: str
doc_scope: list[DocScope] = Field(default_factory=list)
chunk_plan: list[ChunkPlan] = Field(default_factory=list)
budgets: Budgets = Field(default_factory=Budgets)
# ── S1 EXTRACT ─────────────────────────────────────────────────────────────────
class Claim(BaseModel):
claim_id: str
text: str
importance: Importance = Importance.MEDIUM
support_type: SupportType = SupportType.EXPLICIT
numbers: list[str] = Field(default_factory=list)
entities: list[str] = Field(default_factory=list)
dependencies: list[str] = Field(default_factory=list)
evidence: Evidence | None = None
class MathItem(BaseModel):
expression: str
interpretation: str = ""
evidence: Evidence | None = None
class TableItem(BaseModel):
caption: str = ""
headers: list[str] = Field(default_factory=list)
rows: list[list[str]] = Field(default_factory=list)
evidence: Evidence | None = None
class FigureItem(BaseModel):
caption: str = ""
description: str = ""
evidence: Evidence | None = None
class CodeItem(BaseModel):
language: str = ""
snippet: str = ""
description: str = ""
evidence: Evidence | None = None
class ExtractedContent(BaseModel):
claims: list[Claim] = Field(default_factory=list)
definitions: list[dict] = Field(default_factory=list)
math: list[MathItem] = Field(default_factory=list)
table: list[TableItem] = Field(default_factory=list)
figure: list[FigureItem] = Field(default_factory=list)
code: list[CodeItem] = Field(default_factory=list)
chunk_summary: str = ""
class ExtractOutput(BaseModel):
stage: str = "extract"
doc_id: str
chunk_id: str
chunk_hash: str = ""
chunk_type: ChunkType
mode_used: ModeName
model_id: str = ""
extracted: ExtractedContent = Field(default_factory=ExtractedContent)
# ── S2 MERGE ───────────────────────────────────────────────────────────────────
class SectionPoint(BaseModel):
section: str
points: list[str] = Field(default_factory=list)
class KeyClaim(BaseModel):
claim: str
support: ClaimStatus = ClaimStatus.SUPPORTED
evidence_refs: list[Evidence] = Field(default_factory=list)
class Synthesis(BaseModel):
answer_outline: list[SectionPoint] = Field(default_factory=list)
key_claims: list[KeyClaim] = Field(default_factory=list)
open_gaps: list[str] = Field(default_factory=list)
class MergeOutput(BaseModel):
stage: str = "merge"
synthesis: Synthesis = Field(default_factory=Synthesis)
# ── S3 EvidenceCheck ──────────────────────────────────────────────────────────────────
class CheckedClaim(BaseModel):
claim: str
status: ClaimStatus
evidence: list[Evidence] = Field(default_factory=list)
class EvidenceCheck(BaseModel):
checked_claims: list[CheckedClaim] = Field(default_factory=list)
unsupported_claims: list[str] = Field(default_factory=list)
required_followups: list[str] = Field(default_factory=list)
class EvidenceCheckOutput(BaseModel):
stage: str = "evidence_check"
evidence_check: EvidenceCheck = Field(default_factory=EvidenceCheck)
# ── FINAL OUTPUT ───────────────────────────────────────────────────────────────
class Section(BaseModel):
title: str
content: str
class FinalAnswer(BaseModel):
response: str
sections: list[Section] = Field(default_factory=list)
class FinalEvidence(BaseModel):
doc_id: str
chunk_id: str
where: str = ""
supports: str = ""
quote: str = Field(default="", max_length=200)
class TraceSummary(BaseModel):
real_switching: bool = False
modes_used_counts: dict[str, int] = Field(default_factory=dict)
models_used: list[str] = Field(default_factory=list)
docs_opened: list[str] = Field(default_factory=list)
chunks_processed: int = 0
search_queries: list[str] = Field(default_factory=list)
budget_notes: str = ""
class FinalOutput(BaseModel):
final_answer: FinalAnswer = Field(default_factory=FinalAnswer)
evidence: list[FinalEvidence] = Field(default_factory=list)
trace_summary: TraceSummary = Field(default_factory=TraceSummary)
confidence: float = 0.0
missing_info: list[str] = Field(default_factory=list)
next_actions: list[str] = Field(default_factory=list)
bus_messages: list[dict] = Field(default_factory=list)
# ── Helpers ────────────────────────────────────────────────────────────────────
CHUNK_TYPE_TO_MODE: dict[ChunkType, ModeName] = {
ChunkType.TEXT: ModeName.MODE_REASONING,
ChunkType.MATH: ModeName.MODE_REASONING,
ChunkType.TABLE: ModeName.MODE_REASONING,
ChunkType.FIGURE: ModeName.MODE_VISION,
ChunkType.CODE: ModeName.MODE_REASONING,
ChunkType.REFERENCES: ModeName.MODE_QUICK,
ChunkType.NOISE: ModeName.MODE_QUICK,
}
def compute_chunk_hash(content: str) -> str:
"""SHA-256 hash of chunk content."""
return hashlib.sha256(content.encode("utf-8")).hexdigest()