File size: 4,499 Bytes
5e4028d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Named-entity extraction.

spaCy (`en_core_web_sm` by default) always runs and produces baseline
PERSON / DATE / GPE / ORG entities. Claude adds doc-type-aware custom
entities (sender, recipient, amount, signed_date, etc.) when the API is
available and `--no-api` isn't set. Each entity is tagged with `source` so
downstream consumers know whether to trust the label.
"""

from __future__ import annotations

import sys
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path

from src.postcorrect import _get_client

MODEL_ID = "claude-haiku-4-5-20251001"
SPACY_MODEL = "en_core_web_sm"
SPACY_LABELS = {"PERSON", "DATE", "GPE", "LOC", "ORG"}
MAX_TEXT_CHARS = 16000

_EXTRACT_TOOL: dict = {
    "name": "extract_entities",
    "description": "Extract structured entities from a transcribed document.",
    "input_schema": {
        "type": "object",
        "properties": {
            "entities": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "text": {"type": "string"},
                        "label": {"type": "string"},
                        "confidence": {
                            "type": "number",
                            "minimum": 0.0,
                            "maximum": 1.0,
                        },
                    },
                    "required": ["text", "label"],
                },
            }
        },
        "required": ["entities"],
    },
}


@dataclass
class Entity:
    text: str
    label: str
    source: str  # "spacy" | "claude"
    confidence: float | None = None

    def to_dict(self) -> dict:
        return {
            "text": self.text,
            "label": self.label,
            "source": self.source,
            "confidence": self.confidence,
        }


@lru_cache(maxsize=1)
def _load_spacy():
    import spacy

    try:
        return spacy.load(SPACY_MODEL)
    except OSError as exc:
        sys.exit(
            f"spaCy model {SPACY_MODEL!r} not found. Install with:\n"
            f"  python -m spacy download {SPACY_MODEL}\n\nDetails: {exc}"
        )


@lru_cache(maxsize=1)
def _load_prompt() -> str:
    p = Path(__file__).parent.parent / "prompts" / "v1" / "extract.md"
    return p.read_text(encoding="utf-8")


def _truncate(text: str) -> str:
    if len(text) <= MAX_TEXT_CHARS:
        return text
    return text[:MAX_TEXT_CHARS] + "\n\n[TRUNCATED]"


def _extract_spacy(text: str) -> list[Entity]:
    nlp = _load_spacy()
    doc = nlp(text)
    return [
        Entity(text=ent.text, label=ent.label_, source="spacy")
        for ent in doc.ents
        if ent.label_ in SPACY_LABELS
    ]


def _extract_claude(text: str, doc_type: str, model: str) -> list[Entity]:
    client = _get_client()
    user_msg = f"Document type: {doc_type}\n\nText:\n{_truncate(text)}"

    response = client.messages.create(
        model=model,
        max_tokens=2048,
        system=[
            {
                "type": "text",
                "text": _load_prompt(),
                "cache_control": {"type": "ephemeral"},
            }
        ],
        tools=[_EXTRACT_TOOL],
        tool_choice={"type": "tool", "name": "extract_entities"},
        messages=[{"role": "user", "content": user_msg}],
    )

    tool_block = next((b for b in response.content if b.type == "tool_use"), None)
    if tool_block is None:
        print("[ner] no tool_use in response; returning empty", file=sys.stderr)
        return []

    return [
        Entity(
            text=str(item["text"]),
            label=str(item["label"]),
            source="claude",
            confidence=float(item["confidence"]) if "confidence" in item else None,
        )
        for item in tool_block.input.get("entities", [])
    ]


def extract_entities(
    text: str,
    *,
    doc_type: str = "unknown",
    no_api: bool = False,
    model: str = MODEL_ID,
) -> list[Entity]:
    """spaCy always runs; Claude runs when available and not in --no-api mode.
    Returns the union with each entity tagged by source."""
    if not text.strip():
        return []
    entities = _extract_spacy(text)
    if not no_api:
        try:
            entities.extend(_extract_claude(text, doc_type, model))
        except Exception as exc:
            print(
                f"[ner] Claude extraction failed ({exc!r}); spaCy-only output",
                file=sys.stderr,
            )
    return entities