| | """ |
| | DocumentAgent for SPARKNET |
| | |
| | A ReAct-style agent for document intelligence tasks: |
| | - Document parsing and extraction |
| | - Field extraction with grounding |
| | - Table and chart analysis |
| | - Document classification |
| | - Question answering over documents |
| | """ |
| |
|
| | from typing import List, Dict, Any, Optional, Tuple |
| | from dataclasses import dataclass |
| | from enum import Enum |
| | import json |
| | import time |
| | from loguru import logger |
| |
|
| | from .base_agent import BaseAgent, Task, Message |
| | from ..llm.langchain_ollama_client import LangChainOllamaClient |
| | from ..document.schemas.core import ( |
| | ProcessedDocument, |
| | DocumentChunk, |
| | EvidenceRef, |
| | ExtractionResult, |
| | ) |
| | from ..document.schemas.extraction import ExtractionSchema, ExtractedField |
| | from ..document.schemas.classification import DocumentClassification, DocumentType |
| |
|
| |
|
| | class AgentAction(str, Enum): |
| | """Actions the DocumentAgent can take.""" |
| | THINK = "think" |
| | USE_TOOL = "use_tool" |
| | ANSWER = "answer" |
| | ABSTAIN = "abstain" |
| |
|
| |
|
| | @dataclass |
| | class ThoughtAction: |
| | """A thought-action pair in the ReAct loop.""" |
| | thought: str |
| | action: AgentAction |
| | tool_name: Optional[str] = None |
| | tool_args: Optional[Dict[str, Any]] = None |
| | observation: Optional[str] = None |
| | evidence: Optional[List[EvidenceRef]] = None |
| |
|
| |
|
| | @dataclass |
| | class AgentTrace: |
| | """Full trace of agent execution for inspection.""" |
| | task: str |
| | steps: List[ThoughtAction] |
| | final_answer: Optional[Any] = None |
| | confidence: float = 0.0 |
| | total_time_ms: float = 0.0 |
| | success: bool = True |
| | error: Optional[str] = None |
| |
|
| |
|
| | class DocumentAgent: |
| | """ |
| | ReAct-style agent for document intelligence tasks. |
| | |
| | Implements the Think -> Tool -> Observe -> Refine loop |
| | with inspectable traces and grounded outputs. |
| | """ |
| |
|
| | |
| | SYSTEM_PROMPT = """You are a document intelligence agent that analyzes documents |
| | and extracts information with evidence. |
| | |
| | You operate in a Think-Act-Observe loop: |
| | 1. THINK: Analyze what you need to do and what information you have |
| | 2. ACT: Choose a tool to use or provide an answer |
| | 3. OBSERVE: Review the tool output and update your understanding |
| | |
| | Available tools: |
| | {tool_descriptions} |
| | |
| | CRITICAL RULES: |
| | - Every extraction MUST include evidence (page, bbox, text snippet) |
| | - If you cannot find evidence for a value, ABSTAIN rather than guess |
| | - Always cite the source of information with page numbers |
| | - For tables, analyze structure before extracting data |
| | - For charts, describe what you see before extracting values |
| | |
| | Output format for each step: |
| | THOUGHT: <your reasoning> |
| | ACTION: <tool_name or ANSWER or ABSTAIN> |
| | ACTION_INPUT: <JSON arguments for tool, or final answer> |
| | """ |
| |
|
| | |
| | TOOLS = { |
| | "extract_text": { |
| | "description": "Extract text from specific pages or regions", |
| | "args": ["page_numbers", "region_bbox"], |
| | }, |
| | "analyze_table": { |
| | "description": "Analyze and extract structured data from a table region", |
| | "args": ["page", "bbox", "expected_columns"], |
| | }, |
| | "analyze_chart": { |
| | "description": "Analyze a chart/graph and extract insights", |
| | "args": ["page", "bbox"], |
| | }, |
| | "extract_fields": { |
| | "description": "Extract specific fields using a schema", |
| | "args": ["schema", "context_chunks"], |
| | }, |
| | "classify_document": { |
| | "description": "Classify the document type", |
| | "args": ["first_page_chunks"], |
| | }, |
| | "search_text": { |
| | "description": "Search for text patterns in the document", |
| | "args": ["query", "page_range"], |
| | }, |
| | } |
| |
|
| | def __init__( |
| | self, |
| | llm_client: LangChainOllamaClient, |
| | memory_agent: Optional[Any] = None, |
| | max_iterations: int = 10, |
| | temperature: float = 0.3, |
| | ): |
| | """ |
| | Initialize DocumentAgent. |
| | |
| | Args: |
| | llm_client: LangChain Ollama client |
| | memory_agent: Optional memory agent for context retrieval |
| | max_iterations: Maximum ReAct iterations |
| | temperature: LLM temperature for reasoning |
| | """ |
| | self.llm_client = llm_client |
| | self.memory_agent = memory_agent |
| | self.max_iterations = max_iterations |
| | self.temperature = temperature |
| |
|
| | |
| | self._current_document: Optional[ProcessedDocument] = None |
| | self._page_images: Dict[int, Any] = {} |
| |
|
| | logger.info(f"Initialized DocumentAgent (max_iterations={max_iterations})") |
| |
|
| | def set_document( |
| | self, |
| | document: ProcessedDocument, |
| | page_images: Optional[Dict[int, Any]] = None, |
| | ): |
| | """ |
| | Set the current document context. |
| | |
| | Args: |
| | document: Processed document |
| | page_images: Optional dict of page number -> image array |
| | """ |
| | self._current_document = document |
| | self._page_images = page_images or {} |
| | logger.info(f"Set document context: {document.metadata.document_id}") |
| |
|
| | async def run( |
| | self, |
| | task_description: str, |
| | extraction_schema: Optional[ExtractionSchema] = None, |
| | ) -> Tuple[Any, AgentTrace]: |
| | """ |
| | Run the agent on a task. |
| | |
| | Args: |
| | task_description: Natural language task description |
| | extraction_schema: Optional schema for structured extraction |
| | |
| | Returns: |
| | Tuple of (result, trace) |
| | """ |
| | start_time = time.time() |
| |
|
| | if not self._current_document: |
| | raise ValueError("No document set. Call set_document() first.") |
| |
|
| | trace = AgentTrace(task=task_description, steps=[]) |
| |
|
| | try: |
| | |
| | context = self._build_context(extraction_schema) |
| |
|
| | |
| | result = None |
| | for iteration in range(self.max_iterations): |
| | logger.debug(f"ReAct iteration {iteration + 1}") |
| |
|
| | |
| | step = await self._generate_step(task_description, context, trace.steps) |
| | trace.steps.append(step) |
| |
|
| | |
| | if step.action == AgentAction.ANSWER: |
| | result = self._parse_answer(step.tool_args) |
| | trace.final_answer = result |
| | trace.confidence = self._calculate_confidence(trace.steps) |
| | break |
| |
|
| | elif step.action == AgentAction.ABSTAIN: |
| | trace.final_answer = { |
| | "abstained": True, |
| | "reason": step.thought, |
| | } |
| | trace.confidence = 0.0 |
| | break |
| |
|
| | elif step.action == AgentAction.USE_TOOL: |
| | |
| | observation, evidence = await self._execute_tool( |
| | step.tool_name, step.tool_args |
| | ) |
| | step.observation = observation |
| | step.evidence = evidence |
| |
|
| | |
| | context += f"\n\nObservation from {step.tool_name}:\n{observation}" |
| |
|
| | trace.success = True |
| |
|
| | except Exception as e: |
| | logger.error(f"Agent execution failed: {e}") |
| | trace.success = False |
| | trace.error = str(e) |
| |
|
| | trace.total_time_ms = (time.time() - start_time) * 1000 |
| | return trace.final_answer, trace |
| |
|
| | async def extract_fields( |
| | self, |
| | schema: ExtractionSchema, |
| | ) -> ExtractionResult: |
| | """ |
| | Extract fields from the document using a schema. |
| | |
| | Args: |
| | schema: Extraction schema defining fields |
| | |
| | Returns: |
| | ExtractionResult with extracted data and evidence |
| | """ |
| | task = f"Extract the following fields from this document: {', '.join(f.name for f in schema.fields)}" |
| | result, trace = await self.run(task, schema) |
| |
|
| | |
| | data = {} |
| | evidence = [] |
| | warnings = [] |
| | abstained = [] |
| |
|
| | if isinstance(result, dict): |
| | data = result.get("data", result) |
| |
|
| | |
| | for step in trace.steps: |
| | if step.evidence: |
| | evidence.extend(step.evidence) |
| |
|
| | |
| | for field in schema.fields: |
| | if field.name not in data and field.required: |
| | abstained.append(field.name) |
| | warnings.append( |
| | f"Required field '{field.name}' not found with sufficient confidence" |
| | ) |
| |
|
| | return ExtractionResult( |
| | data=data, |
| | evidence=evidence, |
| | warnings=warnings, |
| | confidence=trace.confidence, |
| | abstained_fields=abstained, |
| | ) |
| |
|
| | async def classify(self) -> DocumentClassification: |
| | """ |
| | Classify the document type. |
| | |
| | Returns: |
| | DocumentClassification with type and confidence |
| | """ |
| | task = "Classify this document into one of the standard document types (contract, invoice, patent, research_paper, report, letter, form, etc.)" |
| | result, trace = await self.run(task) |
| |
|
| | |
| | doc_type = DocumentType.UNKNOWN |
| | confidence = 0.0 |
| |
|
| | if isinstance(result, dict): |
| | type_str = result.get("document_type", "unknown") |
| | try: |
| | doc_type = DocumentType(type_str.lower()) |
| | except ValueError: |
| | doc_type = DocumentType.OTHER |
| |
|
| | confidence = result.get("confidence", trace.confidence) |
| |
|
| | return DocumentClassification( |
| | document_id=self._current_document.metadata.document_id, |
| | primary_type=doc_type, |
| | primary_confidence=confidence, |
| | evidence=[e for step in trace.steps if step.evidence for e in step.evidence], |
| | method="llm", |
| | is_confident=confidence >= 0.7, |
| | ) |
| |
|
| | async def answer_question(self, question: str) -> Tuple[str, List[EvidenceRef]]: |
| | """ |
| | Answer a question about the document. |
| | |
| | Args: |
| | question: Natural language question |
| | |
| | Returns: |
| | Tuple of (answer, evidence) |
| | """ |
| | task = f"Answer this question about the document: {question}" |
| | result, trace = await self.run(task) |
| |
|
| | answer = "" |
| | evidence = [] |
| |
|
| | if isinstance(result, dict): |
| | answer = result.get("answer", str(result)) |
| | elif isinstance(result, str): |
| | answer = result |
| |
|
| | |
| | for step in trace.steps: |
| | if step.evidence: |
| | evidence.extend(step.evidence) |
| |
|
| | return answer, evidence |
| |
|
| | def _build_context(self, schema: Optional[ExtractionSchema] = None) -> str: |
| | """Build initial context from document.""" |
| | doc = self._current_document |
| | context_parts = [ |
| | f"Document: {doc.metadata.filename}", |
| | f"Type: {doc.metadata.file_type}", |
| | f"Pages: {doc.metadata.num_pages}", |
| | f"Chunks: {len(doc.chunks)}", |
| | "", |
| | "Document content summary:", |
| | ] |
| |
|
| | |
| | for chunk in doc.chunks[:10]: |
| | context_parts.append( |
| | f"[Page {chunk.page + 1}, {chunk.chunk_type.value}]: {chunk.text[:200]}..." |
| | ) |
| |
|
| | if schema: |
| | context_parts.append("") |
| | context_parts.append("Extraction schema:") |
| | for field in schema.fields: |
| | req = "required" if field.required else "optional" |
| | context_parts.append(f"- {field.name} ({field.type.value}, {req}): {field.description}") |
| |
|
| | return "\n".join(context_parts) |
| |
|
| | async def _generate_step( |
| | self, |
| | task: str, |
| | context: str, |
| | previous_steps: List[ThoughtAction], |
| | ) -> ThoughtAction: |
| | """Generate the next thought-action step.""" |
| | |
| | tool_descriptions = "\n".join( |
| | f"- {name}: {info['description']}" |
| | for name, info in self.TOOLS.items() |
| | ) |
| |
|
| | system_prompt = self.SYSTEM_PROMPT.format(tool_descriptions=tool_descriptions) |
| |
|
| | messages = [{"role": "system", "content": system_prompt}] |
| |
|
| | |
| | user_content = f"TASK: {task}\n\nCONTEXT:\n{context}" |
| |
|
| | |
| | if previous_steps: |
| | user_content += "\n\nPREVIOUS STEPS:" |
| | for i, step in enumerate(previous_steps, 1): |
| | user_content += f"\n\nStep {i}:" |
| | user_content += f"\nTHOUGHT: {step.thought}" |
| | user_content += f"\nACTION: {step.action.value}" |
| | if step.tool_name: |
| | user_content += f"\nTOOL: {step.tool_name}" |
| | if step.observation: |
| | user_content += f"\nOBSERVATION: {step.observation[:500]}..." |
| |
|
| | user_content += "\n\nNow generate your next step:" |
| | messages.append({"role": "user", "content": user_content}) |
| |
|
| | |
| | llm = self.llm_client.get_llm(complexity="complex", temperature=self.temperature) |
| |
|
| | from langchain_core.messages import HumanMessage, SystemMessage |
| | lc_messages = [ |
| | SystemMessage(content=system_prompt), |
| | HumanMessage(content=user_content), |
| | ] |
| |
|
| | response = await llm.ainvoke(lc_messages) |
| | response_text = response.content |
| |
|
| | |
| | return self._parse_step(response_text) |
| |
|
| | def _parse_step(self, response: str) -> ThoughtAction: |
| | """Parse LLM response into ThoughtAction.""" |
| | thought = "" |
| | action = AgentAction.THINK |
| | tool_name = None |
| | tool_args = None |
| |
|
| | lines = response.strip().split("\n") |
| | current_section = None |
| |
|
| | for line in lines: |
| | line = line.strip() |
| |
|
| | if line.startswith("THOUGHT:"): |
| | current_section = "thought" |
| | thought = line[8:].strip() |
| | elif line.startswith("ACTION:"): |
| | current_section = "action" |
| | action_str = line[7:].strip().lower() |
| | if action_str == "answer": |
| | action = AgentAction.ANSWER |
| | elif action_str == "abstain": |
| | action = AgentAction.ABSTAIN |
| | elif action_str in self.TOOLS: |
| | action = AgentAction.USE_TOOL |
| | tool_name = action_str |
| | else: |
| | action = AgentAction.USE_TOOL |
| | tool_name = action_str |
| | elif line.startswith("ACTION_INPUT:"): |
| | current_section = "input" |
| | input_str = line[13:].strip() |
| | try: |
| | tool_args = json.loads(input_str) |
| | except json.JSONDecodeError: |
| | tool_args = {"raw": input_str} |
| | elif current_section == "thought": |
| | thought += " " + line |
| | elif current_section == "input": |
| | try: |
| | tool_args = json.loads(line) |
| | except: |
| | pass |
| |
|
| | return ThoughtAction( |
| | thought=thought, |
| | action=action, |
| | tool_name=tool_name, |
| | tool_args=tool_args, |
| | ) |
| |
|
| | async def _execute_tool( |
| | self, |
| | tool_name: str, |
| | tool_args: Optional[Dict[str, Any]], |
| | ) -> Tuple[str, List[EvidenceRef]]: |
| | """Execute a tool and return observation.""" |
| | if not tool_args: |
| | tool_args = {} |
| |
|
| | doc = self._current_document |
| | evidence = [] |
| |
|
| | try: |
| | if tool_name == "extract_text": |
| | return self._tool_extract_text(tool_args) |
| |
|
| | elif tool_name == "analyze_table": |
| | return await self._tool_analyze_table(tool_args) |
| |
|
| | elif tool_name == "analyze_chart": |
| | return await self._tool_analyze_chart(tool_args) |
| |
|
| | elif tool_name == "extract_fields": |
| | return await self._tool_extract_fields(tool_args) |
| |
|
| | elif tool_name == "classify_document": |
| | return self._tool_classify_document(tool_args) |
| |
|
| | elif tool_name == "search_text": |
| | return self._tool_search_text(tool_args) |
| |
|
| | else: |
| | return f"Unknown tool: {tool_name}", [] |
| |
|
| | except Exception as e: |
| | logger.error(f"Tool {tool_name} failed: {e}") |
| | return f"Error executing {tool_name}: {e}", [] |
| |
|
| | def _tool_extract_text(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| | """Extract text from pages or regions.""" |
| | doc = self._current_document |
| | page_numbers = args.get("page_numbers", list(range(doc.metadata.num_pages))) |
| |
|
| | if isinstance(page_numbers, int): |
| | page_numbers = [page_numbers] |
| |
|
| | texts = [] |
| | evidence = [] |
| |
|
| | for page in page_numbers: |
| | page_chunks = doc.get_page_chunks(page) |
| | for chunk in page_chunks: |
| | texts.append(f"[Page {page + 1}]: {chunk.text}") |
| | evidence.append(EvidenceRef( |
| | chunk_id=chunk.chunk_id, |
| | page=chunk.page, |
| | bbox=chunk.bbox, |
| | source_type="text", |
| | snippet=chunk.text[:100], |
| | confidence=chunk.confidence, |
| | )) |
| |
|
| | return "\n".join(texts[:20]), evidence[:10] |
| |
|
| | async def _tool_analyze_table(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| | """Analyze a table region.""" |
| | page = args.get("page", 0) |
| | doc = self._current_document |
| |
|
| | |
| | table_chunks = [c for c in doc.chunks if c.chunk_type.value == "table" and c.page == page] |
| |
|
| | if not table_chunks: |
| | return "No table found on this page", [] |
| |
|
| | |
| | table_text = table_chunks[0].text |
| | llm = self.llm_client.get_llm(complexity="standard") |
| |
|
| | from langchain_core.messages import HumanMessage |
| | prompt = f"Analyze this table and extract structured data as JSON:\n\n{table_text}" |
| | response = await llm.ainvoke([HumanMessage(content=prompt)]) |
| |
|
| | evidence = [EvidenceRef( |
| | chunk_id=table_chunks[0].chunk_id, |
| | page=page, |
| | bbox=table_chunks[0].bbox, |
| | source_type="table", |
| | snippet=table_text[:200], |
| | confidence=table_chunks[0].confidence, |
| | )] |
| |
|
| | return response.content, evidence |
| |
|
| | async def _tool_analyze_chart(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| | """Analyze a chart region.""" |
| | page = args.get("page", 0) |
| | doc = self._current_document |
| |
|
| | |
| | chart_chunks = [ |
| | c for c in doc.chunks |
| | if c.chunk_type.value in ("chart", "figure") and c.page == page |
| | ] |
| |
|
| | if not chart_chunks: |
| | return "No chart/figure found on this page", [] |
| |
|
| | |
| | if page in self._page_images: |
| | |
| | pass |
| |
|
| | return f"Chart found on page {page + 1}: {chart_chunks[0].caption or 'No caption'}", [] |
| |
|
| | async def _tool_extract_fields(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| | """Extract specific fields.""" |
| | schema_dict = args.get("schema", {}) |
| | doc = self._current_document |
| |
|
| | |
| | context = "\n".join(c.text for c in doc.chunks[:20]) |
| |
|
| | |
| | llm = self.llm_client.get_llm(complexity="complex") |
| |
|
| | from langchain_core.messages import HumanMessage, SystemMessage |
| | system = "Extract the requested fields from the document. Output JSON with field names as keys." |
| | user = f"Fields to extract: {json.dumps(schema_dict)}\n\nDocument content:\n{context}" |
| |
|
| | response = await llm.ainvoke([ |
| | SystemMessage(content=system), |
| | HumanMessage(content=user), |
| | ]) |
| |
|
| | return response.content, [] |
| |
|
| | def _tool_classify_document(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| | """Classify document type based on first page.""" |
| | doc = self._current_document |
| | first_page_chunks = doc.get_page_chunks(0) |
| | text = " ".join(c.text for c in first_page_chunks[:5]) |
| |
|
| | return f"First page content for classification:\n{text[:500]}", [] |
| |
|
| | def _tool_search_text(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| | """Search for text in document.""" |
| | query = args.get("query", "").lower() |
| | doc = self._current_document |
| |
|
| | matches = [] |
| | evidence = [] |
| |
|
| | for chunk in doc.chunks: |
| | if query in chunk.text.lower(): |
| | matches.append(f"[Page {chunk.page + 1}]: ...{chunk.text}...") |
| | evidence.append(EvidenceRef( |
| | chunk_id=chunk.chunk_id, |
| | page=chunk.page, |
| | bbox=chunk.bbox, |
| | source_type="text", |
| | snippet=chunk.text[:100], |
| | confidence=chunk.confidence, |
| | )) |
| |
|
| | if not matches: |
| | return f"No matches found for '{query}'", [] |
| |
|
| | return f"Found {len(matches)} matches:\n" + "\n".join(matches[:10]), evidence[:10] |
| |
|
| | def _parse_answer(self, answer_input: Optional[Dict[str, Any]]) -> Any: |
| | """Parse the final answer from tool args.""" |
| | if not answer_input: |
| | return None |
| |
|
| | if isinstance(answer_input, dict): |
| | return answer_input |
| |
|
| | return {"answer": answer_input} |
| |
|
| | def _calculate_confidence(self, steps: List[ThoughtAction]) -> float: |
| | """Calculate overall confidence from trace.""" |
| | if not steps: |
| | return 0.0 |
| |
|
| | |
| | all_evidence = [e for s in steps if s.evidence for e in s.evidence] |
| | if all_evidence: |
| | return sum(e.confidence for e in all_evidence) / len(all_evidence) |
| |
|
| | return 0.5 |
| |
|