File size: 3,333 Bytes
5e4028d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | """Document type classification via Claude vision tool-use.
One API call per document. System prompt cached via cache_control so a batch
run pays for the prompt tokens once. Returns ClassifyResult(doc_type,
confidence, reasoning).
`--no-api` returns doc_type="unknown" with a note. The original plan called
for a `facebook/bart-large-mnli` zero-shot fallback; dropped to avoid a
1.6 GB extra HF download for a feature that's only useful in offline mode.
"""
from __future__ import annotations
import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from src.postcorrect import _get_client
MODEL_ID = "claude-haiku-4-5-20251001"
DOC_TYPES = ["letter", "receipt", "ledger", "deed"]
MAX_TEXT_CHARS = 16000 # rough cap; haiku handles 200k tokens but no need to send a book
_CLASSIFY_TOOL: dict = {
"name": "classify_document",
"description": "Submit a document classification with reasoning.",
"input_schema": {
"type": "object",
"properties": {
"doc_type": {
"type": "string",
"enum": DOC_TYPES,
"description": "Best-matching document type",
},
"confidence": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
},
"reasoning": {
"type": "string",
"description": "1-2 sentences citing structural cues",
},
},
"required": ["doc_type", "confidence", "reasoning"],
},
}
@dataclass
class ClassifyResult:
doc_type: str
confidence: float
reasoning: str
@lru_cache(maxsize=1)
def _load_prompt() -> str:
p = Path(__file__).parent.parent / "prompts" / "v1" / "classify.md"
return p.read_text(encoding="utf-8")
def _truncate(text: str) -> str:
if len(text) <= MAX_TEXT_CHARS:
return text
return text[:MAX_TEXT_CHARS] + "\n\n[TRUNCATED]"
def classify(
text: str,
*,
no_api: bool = False,
model: str = MODEL_ID,
) -> ClassifyResult:
if no_api:
return ClassifyResult(
doc_type="unknown", confidence=0.0, reasoning="--no-api mode"
)
if not text.strip():
return ClassifyResult(
doc_type="unknown", confidence=0.0, reasoning="empty input text"
)
client = _get_client()
response = client.messages.create(
model=model,
max_tokens=512,
system=[
{
"type": "text",
"text": _load_prompt(),
"cache_control": {"type": "ephemeral"},
}
],
tools=[_CLASSIFY_TOOL],
tool_choice={"type": "tool", "name": "classify_document"},
messages=[{"role": "user", "content": _truncate(text)}],
)
tool_block = next((b for b in response.content if b.type == "tool_use"), None)
if tool_block is None:
print("[classify] no tool_use in response; returning unknown", file=sys.stderr)
return ClassifyResult(
doc_type="unknown", confidence=0.0, reasoning="no tool response"
)
return ClassifyResult(
doc_type=str(tool_block.input["doc_type"]),
confidence=float(tool_block.input["confidence"]),
reasoning=str(tool_block.input["reasoning"]),
)
|