File size: 1,407 Bytes
cbb1b1a
 
 
4c85df9
 
 
cbb1b1a
 
4c85df9
cbb1b1a
4c85df9
 
cbb1b1a
4c85df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
EvidenceNER inference helper.

Exposes extract_entities() used by the CMA tool `extract_entities` and by
DocumentProcessor after OCR.  The EvidenceNER instance is cached at module
level after the first call so the checkpoint is only loaded once per process.
"""

from __future__ import annotations

import logging
from typing import Optional

from src.ner.model import Entity, EvidenceNER

logger = logging.getLogger(__name__)

_DEFAULT_MODEL_DIR = "models/evidence_ner"
_ner: Optional[EvidenceNER] = None


def init_ner(model_dir: str = _DEFAULT_MODEL_DIR) -> EvidenceNER:
    """
    Explicitly initialise (or reload) the module-level EvidenceNER singleton.

    Call this once at server startup for a predictable load-time cost.
    """
    global _ner
    logger.info("Loading EvidenceNER from %s …", model_dir)
    _ner = EvidenceNER(model_dir)
    return _ner


def extract_entities(
    text: str, model_dir: str = _DEFAULT_MODEL_DIR
) -> list[Entity]:
    """
    Extract named entities from *text* and return a list of Entity spans.

    Loads the checkpoint from *model_dir* lazily on the first call and caches
    the instance for subsequent calls.

    Returns [] for empty input; never raises (caller is responsible for
    catching EvidenceNER init errors at startup via init_ner()).
    """
    global _ner
    if _ner is None:
        init_ner(model_dir)
    return _ner.extract(text)