cmboulanger commited on
Commit
37eaffd
·
1 Parent(s): 790b4e5

full implementation

Browse files
README.md CHANGED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tei-annotator
2
+
3
+ A Python library for annotating text with [TEI XML](https://tei-c.org/) tags using a two-stage LLM pipeline.
4
+
5
+ The pipeline:
6
+
7
+ 1. **(Optional) GLiNER pre-detection** — fast CPU-based span labelling generates candidates for the LLM to verify and extend.
8
+ 2. **LLM annotation** — a prompted language model identifies entities, returns structured spans (element + verbatim text + surrounding context + attributes).
9
+ 3. **Deterministic post-processing** — spans are resolved to character offsets, validated against the schema, and injected as XML tags. The source text is **never modified** by any model call.
10
+
11
+ Works with any inference endpoint through an injected `call_fn: (str) -> str` — Anthropic, OpenAI, Gemini, a local Ollama instance, or a constrained-decoding API.
12
+
13
+ ---
14
+
15
+ ## Installation
16
+
17
+ Requires Python ≥ 3.12 and [uv](https://docs.astral.sh/uv/).
18
+
19
+ ```bash
20
+ git clone <repo>
21
+ cd tei-annotator
22
+ uv sync # installs runtime deps: jinja2, lxml, rapidfuzz
23
+ uv sync --extra gliner # also installs gliner for the optional pre-detection pass
24
+ ```
25
+
26
+ API keys for real LLM endpoints go in `.env` (see `.env` for the expected variable names).
27
+
28
+ ---
29
+
30
+ ## Quick example
31
+
32
+ ```python
33
+ from tei_annotator import (
34
+ annotate,
35
+ TEISchema, TEIElement, TEIAttribute,
36
+ EndpointConfig, EndpointCapability,
37
+ )
38
+
39
+ # 1. Describe the elements you want to annotate
40
+ schema = TEISchema(elements=[
41
+ TEIElement(
42
+ tag="persName",
43
+ description="a person's name",
44
+ attributes=[TEIAttribute(name="ref", description="authority URI")],
45
+ ),
46
+ TEIElement(
47
+ tag="placeName",
48
+ description="a geographical place name",
49
+ attributes=[],
50
+ ),
51
+ ])
52
+
53
+ # 2. Wrap your inference endpoint
54
+ def my_call_fn(prompt: str) -> str:
55
+ # replace with any LLM call — Anthropic, OpenAI, Gemini, Ollama, …
56
+ ...
57
+
58
+ endpoint = EndpointConfig(
59
+ capability=EndpointCapability.TEXT_GENERATION,
60
+ call_fn=my_call_fn,
61
+ )
62
+
63
+ # 3. Annotate
64
+ result = annotate(
65
+ text="Marie Curie was born in Warsaw and later worked in Paris.",
66
+ schema=schema,
67
+ endpoint=endpoint,
68
+ gliner_model=None, # set to e.g. "numind/NuNER_Zero" to enable pre-detection
69
+ )
70
+
71
+ print(result.xml)
72
+ # <persName>Marie Curie</persName> was born in <placeName>Warsaw</placeName>
73
+ # and later worked in <placeName>Paris</placeName>.
74
+
75
+ if result.fuzzy_spans:
76
+ print("Review these spans — context was matched approximately:")
77
+ for span in result.fuzzy_spans:
78
+ print(f" <{span.element}>{span.text}</{span.element}>")
79
+ ```
80
+
81
+ The input text may already contain XML markup; existing tags are stripped before the LLM sees the text and restored in the final output.
82
+
83
+ ### Real-endpoint smoke test
84
+
85
+ `scripts/smoke_test_llm.py` runs the full pipeline against **Gemini 2.0 Flash** and **KISSKI `llama-3.3-70b-instruct`** using API keys from `.env`:
86
+
87
+ ```bash
88
+ uv run scripts/smoke_test_llm.py
89
+ ```
90
+
91
+ ---
92
+
93
+ ## `annotate()` parameters
94
+
95
+ | Parameter | Default | Description |
96
+ | --- | --- | --- |
97
+ | `text` | — | Input text; may contain existing XML tags |
98
+ | `schema` | — | `TEISchema` describing elements and attributes in scope |
99
+ | `endpoint` | — | `EndpointConfig` wrapping any `call_fn: (str) -> str` |
100
+ | `gliner_model` | `"numind/NuNER_Zero"` | HuggingFace model for optional pre-detection; `None` to disable |
101
+ | `chunk_size` | `1500` | Maximum characters per LLM prompt chunk |
102
+ | `chunk_overlap` | `200` | Character overlap between consecutive chunks |
103
+
104
+ ### `EndpointCapability` values
105
+
106
+ | Value | When to use |
107
+ | --- | --- |
108
+ | `TEXT_GENERATION` | Plain LLM — JSON requested via prompt, with one automatic retry on parse failure |
109
+ | `JSON_ENFORCED` | Constrained-decoding endpoint that guarantees valid JSON output |
110
+ | `EXTRACTION` | Native extraction model (GLiNER2 / NuExtract-style); raw text is passed directly |
111
+
112
+ ---
113
+
114
+ ## Testing
115
+
116
+ ```bash
117
+ # Unit tests (fully mocked, < 0.1 s)
118
+ uv run pytest
119
+
120
+ # Integration tests — complex pipeline scenarios, no model download needed
121
+ uv run pytest --override-ini="addopts=" -m integration \
122
+ tests/integration/test_pipeline_e2e.py -k "not real_gliner"
123
+
124
+ # Integration tests — real GLiNER model (downloads ~400 MB on first run)
125
+ uv run pytest --override-ini="addopts=" -m integration \
126
+ tests/integration/test_gliner_detector.py \
127
+ tests/integration/test_pipeline_e2e.py::test_pipeline_with_real_gliner
128
+ ```
129
+
130
+ Integration tests are excluded from the default `pytest` run via `pyproject.toml`:
131
+
132
+ ```toml
133
+ [tool.pytest.ini_options]
134
+ addopts = "-m 'not integration'"
135
+ ```
implementation-plan.md CHANGED
@@ -340,3 +340,55 @@ def test_annotate_smoke():
340
  assert "John Smith" in result.xml
341
  assert result.xml.count("John Smith") == 1 # text not duplicated
342
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  assert "John Smith" in result.xml
341
  assert result.xml.count("John Smith") == 1 # text not duplicated
342
  ```
343
+
344
+ ---
345
+
346
+ ## Implementation Status
347
+
348
+ **Completed 2026-02-28** — full implementation per the plan above.
349
+
350
+ ### What was built
351
+
352
+ All modules in the package structure were implemented:
353
+
354
+ | File | Notes |
355
+ | --- | --- |
356
+ | `tei_annotator/models/schema.py` | `TEIAttribute`, `TEIElement`, `TEISchema` dataclasses |
357
+ | `tei_annotator/models/spans.py` | `SpanDescriptor`, `ResolvedSpan` dataclasses |
358
+ | `tei_annotator/inference/endpoint.py` | `EndpointCapability` enum, `EndpointConfig` dataclass |
359
+ | `tei_annotator/chunking/chunker.py` | `chunk_text()` — overlap chunker, XML-safe boundaries |
360
+ | `tei_annotator/detection/gliner_detector.py` | `detect_spans()` — optional, raises `ImportError` if `[gliner]` extra not installed |
361
+ | `tei_annotator/prompting/builder.py` | `build_prompt()` + `make_correction_prompt()` |
362
+ | `tei_annotator/prompting/templates/text_gen.jinja2` | Verbose prompt with JSON example, "output only JSON" instruction |
363
+ | `tei_annotator/prompting/templates/json_enforced.jinja2` | Minimal prompt for constrained-decoding endpoints |
364
+ | `tei_annotator/postprocessing/parser.py` | `parse_response()` — fence stripping, one-shot self-correction retry |
365
+ | `tei_annotator/postprocessing/resolver.py` | `resolve_spans()` — context-anchor → char offset, rapidfuzz fuzzy fallback at threshold 0.92 |
366
+ | `tei_annotator/postprocessing/validator.py` | `validate_spans()` — element, attribute name, allowed-value checks |
367
+ | `tei_annotator/postprocessing/injector.py` | `inject_xml()` — stack-based nesting tree, recursive tag insertion |
368
+ | `tei_annotator/pipeline.py` | `annotate()` — full orchestration, tag strip/restore, deduplication across chunks, lxml final validation |
369
+
370
+ ### Dependencies added
371
+
372
+ Runtime: `jinja2`, `lxml`, `rapidfuzz`. Optional extra `[gliner]` for GLiNER support. Dev: `pytest`, `pytest-cov`.
373
+
374
+ ### Tests
375
+
376
+ - **63 unit tests** (Layer 1) — fully mocked, run in < 0.1 s via `uv run pytest`
377
+ - **9 integration tests** (Layer 2, no GLiNER) — complex resolver/injector/pipeline scenarios, run via `uv run pytest --override-ini="addopts=" -m integration tests/integration/test_pipeline_e2e.py -k "not real_gliner"`
378
+ - **1 GLiNER integration test** — requires `[gliner]` extra and HuggingFace model download
379
+
380
+ ### Smoke script
381
+
382
+ `scripts/smoke_test_llm.py` — end-to-end test with real LLM calls (no GLiNER). Verified against:
383
+
384
+ - **Google Gemini 2.0 Flash** (`GEMINI_API_KEY` from `.env`)
385
+ - **KISSKI `llama-3.3-70b-instruct`** (`KISSKI_API_KEY` from `.env`, OpenAI-compatible API at `https://chat-ai.academiccloud.de/v1`)
386
+
387
+ Run with `uv run scripts/smoke_test_llm.py`.
388
+
389
+ ### Key implementation notes
390
+
391
+ - The `_strip_existing_tags` / `_restore_existing_tags` pair in `pipeline.py` preserves original markup by tracking plain-text offsets of each stripped tag and re-inserting them after annotation.
392
+ - `_build_nesting_tree` in `injector.py` uses a sort-by-(start-asc, length-desc) + stack algorithm; partial overlaps are dropped with a `warnings.warn`.
393
+ - The resolver does an exact `str.find` first; fuzzy search (sliding-window rapidfuzz) is only attempted if exact fails and rapidfuzz is installed.
394
+ - `parse_response` passes `call_fn` and `make_correction_prompt` only for `TEXT_GENERATION` endpoints; `JSON_ENFORCED` and `EXTRACTION` never retry.
pyproject.toml CHANGED
@@ -1,7 +1,30 @@
1
  [project]
2
  name = "tei-annotator"
3
  version = "0.1.0"
4
- description = "Add your description here"
5
  readme = "README.md"
6
- requires-python = "==3.12"
7
- dependencies = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  [project]
2
  name = "tei-annotator"
3
  version = "0.1.0"
4
+ description = "TEI XML annotation library using LLM pipelines"
5
  readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "jinja2>=3.1",
9
+ "lxml>=5.0",
10
+ "rapidfuzz>=3.0",
11
+ ]
12
+
13
+ [project.optional-dependencies]
14
+ gliner = ["gliner>=0.2"]
15
+
16
+ [tool.pytest.ini_options]
17
+ addopts = "-m 'not integration'"
18
+ markers = [
19
+ "integration: marks tests as integration tests (require GLiNER model download)",
20
+ ]
21
+
22
+ [dependency-groups]
23
+ dev = [
24
+ "pytest>=8.0",
25
+ "pytest-cov>=5.0",
26
+ ]
27
+
28
+ [build-system]
29
+ requires = ["hatchling"]
30
+ build-backend = "hatchling.build"
scripts/smoke_test_llm.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ End-to-end smoke test: tei-annotator pipeline with real LLM endpoints.
4
+
5
+ Providers tested:
6
+ • Google Gemini 2.0 Flash
7
+ • KISSKI (OpenAI-compatible API, llama-3.3-70b-instruct)
8
+
9
+ Reads API keys from .env in the project root.
10
+
11
+ Usage:
12
+ uv run scripts/smoke_test_llm.py
13
+ python scripts/smoke_test_llm.py # if venv is already activated
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ import sys
21
+ import textwrap
22
+ import urllib.error
23
+ import urllib.request
24
+ from pathlib import Path
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # .env loader (stdlib-only, no python-dotenv needed)
29
+ # ---------------------------------------------------------------------------
30
+
31
+ def _load_env(path: str = ".env") -> None:
32
+ try:
33
+ with open(path) as fh:
34
+ for line in fh:
35
+ line = line.strip()
36
+ if not line or line.startswith("#") or "=" not in line:
37
+ continue
38
+ key, _, value = line.partition("=")
39
+ value = value.strip().strip('"').strip("'")
40
+ os.environ.setdefault(key.strip(), value)
41
+ except FileNotFoundError:
42
+ pass
43
+
44
+
45
+ _load_env(Path(__file__).parent.parent / ".env")
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # HTTP helper (stdlib urllib)
50
+ # ---------------------------------------------------------------------------
51
+
52
+ def _post_json(url: str, payload: dict, headers: dict) -> dict:
53
+ body = json.dumps(payload).encode()
54
+ req = urllib.request.Request(url, data=body, headers=headers, method="POST")
55
+ try:
56
+ with urllib.request.urlopen(req, timeout=60) as resp:
57
+ return json.loads(resp.read())
58
+ except urllib.error.HTTPError as exc:
59
+ detail = exc.read().decode(errors="replace")
60
+ raise RuntimeError(f"HTTP {exc.code} from {url}: {detail}") from exc
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # call_fn factories
65
+ # ---------------------------------------------------------------------------
66
+
67
+ def make_gemini_call_fn(api_key: str, model: str = "gemini-2.0-flash") -> ...:
68
+ """Return a call_fn that sends a prompt to Gemini and returns the text reply."""
69
+ url = (
70
+ f"https://generativelanguage.googleapis.com/v1beta/models"
71
+ f"/{model}:generateContent?key={api_key}"
72
+ )
73
+
74
+ def call_fn(prompt: str) -> str:
75
+ payload = {
76
+ "contents": [{"parts": [{"text": prompt}]}],
77
+ "generationConfig": {"temperature": 0.1},
78
+ }
79
+ result = _post_json(url, payload, {"Content-Type": "application/json"})
80
+ return result["candidates"][0]["content"]["parts"][0]["text"]
81
+
82
+ call_fn.__name__ = f"gemini/{model}"
83
+ return call_fn
84
+
85
+
86
+ def make_kisski_call_fn(
87
+ api_key: str,
88
+ base_url: str = "https://chat-ai.academiccloud.de/v1",
89
+ model: str = "llama-3.3-70b-instruct",
90
+ ) -> ...:
91
+ """Return a call_fn that sends a prompt to a KISSKI-hosted OpenAI-compatible model."""
92
+ url = f"{base_url}/chat/completions"
93
+ headers = {
94
+ "Content-Type": "application/json",
95
+ "Authorization": f"Bearer {api_key}",
96
+ }
97
+
98
+ def call_fn(prompt: str) -> str:
99
+ payload = {
100
+ "model": model,
101
+ "messages": [{"role": "user", "content": prompt}],
102
+ "temperature": 0.1,
103
+ }
104
+ result = _post_json(url, payload, headers)
105
+ return result["choices"][0]["message"]["content"]
106
+
107
+ call_fn.__name__ = f"kisski/{model}"
108
+ return call_fn
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Test scenario
113
+ # ---------------------------------------------------------------------------
114
+
115
+ TEST_TEXT = (
116
+ "Marie Curie was born in Warsaw, Poland, and later conducted her research "
117
+ "in Paris, France. Together with her husband Pierre Curie, she discovered "
118
+ "polonium and radium."
119
+ )
120
+
121
+ # We just check that the pipeline runs and produces *some* annotation.
122
+ # Whether the LLM chose the right entities is not asserted here.
123
+ EXPECTED_TAGS = ["persName", "placeName"]
124
+
125
+
126
+ def _build_schema():
127
+ from tei_annotator.models.schema import TEIAttribute, TEIElement, TEISchema
128
+
129
+ return TEISchema(
130
+ elements=[
131
+ TEIElement(
132
+ tag="persName",
133
+ description="a person's name",
134
+ attributes=[TEIAttribute(name="ref", description="authority URI")],
135
+ ),
136
+ TEIElement(
137
+ tag="placeName",
138
+ description="a geographical place name",
139
+ attributes=[TEIAttribute(name="ref", description="authority URI")],
140
+ ),
141
+ ]
142
+ )
143
+
144
+
145
+ def run_smoke_test(provider_name: str, call_fn) -> bool:
146
+ """
147
+ Run the full annotate() pipeline with *call_fn* and print results.
148
+ Returns True on success, False on failure.
149
+ """
150
+ import re
151
+
152
+ from tei_annotator.inference.endpoint import EndpointCapability, EndpointConfig
153
+ from tei_annotator.pipeline import annotate
154
+
155
+ print(f"\n{'─' * 60}")
156
+ print(f" Provider : {provider_name}")
157
+ print(f" Input : {TEST_TEXT[:80]}…")
158
+ print(f"{'─' * 60}")
159
+
160
+ try:
161
+ result = annotate(
162
+ text=TEST_TEXT,
163
+ schema=_build_schema(),
164
+ endpoint=EndpointConfig(
165
+ capability=EndpointCapability.TEXT_GENERATION,
166
+ call_fn=call_fn,
167
+ ),
168
+ gliner_model=None, # skip GLiNER for speed
169
+ )
170
+ except Exception as exc:
171
+ print(f" ✗ FAILED — exception during annotate(): {exc}")
172
+ return False
173
+
174
+ # Verify plain text is unmodified
175
+ plain = re.sub(r"<[^>]+>", "", result.xml)
176
+ if plain != TEST_TEXT:
177
+ print(f" ✗ FAILED — plain text was modified by the pipeline")
178
+ print(f" Expected : {TEST_TEXT!r}")
179
+ print(f" Got : {plain!r}")
180
+ return False
181
+
182
+ # Verify at least one annotation was injected (LLM must have found something)
183
+ has_any_tag = any(f"<{t}>" in result.xml for t in EXPECTED_TAGS)
184
+ if not has_any_tag:
185
+ print(f" ✗ FAILED — no annotation tags found in output")
186
+ print(f" Output XML: {result.xml}")
187
+ return False
188
+
189
+ # Pretty-print the result
190
+ tags_found = [t for t in EXPECTED_TAGS if f"<{t}>" in result.xml]
191
+ print(f" ✓ PASSED")
192
+ print(f" Tags found : {', '.join(tags_found)}")
193
+ if result.fuzzy_spans:
194
+ print(f" Fuzzy spans: {[s.text for s in result.fuzzy_spans]}")
195
+ print(f" Output XML :")
196
+ for line in textwrap.wrap(result.xml, width=72, subsequent_indent=" "):
197
+ print(f" {line}")
198
+ return True
199
+
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # Main
203
+ # ---------------------------------------------------------------------------
204
+
205
+ def main() -> int:
206
+ gemini_key = os.environ.get("GEMINI_API_KEY", "")
207
+ kisski_key = os.environ.get("KISSKI_API_KEY", "")
208
+
209
+ if not gemini_key:
210
+ print("ERROR: GEMINI_API_KEY not set (check .env)", file=sys.stderr)
211
+ return 1
212
+ if not kisski_key:
213
+ print("ERROR: KISSKI_API_KEY not set (check .env)", file=sys.stderr)
214
+ return 1
215
+
216
+ providers: list[tuple[str, object]] = [
217
+ ("Gemini 2.0 Flash", make_gemini_call_fn(gemini_key)),
218
+ ("KISSKI / llama-3.3-70b-instruct", make_kisski_call_fn(kisski_key)),
219
+ ]
220
+
221
+ results: list[bool] = []
222
+ for name, fn in providers:
223
+ results.append(run_smoke_test(name, fn))
224
+
225
+ print(f"\n{'═' * 60}")
226
+ passed = sum(results)
227
+ total = len(results)
228
+ print(f" Result: {passed}/{total} providers passed")
229
+ print(f"{'═' * 60}\n")
230
+
231
+ return 0 if all(results) else 1
232
+
233
+
234
+ if __name__ == "__main__":
235
+ sys.exit(main())
tei_annotator/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tei-annotator: TEI XML annotation library using a two-stage LLM pipeline.
3
+ """
4
+
5
+ from .inference.endpoint import EndpointCapability, EndpointConfig
6
+ from .models.schema import TEIAttribute, TEIElement, TEISchema
7
+ from .models.spans import ResolvedSpan, SpanDescriptor
8
+ from .pipeline import AnnotationResult, annotate
9
+
10
+ __all__ = [
11
+ "annotate",
12
+ "AnnotationResult",
13
+ "TEISchema",
14
+ "TEIElement",
15
+ "TEIAttribute",
16
+ "SpanDescriptor",
17
+ "ResolvedSpan",
18
+ "EndpointConfig",
19
+ "EndpointCapability",
20
+ ]
tei_annotator/chunking/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .chunker import Chunk, chunk_text
2
+
3
+ __all__ = ["Chunk", "chunk_text"]
tei_annotator/chunking/chunker.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class Chunk:
9
+ text: str
10
+ start_offset: int # position of text[0] in the original source
11
+
12
+
13
+ def chunk_text(
14
+ text: str,
15
+ chunk_size: int = 1500,
16
+ overlap: int = 200,
17
+ ) -> list[Chunk]:
18
+ """
19
+ Split text into overlapping chunks, never splitting inside an XML tag.
20
+
21
+ Each chunk's start_offset satisfies:
22
+ original_text[chunk.start_offset : chunk.start_offset + len(chunk.text)] == chunk.text
23
+ """
24
+ if len(text) <= chunk_size:
25
+ return [Chunk(text=text, start_offset=0)]
26
+
27
+ # Build a set of character positions that are inside XML tags (inclusive).
28
+ tag_positions: set[int] = set()
29
+ for m in re.finditer(r"<[^>]*>", text):
30
+ tag_positions.update(range(m.start(), m.end()))
31
+
32
+ chunks: list[Chunk] = []
33
+ start = 0
34
+
35
+ while start < len(text):
36
+ end = min(start + chunk_size, len(text))
37
+
38
+ if end < len(text):
39
+ # Step back out of any XML tag
40
+ candidate = end
41
+ while candidate > start and candidate in tag_positions:
42
+ candidate -= 1
43
+
44
+ # Try to break at a whitespace boundary near the target
45
+ break_pos = candidate
46
+ for i in range(candidate, max(start, candidate - 100), -1):
47
+ if i not in tag_positions and text[i].isspace():
48
+ break_pos = i + 1
49
+ break
50
+
51
+ end = max(start + 1, break_pos) # guarantee forward progress
52
+
53
+ chunks.append(Chunk(text=text[start:end], start_offset=start))
54
+
55
+ if end >= len(text):
56
+ break
57
+
58
+ next_start = end - overlap
59
+ if next_start <= start:
60
+ next_start = end
61
+ start = next_start
62
+
63
+ return chunks
tei_annotator/detection/__init__.py ADDED
File without changes
tei_annotator/detection/gliner_detector.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from ..models.schema import TEISchema
4
+ from ..models.spans import SpanDescriptor
5
+
6
+ try:
7
+ from gliner import GLiNER as _GLiNER
8
+ except ImportError as _e:
9
+ raise ImportError(
10
+ "The 'gliner' package is required for GLiNER detection. "
11
+ "Install it with: pip install tei-annotator[gliner]"
12
+ ) from _e
13
+
14
+
15
+ def detect_spans(
16
+ text: str,
17
+ schema: TEISchema,
18
+ model_id: str = "numind/NuNER_Zero",
19
+ ) -> list[SpanDescriptor]:
20
+ """
21
+ Detect entity spans in *text* using a GLiNER model.
22
+
23
+ Model weights are fetched from HuggingFace Hub on first use and cached in
24
+ ~/.cache/huggingface/. All listed models run on CPU; no GPU required.
25
+
26
+ Recommended model_id values:
27
+ - "numind/NuNER_Zero" (MIT, default)
28
+ - "urchade/gliner_medium-v2.1" (Apache-2.0, balanced)
29
+ - "knowledgator/gliner-multitask-large-v0.5" (adds relation extraction)
30
+ """
31
+ model = _GLiNER.from_pretrained(model_id)
32
+
33
+ # Map TEI element descriptions to their tags
34
+ labels = [elem.description for elem in schema.elements]
35
+ tag_for_label = {elem.description: elem.tag for elem in schema.elements}
36
+
37
+ entities = model.predict_entities(text, labels)
38
+
39
+ spans: list[SpanDescriptor] = []
40
+ for entity in entities:
41
+ ctx_start = max(0, entity["start"] - 60)
42
+ ctx_end = min(len(text), entity["end"] + 60)
43
+ context = text[ctx_start:ctx_end]
44
+
45
+ tag = tag_for_label.get(entity["label"], entity["label"])
46
+ spans.append(
47
+ SpanDescriptor(
48
+ element=tag,
49
+ text=entity["text"],
50
+ context=context,
51
+ attrs={},
52
+ confidence=entity.get("score"),
53
+ )
54
+ )
55
+
56
+ return spans
tei_annotator/inference/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .endpoint import EndpointCapability, EndpointConfig
2
+
3
+ __all__ = ["EndpointCapability", "EndpointConfig"]
tei_annotator/inference/endpoint.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Callable
6
+
7
+
8
+ class EndpointCapability(Enum):
9
+ TEXT_GENERATION = "text_generation" # plain LLM, JSON via prompt only
10
+ JSON_ENFORCED = "json_enforced" # constrained decoding guaranteed
11
+ EXTRACTION = "extraction" # GLiNER2/NuExtract-style native
12
+
13
+
14
+ @dataclass
15
+ class EndpointConfig:
16
+ capability: EndpointCapability
17
+ call_fn: Callable[[str], str]
18
+ # call_fn signature: takes a prompt string, returns a response string.
19
+ # Caller is responsible for auth, model selection, and retries.
tei_annotator/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .schema import TEIAttribute, TEIElement, TEISchema
2
+ from .spans import SpanDescriptor, ResolvedSpan
3
+
4
+ __all__ = [
5
+ "TEIAttribute",
6
+ "TEIElement",
7
+ "TEISchema",
8
+ "SpanDescriptor",
9
+ "ResolvedSpan",
10
+ ]
tei_annotator/models/schema.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class TEIAttribute:
8
+ name: str
9
+ description: str
10
+ required: bool = False
11
+ allowed_values: list[str] | None = None
12
+
13
+
14
+ @dataclass
15
+ class TEIElement:
16
+ tag: str
17
+ description: str
18
+ allowed_children: list[str] = field(default_factory=list)
19
+ attributes: list[TEIAttribute] = field(default_factory=list)
20
+
21
+
22
+ @dataclass
23
+ class TEISchema:
24
+ elements: list[TEIElement] = field(default_factory=list)
25
+
26
+ def get(self, tag: str) -> TEIElement | None:
27
+ for elem in self.elements:
28
+ if elem.tag == tag:
29
+ return elem
30
+ return None
tei_annotator/models/spans.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+
6
+ @dataclass
7
+ class SpanDescriptor:
8
+ """Flat span emitted by the LLM or GLiNER — always context-anchored, never nested."""
9
+ element: str
10
+ text: str
11
+ context: str # must contain text as a substring
12
+ attrs: dict[str, str] = field(default_factory=dict)
13
+ confidence: float | None = None # passed through from GLiNER
14
+
15
+
16
+ @dataclass
17
+ class ResolvedSpan:
18
+ """Span resolved to absolute char offsets in the source text."""
19
+ element: str
20
+ start: int
21
+ end: int
22
+ attrs: dict[str, str] = field(default_factory=dict)
23
+ children: list[ResolvedSpan] = field(default_factory=list)
24
+ fuzzy_match: bool = False # flagged for human review
tei_annotator/pipeline.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from dataclasses import dataclass, field
5
+
6
+ from .chunking.chunker import chunk_text
7
+ from .inference.endpoint import EndpointCapability, EndpointConfig
8
+ from .models.schema import TEISchema
9
+ from .models.spans import ResolvedSpan, SpanDescriptor
10
+ from .postprocessing.injector import inject_xml
11
+ from .postprocessing.parser import parse_response
12
+ from .postprocessing.resolver import resolve_spans
13
+ from .postprocessing.validator import validate_spans
14
+ from .prompting.builder import build_prompt, make_correction_prompt
15
+
16
+
17
+ @dataclass
18
+ class AnnotationResult:
19
+ xml: str
20
+ fuzzy_spans: list[ResolvedSpan] = field(default_factory=list)
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Internal helpers
25
+ # ---------------------------------------------------------------------------
26
+
27
+ @dataclass
28
+ class _TagEntry:
29
+ plain_offset: int # position in plain text before which this tag should be re-inserted
30
+ tag: str
31
+
32
+
33
+ def _strip_existing_tags(text: str) -> tuple[str, list[_TagEntry]]:
34
+ """
35
+ Remove XML tags from *text*.
36
+
37
+ Returns (plain_text, restore_map) where restore_map records each stripped
38
+ tag and the plain-text offset at which it should be re-inserted.
39
+ """
40
+ plain: list[str] = []
41
+ restore: list[_TagEntry] = []
42
+ i = 0
43
+ while i < len(text):
44
+ if text[i] == "<":
45
+ j = text.find(">", i)
46
+ if j != -1:
47
+ restore.append(_TagEntry(plain_offset=len(plain), tag=text[i : j + 1]))
48
+ i = j + 1
49
+ else:
50
+ plain.append(text[i])
51
+ i += 1
52
+ else:
53
+ plain.append(text[i])
54
+ i += 1
55
+ return "".join(plain), restore
56
+
57
+
58
+ def _restore_existing_tags(annotated_xml: str, restore_map: list[_TagEntry]) -> str:
59
+ """
60
+ Re-insert original XML tags into *annotated_xml*.
61
+
62
+ The tags are keyed by their position in the *plain text* (before annotation),
63
+ so we walk the annotated XML tracking plain-text position (i.e. advancing only
64
+ on non-tag characters).
65
+ """
66
+ if not restore_map:
67
+ return annotated_xml
68
+
69
+ inserts: dict[int, list[str]] = {}
70
+ for entry in restore_map:
71
+ inserts.setdefault(entry.plain_offset, []).append(entry.tag)
72
+
73
+ result: list[str] = []
74
+ plain_pos = 0
75
+ i = 0
76
+
77
+ while i < len(annotated_xml):
78
+ # Flush any original tags due at the current plain position
79
+ for tag in inserts.pop(plain_pos, []):
80
+ result.append(tag)
81
+
82
+ if annotated_xml[i] == "<":
83
+ # Existing (newly injected) tag — copy verbatim, don't advance plain_pos
84
+ j = annotated_xml.find(">", i)
85
+ if j != -1:
86
+ result.append(annotated_xml[i : j + 1])
87
+ i = j + 1
88
+ else:
89
+ result.append(annotated_xml[i])
90
+ plain_pos += 1
91
+ i += 1
92
+ else:
93
+ result.append(annotated_xml[i])
94
+ plain_pos += 1
95
+ i += 1
96
+
97
+ # Flush any remaining original tags (e.g. trailing tags in the original)
98
+ for pos in sorted(inserts.keys()):
99
+ for tag in inserts[pos]:
100
+ result.append(tag)
101
+
102
+ return "".join(result)
103
+
104
+
105
+ def _run_gliner(
106
+ text: str,
107
+ schema: TEISchema,
108
+ model_id: str,
109
+ ) -> list[SpanDescriptor]:
110
+ """Run GLiNER detection; returns [] if the optional dependency is missing."""
111
+ try:
112
+ from .detection.gliner_detector import detect_spans
113
+
114
+ return detect_spans(text, schema, model_id)
115
+ except ImportError:
116
+ warnings.warn(
117
+ "gliner is not installed; skipping GLiNER pre-detection pass. "
118
+ "Install it with: pip install tei-annotator[gliner]",
119
+ stacklevel=3,
120
+ )
121
+ return []
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Public API
126
+ # ---------------------------------------------------------------------------
127
+
128
+
129
+ def annotate(
130
+ text: str,
131
+ schema: TEISchema,
132
+ endpoint: EndpointConfig,
133
+ gliner_model: str | None = "numind/NuNER_Zero",
134
+ chunk_size: int = 1500,
135
+ chunk_overlap: int = 200,
136
+ ) -> AnnotationResult:
137
+ """
138
+ Annotate *text* with TEI XML tags using a two-stage LLM pipeline.
139
+
140
+ The source text is **never modified** — models only contribute tag positions
141
+ and attribute values. All text in the output comes from the original input.
142
+
143
+ Parameters
144
+ ----------
145
+ text:
146
+ Input text, which may already contain partial XML markup.
147
+ schema:
148
+ A TEISchema describing the elements (and their attributes) in scope.
149
+ endpoint:
150
+ Injected inference dependency (wraps any call_fn: str → str).
151
+ gliner_model:
152
+ HuggingFace model ID for the optional GLiNER pre-detection pass.
153
+ Pass None to disable.
154
+ chunk_size:
155
+ Maximum characters per chunk sent to the LLM.
156
+ chunk_overlap:
157
+ Characters of overlap between consecutive chunks.
158
+ """
159
+ # ------------------------------------------------------------------ #
160
+ # STEP 1 Strip existing XML tags; save restoration map #
161
+ # ------------------------------------------------------------------ #
162
+ plain_text, restore_map = _strip_existing_tags(text)
163
+
164
+ # ------------------------------------------------------------------ #
165
+ # STEP 2 Optional GLiNER pre-detection pass #
166
+ # ------------------------------------------------------------------ #
167
+ gliner_candidates: list[SpanDescriptor] = []
168
+ if (
169
+ gliner_model is not None
170
+ and endpoint.capability != EndpointCapability.EXTRACTION
171
+ and len(plain_text) > 200
172
+ ):
173
+ gliner_candidates = _run_gliner(plain_text, schema, gliner_model)
174
+
175
+ # ------------------------------------------------------------------ #
176
+ # STEPS 3–5 Chunk → prompt → infer → postprocess #
177
+ # ------------------------------------------------------------------ #
178
+ chunks = chunk_text(plain_text, chunk_size=chunk_size, overlap=chunk_overlap)
179
+ all_resolved: list[ResolvedSpan] = []
180
+
181
+ for chunk in chunks:
182
+ # Narrow GLiNER candidates to those plausibly within this chunk
183
+ chunk_candidates: list[SpanDescriptor] | None = None
184
+ if gliner_candidates:
185
+ chunk_candidates = [
186
+ c
187
+ for c in gliner_candidates
188
+ if c.context and chunk.text.find(c.context[:30]) != -1
189
+ ] or None
190
+
191
+ # 3. Build prompt / raw request
192
+ if endpoint.capability == EndpointCapability.EXTRACTION:
193
+ raw_response = endpoint.call_fn(chunk.text)
194
+ else:
195
+ prompt = build_prompt(
196
+ source_text=chunk.text,
197
+ schema=schema,
198
+ capability=endpoint.capability,
199
+ candidates=chunk_candidates,
200
+ )
201
+ raw_response = endpoint.call_fn(prompt)
202
+
203
+ # 4. Parse response → SpanDescriptors
204
+ retry_fn = (
205
+ endpoint.call_fn
206
+ if endpoint.capability == EndpointCapability.TEXT_GENERATION
207
+ else None
208
+ )
209
+ correction_fn = (
210
+ make_correction_prompt
211
+ if endpoint.capability == EndpointCapability.TEXT_GENERATION
212
+ else None
213
+ )
214
+ try:
215
+ span_descs = parse_response(
216
+ raw_response,
217
+ call_fn=retry_fn,
218
+ make_correction_prompt=correction_fn,
219
+ )
220
+ except ValueError:
221
+ warnings.warn(
222
+ f"Could not parse LLM response for chunk at offset "
223
+ f"{chunk.start_offset}; skipping chunk.",
224
+ stacklevel=2,
225
+ )
226
+ continue
227
+
228
+ # 5a. Resolve within chunk text → positions relative to chunk
229
+ chunk_resolved = resolve_spans(chunk.text, span_descs)
230
+
231
+ # 5b. Shift to global (plain_text) offsets
232
+ for span in chunk_resolved:
233
+ span.start += chunk.start_offset
234
+ span.end += chunk.start_offset
235
+
236
+ # 5c. Validate against schema
237
+ chunk_resolved = validate_spans(chunk_resolved, schema, plain_text)
238
+
239
+ all_resolved.extend(chunk_resolved)
240
+
241
+ # ------------------------------------------------------------------ #
242
+ # Deduplicate spans that appeared in overlapping chunks #
243
+ # ------------------------------------------------------------------ #
244
+ seen: set[tuple[str, int, int]] = set()
245
+ deduped: list[ResolvedSpan] = []
246
+ for span in all_resolved:
247
+ key = (span.element, span.start, span.end)
248
+ if key not in seen:
249
+ seen.add(key)
250
+ deduped.append(span)
251
+
252
+ # ------------------------------------------------------------------ #
253
+ # STEP 5d Inject XML tags into the plain text #
254
+ # ------------------------------------------------------------------ #
255
+ annotated_text = inject_xml(plain_text, deduped)
256
+
257
+ # ------------------------------------------------------------------ #
258
+ # STEP 5d (cont.) Restore original XML tags #
259
+ # ------------------------------------------------------------------ #
260
+ final_xml = _restore_existing_tags(annotated_text, restore_map)
261
+
262
+ # ------------------------------------------------------------------ #
263
+ # STEP 5e Final XML validation (best-effort) #
264
+ # ------------------------------------------------------------------ #
265
+ try:
266
+ from lxml import etree
267
+
268
+ try:
269
+ etree.fromstring(f"<_root>{final_xml}</_root>".encode())
270
+ except etree.XMLSyntaxError as exc:
271
+ warnings.warn(f"Output XML validation failed: {exc}", stacklevel=2)
272
+ except ImportError:
273
+ pass
274
+
275
+ return AnnotationResult(
276
+ xml=final_xml,
277
+ fuzzy_spans=[s for s in deduped if s.fuzzy_match],
278
+ )
tei_annotator/postprocessing/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .injector import inject_xml
2
+ from .parser import parse_response
3
+ from .resolver import resolve_spans
4
+ from .validator import validate_spans
5
+
6
+ __all__ = ["inject_xml", "parse_response", "resolve_spans", "validate_spans"]
tei_annotator/postprocessing/injector.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+
5
+ from ..models.spans import ResolvedSpan
6
+
7
+
8
+ def _build_nesting_tree(flat_spans: list[ResolvedSpan]) -> list[ResolvedSpan]:
9
+ """
10
+ Populate ResolvedSpan.children based on offset containment and return root spans.
11
+
12
+ Spans are sorted so that outer (longer) spans are processed before inner ones.
13
+ Overlapping (non-nesting) spans are skipped with a warning.
14
+ """
15
+ # Sort: start asc, then end desc so outer spans come before inner at same start
16
+ spans = sorted(flat_spans, key=lambda s: (s.start, -(s.end - s.start)))
17
+
18
+ # Clear any children left from a previous call
19
+ for s in spans:
20
+ s.children = []
21
+
22
+ roots: list[ResolvedSpan] = []
23
+ stack: list[ResolvedSpan] = []
24
+
25
+ for span in spans:
26
+ rejected = False
27
+
28
+ # Pop stack entries that are fully before (or incompatibly overlap) this span
29
+ while stack:
30
+ top = stack[-1]
31
+ if top.start <= span.start and span.end <= top.end:
32
+ break # top properly contains span → it's the parent
33
+ elif span.start >= top.end:
34
+ stack.pop() # span comes after top → pop and continue
35
+ else:
36
+ # Partial overlap (neither contained nor after) → reject span
37
+ warnings.warn(
38
+ f"Overlapping spans [{top.start},{top.end}] and "
39
+ f"[{span.start},{span.end}] cannot be nested; "
40
+ f"skipping <{span.element}> span.",
41
+ stacklevel=3,
42
+ )
43
+ rejected = True
44
+ break
45
+
46
+ if rejected:
47
+ continue
48
+
49
+ if stack:
50
+ stack[-1].children.append(span)
51
+ else:
52
+ roots.append(span)
53
+
54
+ stack.append(span)
55
+
56
+ return roots
57
+
58
+
59
+ def _inject_recursive(
60
+ text: str,
61
+ spans: list[ResolvedSpan],
62
+ offset: int,
63
+ ) -> str:
64
+ """
65
+ Insert XML open/close tags for *spans* into *text*.
66
+
67
+ *offset* is the absolute position of text[0] in the original source, used
68
+ to translate span.start/end (absolute) to positions within *text*.
69
+ """
70
+ if not spans:
71
+ return text
72
+
73
+ result: list[str] = []
74
+ cursor = 0 # relative position within text
75
+
76
+ for span in sorted(spans, key=lambda s: s.start):
77
+ rel_start = span.start - offset
78
+ rel_end = span.end - offset
79
+
80
+ # Text before this span
81
+ result.append(text[cursor:rel_start])
82
+
83
+ # Build tag strings
84
+ attrs_str = " ".join(f'{k}="{v}"' for k, v in span.attrs.items())
85
+ open_tag = f"<{span.element}" + (f" {attrs_str}" if attrs_str else "") + ">"
86
+ close_tag = f"</{span.element}>"
87
+
88
+ # Recursively inject children inside this span's content
89
+ inner = text[rel_start:rel_end]
90
+ if span.children:
91
+ inner = _inject_recursive(inner, span.children, offset=span.start)
92
+
93
+ result.append(open_tag)
94
+ result.append(inner)
95
+ result.append(close_tag)
96
+
97
+ cursor = rel_end
98
+
99
+ result.append(text[cursor:])
100
+ return "".join(result)
101
+
102
+
103
+ def inject_xml(source: str, spans: list[ResolvedSpan]) -> str:
104
+ """
105
+ Insert XML tags into *source* at the positions defined by *spans*.
106
+
107
+ Nesting is inferred from offset containment via _build_nesting_tree.
108
+ """
109
+ if not spans:
110
+ return source
111
+ root_spans = _build_nesting_tree(spans)
112
+ return _inject_recursive(source, root_spans, offset=0)
tei_annotator/postprocessing/parser.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from typing import Callable
6
+
7
+ from ..models.spans import SpanDescriptor
8
+
9
+
10
+ def _strip_fences(text: str) -> str:
11
+ """Remove markdown code fences, even if preceded by explanatory text."""
12
+ text = text.strip()
13
+ m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", text, re.DOTALL)
14
+ if m:
15
+ return m.group(1).strip()
16
+ return text
17
+
18
+
19
+ def _parse_json_list(text: str) -> list[dict] | None:
20
+ """Parse text as a JSON list; return None on failure."""
21
+ try:
22
+ result = json.loads(text)
23
+ return result if isinstance(result, list) else None
24
+ except json.JSONDecodeError:
25
+ return None
26
+
27
+
28
+ def _dicts_to_spans(raw: list[dict]) -> list[SpanDescriptor]:
29
+ spans: list[SpanDescriptor] = []
30
+ for item in raw:
31
+ if not isinstance(item, dict):
32
+ continue
33
+ element = item.get("element", "")
34
+ text = item.get("text", "")
35
+ context = item.get("context", "")
36
+ attrs = item.get("attrs", {})
37
+ if not (element and text and context):
38
+ continue
39
+ spans.append(
40
+ SpanDescriptor(
41
+ element=element,
42
+ text=text,
43
+ context=context,
44
+ attrs=attrs if isinstance(attrs, dict) else {},
45
+ )
46
+ )
47
+ return spans
48
+
49
+
50
+ def parse_response(
51
+ response: str,
52
+ call_fn: Callable[[str], str] | None = None,
53
+ make_correction_prompt: Callable[[str, str], str] | None = None,
54
+ ) -> list[SpanDescriptor]:
55
+ """
56
+ Parse an LLM response string into a list of SpanDescriptors.
57
+
58
+ - Strips markdown code fences automatically.
59
+ - If parsing fails and *call_fn* + *make_correction_prompt* are provided,
60
+ retries once with a self-correction prompt that includes the bad response.
61
+ - Raises ValueError if parsing fails after the retry (or if no retry is configured).
62
+ """
63
+ cleaned = _strip_fences(response)
64
+ raw = _parse_json_list(cleaned)
65
+ if raw is not None:
66
+ return _dicts_to_spans(raw)
67
+
68
+ if call_fn is None or make_correction_prompt is None:
69
+ raise ValueError(f"Failed to parse JSON from response: {response[:300]!r}")
70
+
71
+ error_msg = "Response is not valid JSON"
72
+ correction_prompt = make_correction_prompt(response, error_msg)
73
+ retry_response = call_fn(correction_prompt)
74
+ retry_cleaned = _strip_fences(retry_response)
75
+ raw = _parse_json_list(retry_cleaned)
76
+ if raw is None:
77
+ raise ValueError(f"Failed to parse JSON after retry: {retry_response[:300]!r}")
78
+ return _dicts_to_spans(raw)
tei_annotator/postprocessing/resolver.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from ..models.spans import ResolvedSpan, SpanDescriptor
4
+
5
+ try:
6
+ from rapidfuzz import fuzz as _fuzz
7
+
8
+ _HAS_RAPIDFUZZ = True
9
+ except ImportError:
10
+ _HAS_RAPIDFUZZ = False
11
+
12
+
13
+ def _find_context(
14
+ source: str,
15
+ context: str,
16
+ threshold: float,
17
+ ) -> tuple[int, bool] | None:
18
+ """
19
+ Locate *context* in *source*.
20
+
21
+ Returns (start_pos, is_fuzzy):
22
+ - (pos, False) on exact match
23
+ - (pos, True) on fuzzy match with score >= threshold
24
+ - None if not found or below threshold
25
+ """
26
+ pos = source.find(context)
27
+ if pos != -1:
28
+ return pos, False
29
+
30
+ if not _HAS_RAPIDFUZZ or not context:
31
+ return None
32
+
33
+ win = len(context)
34
+ if win > len(source):
35
+ return None
36
+
37
+ best_score = 0.0
38
+ best_pos = -1
39
+ for i in range(len(source) - win + 1):
40
+ score = _fuzz.ratio(context, source[i : i + win]) / 100.0
41
+ if score > best_score:
42
+ best_score = score
43
+ best_pos = i
44
+
45
+ if best_score >= threshold:
46
+ return best_pos, True
47
+ return None
48
+
49
+
50
+ def resolve_spans(
51
+ source: str,
52
+ spans: list[SpanDescriptor],
53
+ fuzzy_threshold: float = 0.92,
54
+ ) -> list[ResolvedSpan]:
55
+ """
56
+ Convert context-anchored SpanDescriptors to char-offset ResolvedSpans.
57
+
58
+ Rejects spans whose text cannot be reliably located in *source*.
59
+ Spans that required fuzzy context matching are flagged with fuzzy_match=True.
60
+ """
61
+ resolved: list[ResolvedSpan] = []
62
+
63
+ for span in spans:
64
+ result = _find_context(source, span.context, fuzzy_threshold)
65
+ if result is None:
66
+ continue # context not found → reject
67
+
68
+ ctx_start, context_is_fuzzy = result
69
+
70
+ # Find span.text within the located context window
71
+ window = source[ctx_start : ctx_start + len(span.context)]
72
+ text_pos = window.find(span.text)
73
+ if text_pos == -1:
74
+ continue # text not in context window → reject
75
+
76
+ abs_start = ctx_start + text_pos
77
+ abs_end = abs_start + len(span.text)
78
+
79
+ # Verify verbatim match (should always hold after exact context find,
80
+ # but important guard after fuzzy context find)
81
+ if source[abs_start:abs_end] != span.text:
82
+ continue
83
+
84
+ resolved.append(
85
+ ResolvedSpan(
86
+ element=span.element,
87
+ start=abs_start,
88
+ end=abs_end,
89
+ attrs=span.attrs.copy(),
90
+ children=[],
91
+ fuzzy_match=context_is_fuzzy,
92
+ )
93
+ )
94
+
95
+ return resolved
tei_annotator/postprocessing/validator.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from ..models.schema import TEISchema
4
+ from ..models.spans import ResolvedSpan
5
+
6
+
7
+ def validate_spans(
8
+ spans: list[ResolvedSpan],
9
+ schema: TEISchema,
10
+ source: str,
11
+ ) -> list[ResolvedSpan]:
12
+ """
13
+ Filter out spans that fail schema validation.
14
+
15
+ Rejected when:
16
+ - element is not in the schema
17
+ - an attribute name is not listed for that element
18
+ - an attribute value is not in the element's allowed_values (when constrained)
19
+ - span bounds are out of range
20
+ """
21
+ valid: list[ResolvedSpan] = []
22
+
23
+ for span in spans:
24
+ # Bounds sanity check
25
+ if span.start < 0 or span.end > len(source) or span.start >= span.end:
26
+ continue
27
+
28
+ elem = schema.get(span.element)
29
+ if elem is None:
30
+ continue # element not in schema
31
+
32
+ allowed_names = {a.name for a in elem.attributes}
33
+ attr_ok = True
34
+ for attr_name, attr_value in span.attrs.items():
35
+ if attr_name not in allowed_names:
36
+ attr_ok = False
37
+ break
38
+ attr_def = next((a for a in elem.attributes if a.name == attr_name), None)
39
+ if attr_def and attr_def.allowed_values is not None:
40
+ if attr_value not in attr_def.allowed_values:
41
+ attr_ok = False
42
+ break
43
+
44
+ if not attr_ok:
45
+ continue
46
+
47
+ valid.append(span)
48
+
49
+ return valid
tei_annotator/prompting/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .builder import build_prompt, make_correction_prompt
2
+
3
+ __all__ = ["build_prompt", "make_correction_prompt"]
tei_annotator/prompting/builder.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from ..inference.endpoint import EndpointCapability
7
+ from ..models.schema import TEISchema
8
+ from ..models.spans import SpanDescriptor
9
+
10
+ try:
11
+ from jinja2 import Environment, FileSystemLoader
12
+
13
+ _HAS_JINJA = True
14
+ except ImportError:
15
+ _HAS_JINJA = False
16
+
17
+ _TEMPLATE_DIR = Path(__file__).parent / "templates"
18
+
19
+
20
+ def _get_env() -> "Environment":
21
+ if not _HAS_JINJA:
22
+ raise ImportError(
23
+ "jinja2 is required for prompt building. Install it with: pip install jinja2"
24
+ )
25
+ env = Environment(loader=FileSystemLoader(str(_TEMPLATE_DIR)), keep_trailing_newline=True)
26
+ env.filters["tojson"] = lambda x, **kw: json.dumps(x, ensure_ascii=False, **kw)
27
+ return env
28
+
29
+
30
+ def build_prompt(
31
+ source_text: str,
32
+ schema: TEISchema,
33
+ capability: EndpointCapability,
34
+ candidates: list[SpanDescriptor] | None = None,
35
+ ) -> str:
36
+ """
37
+ Build an LLM prompt for the given endpoint capability.
38
+
39
+ Raises ValueError for EXTRACTION endpoints (they don't use text prompts).
40
+ """
41
+ if capability == EndpointCapability.EXTRACTION:
42
+ raise ValueError(
43
+ "EXTRACTION endpoints use their own native format; no text prompt needed."
44
+ )
45
+
46
+ env = _get_env()
47
+ template_name = (
48
+ "text_gen.jinja2"
49
+ if capability == EndpointCapability.TEXT_GENERATION
50
+ else "json_enforced.jinja2"
51
+ )
52
+ template = env.get_template(template_name)
53
+
54
+ candidate_dicts: list[dict] | None = None
55
+ if candidates:
56
+ candidate_dicts = [
57
+ {
58
+ "element": c.element,
59
+ "text": c.text,
60
+ "context": c.context,
61
+ "attrs": c.attrs,
62
+ **({"confidence": c.confidence} if c.confidence is not None else {}),
63
+ }
64
+ for c in candidates
65
+ ]
66
+
67
+ return template.render(
68
+ schema=schema,
69
+ source_text=source_text,
70
+ candidates=candidate_dicts,
71
+ )
72
+
73
+
74
+ def make_correction_prompt(original_response: str, error_message: str) -> str:
75
+ """Build a self-correction retry prompt that includes the bad response and the error."""
76
+ return (
77
+ "Your previous response could not be parsed as JSON.\n"
78
+ f"Error: {error_message}\n\n"
79
+ f"Your previous response was:\n{original_response}\n\n"
80
+ "Please fix the JSON and return only a valid JSON array of span objects. "
81
+ "Do not include any markdown formatting or explanation."
82
+ )
tei_annotator/prompting/templates/json_enforced.jinja2 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a TEI XML annotation assistant.
2
+
3
+ ## TEI Schema
4
+ {% for elem in schema.elements %}
5
+ - `{{ elem.tag }}`: {{ elem.description }}{% if elem.attributes %} (attributes: {% for attr in elem.attributes %}`{{ attr.name }}`{% if not loop.last %}, {% endif %}{% endfor %}){% endif %}
6
+ {% endfor %}
7
+
8
+ ## Source Text
9
+
10
+ ```
11
+ {{ source_text }}
12
+ ```
13
+ {% if candidates %}
14
+
15
+ ## Pre-detected Candidates (verify and extend)
16
+
17
+ {{ candidates | tojson }}
18
+ {% endif %}
19
+
20
+ Return a JSON array. Each item must have: `element`, `text`, `context`, `attrs`.
21
+ One entry per occurrence. `text` and `context` must be exact substrings of the source text.
tei_annotator/prompting/templates/text_gen.jinja2 ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a TEI XML annotation assistant. Your task is to identify named entities and spans in the source text and annotate them with TEI XML tags.
2
+
3
+ ## TEI Schema
4
+
5
+ The following TEI elements are in scope:
6
+ {% for elem in schema.elements %}
7
+ ### `{{ elem.tag }}`
8
+ {{ elem.description }}
9
+ {% if elem.attributes %}
10
+ Attributes:
11
+ {% for attr in elem.attributes %}
12
+ - `{{ attr.name }}`{% if attr.required %} *(required)*{% endif %}: {{ attr.description }}{% if attr.allowed_values %} — allowed values: `{{ attr.allowed_values | join("`, `") }}`{% endif %}
13
+ {% endfor %}
14
+ {% endif %}
15
+ {% endfor %}
16
+
17
+ ## Source Text
18
+
19
+ ```
20
+ {{ source_text }}
21
+ ```
22
+ {% if candidates %}
23
+
24
+ ## Pre-detected Candidates
25
+
26
+ The following spans were pre-detected by a fast model. Use them as hints — verify each one, correct any errors, and add any entities the detector missed:
27
+
28
+ {{ candidates | tojson }}
29
+ {% endif %}
30
+
31
+ ## Instructions
32
+
33
+ Identify all occurrences of entities described in the schema above in the source text.
34
+
35
+ Return a **JSON array** where each item is an object with:
36
+ - `"element"`: the TEI tag name (e.g. `"persName"`)
37
+ - `"text"`: the exact text span to annotate — must appear verbatim in the source text
38
+ - `"context"`: a substring of the source text (50–150 characters) that contains `"text"` as a substring
39
+ - `"attrs"`: a JSON object with attribute name → value pairs (use `{}` if no attributes needed)
40
+
41
+ Rules:
42
+ - Emit one entry per **occurrence**, not per unique entity
43
+ - `"text"` must be an exact substring of the source text
44
+ - `"context"` must be an exact substring of the source text and must contain `"text"`
45
+ - Do not modify the source text in any way
46
+
47
+ Output **only** the JSON array. Do not include markdown fences, explanations, or any other text.
48
+
49
+ Example output:
50
+ [
51
+ {"element": "persName", "text": "John Smith", "context": "He said John Smith yesterday.", "attrs": {}},
52
+ {"element": "placeName", "text": "Paris", "context": "traveled to Paris in 1920", "attrs": {"ref": "https://www.wikidata.org/wiki/Q90"}}
53
+ ]
tests/__init__.py ADDED
File without changes
tests/integration/__init__.py ADDED
File without changes
tests/integration/test_gliner_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests for GLiNER detection.
3
+
4
+ These tests download a real HuggingFace model (~400 MB) on first run.
5
+ Run with: pytest -m integration
6
+ """
7
+
8
+ import pytest
9
+
10
+ pytestmark = pytest.mark.integration
11
+
12
+
13
+ def test_gliner_detects_person_name():
14
+ from tei_annotator.detection.gliner_detector import detect_spans
15
+ from tei_annotator.models.schema import TEIElement, TEISchema
16
+
17
+ schema = TEISchema(
18
+ elements=[
19
+ TEIElement(tag="persName", description="a person's name", attributes=[]),
20
+ ]
21
+ )
22
+ text = "Albert Einstein was born in Ulm in 1879."
23
+ spans = detect_spans(text, schema, model_id="numind/NuNER_Zero")
24
+ assert any(s.element == "persName" and "Einstein" in s.text for s in spans), (
25
+ f"Expected a persName span containing 'Einstein'; got: {spans}"
26
+ )
27
+
28
+
29
+ def test_gliner_confidence_scores_present():
30
+ from tei_annotator.detection.gliner_detector import detect_spans
31
+ from tei_annotator.models.schema import TEIElement, TEISchema
32
+
33
+ schema = TEISchema(
34
+ elements=[
35
+ TEIElement(tag="persName", description="a person's name", attributes=[]),
36
+ ]
37
+ )
38
+ text = "Marie Curie discovered polonium."
39
+ spans = detect_spans(text, schema, model_id="numind/NuNER_Zero")
40
+ for span in spans:
41
+ if span.confidence is not None:
42
+ assert 0.0 <= span.confidence <= 1.0
43
+
44
+
45
+ def test_gliner_context_contains_text():
46
+ from tei_annotator.detection.gliner_detector import detect_spans
47
+ from tei_annotator.models.schema import TEIElement, TEISchema
48
+
49
+ schema = TEISchema(
50
+ elements=[
51
+ TEIElement(tag="persName", description="a person's name", attributes=[]),
52
+ ]
53
+ )
54
+ text = "Charles Darwin published On the Origin of Species."
55
+ spans = detect_spans(text, schema, model_id="numind/NuNER_Zero")
56
+ for span in spans:
57
+ assert span.text in span.context, (
58
+ f"span.text {span.text!r} not found in context {span.context!r}"
59
+ )
tests/integration/test_pipeline_e2e.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-end integration tests: real GLiNER model + mocked call_fn.
3
+
4
+ Tests that only use mocked call_fn (gliner_model=None) are also here because
5
+ they exercise the full pipeline with non-trivial context resolution scenarios.
6
+
7
+ Run with: pytest -m integration
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import re
14
+
15
+ import pytest
16
+
17
+ pytestmark = pytest.mark.integration
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Helpers
21
+ # ---------------------------------------------------------------------------
22
+
23
+
24
+ def _strip_tags(xml: str) -> str:
25
+ return re.sub(r"<[^>]+>", "", xml)
26
+
27
+
28
+ def _schema(*tags: tuple[str, str]):
29
+ """Build a TEISchema from (tag, description) pairs."""
30
+ from tei_annotator.models.schema import TEIAttribute, TEIElement, TEISchema
31
+
32
+ elements = []
33
+ for tag, desc in tags:
34
+ if tag == "persName":
35
+ elements.append(
36
+ TEIElement(
37
+ tag="persName",
38
+ description=desc,
39
+ attributes=[
40
+ TEIAttribute(name="ref", description="URI reference"),
41
+ TEIAttribute(name="cert", description="certainty", allowed_values=["high", "low"]),
42
+ ],
43
+ )
44
+ )
45
+ else:
46
+ elements.append(TEIElement(tag=tag, description=desc, attributes=[]))
47
+ return TEISchema(elements=elements)
48
+
49
+
50
+ def _endpoint(call_fn, capability="json_enforced"):
51
+ from tei_annotator.inference.endpoint import EndpointCapability, EndpointConfig
52
+
53
+ cap = {
54
+ "json_enforced": EndpointCapability.JSON_ENFORCED,
55
+ "text_generation": EndpointCapability.TEXT_GENERATION,
56
+ }[capability]
57
+ return EndpointConfig(capability=cap, call_fn=call_fn)
58
+
59
+
60
+ def _annotate(text, schema, call_fn, capability="json_enforced", gliner_model=None, **kw):
61
+ from tei_annotator.pipeline import annotate
62
+
63
+ return annotate(
64
+ text=text,
65
+ schema=schema,
66
+ endpoint=_endpoint(call_fn, capability),
67
+ gliner_model=gliner_model,
68
+ **kw,
69
+ )
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # 1. Exact context longer than span text
74
+ # ---------------------------------------------------------------------------
75
+
76
+
77
+ def test_context_longer_than_span_text():
78
+ """Resolver must locate span.text inside a longer context window."""
79
+ source = "The treaty was signed by Cardinal Richelieu in Paris."
80
+ schema = _schema(("persName", "a person's name"), ("placeName", "a place name"))
81
+
82
+ def call_fn(_):
83
+ return json.dumps([
84
+ {
85
+ "element": "persName",
86
+ "text": "Cardinal Richelieu",
87
+ "context": "was signed by Cardinal Richelieu in Paris",
88
+ "attrs": {},
89
+ },
90
+ {
91
+ "element": "placeName",
92
+ "text": "Paris",
93
+ "context": "Cardinal Richelieu in Paris.",
94
+ "attrs": {},
95
+ },
96
+ ])
97
+
98
+ result = _annotate(source, schema, call_fn)
99
+ assert "<persName>Cardinal Richelieu</persName>" in result.xml
100
+ assert "<placeName>Paris</placeName>" in result.xml
101
+ assert _strip_tags(result.xml) == source
102
+
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # 2. Same span text appears twice — context disambiguates
106
+ # ---------------------------------------------------------------------------
107
+
108
+
109
+ def test_multiple_occurrences_disambiguated_by_context():
110
+ """
111
+ 'John Smith' appears twice. LLM returns two spans with distinct contexts
112
+ pointing at each occurrence. Both must be annotated at the correct offset.
113
+ """
114
+ source = "John Smith arrived early. Later, John Smith left."
115
+ schema = _schema(("persName", "a person's name"))
116
+
117
+ def call_fn(_):
118
+ return json.dumps([
119
+ {
120
+ "element": "persName",
121
+ "text": "John Smith",
122
+ "context": "John Smith arrived early.",
123
+ "attrs": {},
124
+ },
125
+ {
126
+ "element": "persName",
127
+ "text": "John Smith",
128
+ "context": "Later, John Smith left.",
129
+ "attrs": {},
130
+ },
131
+ ])
132
+
133
+ result = _annotate(source, schema, call_fn)
134
+ assert result.xml.count("<persName>") == 2
135
+ assert result.xml.count("John Smith") == 2
136
+ assert _strip_tags(result.xml) == source
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # 3. Long text requiring chunking — global offset calculation
141
+ # ---------------------------------------------------------------------------
142
+
143
+
144
+ def test_long_text_entity_in_second_chunk():
145
+ """
146
+ Entity is far into a long text; its LLM context is relative to a later chunk.
147
+ Offset must be shifted by chunk.start_offset to land at the correct global position.
148
+ """
149
+ # Build a ~2500-char text; entity sits well past the first 1500-char chunk
150
+ filler = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " * 25 # ~1425 chars
151
+ target_sentence = "Napoleon Bonaparte was exiled to Saint Helena."
152
+ source = filler + target_sentence
153
+
154
+ schema = _schema(
155
+ ("persName", "a person's name"),
156
+ ("placeName", "a place name"),
157
+ )
158
+
159
+ def call_fn(prompt):
160
+ # The LLM sees either a filler chunk (returns []) or the chunk containing
161
+ # the target sentence and returns the two spans.
162
+ if "Napoleon" not in prompt:
163
+ return "[]"
164
+ return json.dumps([
165
+ {
166
+ "element": "persName",
167
+ "text": "Napoleon Bonaparte",
168
+ "context": "Napoleon Bonaparte was exiled to Saint Helena.",
169
+ "attrs": {},
170
+ },
171
+ {
172
+ "element": "placeName",
173
+ "text": "Saint Helena",
174
+ "context": "was exiled to Saint Helena.",
175
+ "attrs": {},
176
+ },
177
+ ])
178
+
179
+ result = _annotate(source, schema, call_fn, chunk_size=1500, chunk_overlap=200)
180
+
181
+ assert "<persName>Napoleon Bonaparte</persName>" in result.xml
182
+ assert "<placeName>Saint Helena</placeName>" in result.xml
183
+ assert _strip_tags(result.xml) == source
184
+
185
+ # Verify the annotated positions are truly within the target sentence
186
+ napoleon_start = result.xml.index("Napoleon Bonaparte")
187
+ assert napoleon_start > 1400, (
188
+ f"Napoleon offset {napoleon_start} is too early — chunk offset was not applied"
189
+ )
190
+
191
+
192
+ # ---------------------------------------------------------------------------
193
+ # 4. Nested spans resolved end-to-end
194
+ # ---------------------------------------------------------------------------
195
+
196
+
197
+ def test_nested_spans_end_to_end():
198
+ """
199
+ LLM emits an outer persName and inner forename / surname spans.
200
+ Both are resolved separately and then nested by the injector.
201
+ """
202
+ source = "He met John Smith today."
203
+ schema = _schema(
204
+ ("persName", "a person's full name"),
205
+ ("forename", "a forename"),
206
+ ("surname", "a surname"),
207
+ )
208
+
209
+ def call_fn(_):
210
+ return json.dumps([
211
+ {"element": "persName", "text": "John Smith", "context": "met John Smith today.", "attrs": {}},
212
+ {"element": "forename", "text": "John", "context": "met John Smith today.", "attrs": {}},
213
+ {"element": "surname", "text": "Smith", "context": "John Smith today.", "attrs": {}},
214
+ ])
215
+
216
+ result = _annotate(source, schema, call_fn)
217
+
218
+ assert "<persName>" in result.xml
219
+ assert "<forename>" in result.xml
220
+ assert "<surname>" in result.xml
221
+ # forename and surname must be inside persName
222
+ p_open = result.xml.index("<persName>")
223
+ p_close = result.xml.index("</persName>")
224
+ fn_open = result.xml.index("<forename>")
225
+ sn_close = result.xml.index("</surname>")
226
+ assert p_open < fn_open < sn_close < p_close
227
+ assert _strip_tags(result.xml) == source
228
+
229
+
230
+ # ---------------------------------------------------------------------------
231
+ # 5. Pre-existing XML preserved after annotation
232
+ # ---------------------------------------------------------------------------
233
+
234
+
235
+ def test_preexisting_xml_preserved():
236
+ """
237
+ Source already has markup (<note> tags). After annotation the original
238
+ markup must still be present alongside the new TEI annotations.
239
+ """
240
+ source = "He met <note>allegedly</note> John Smith yesterday."
241
+ schema = _schema(("persName", "a person's name"))
242
+
243
+ def call_fn(_):
244
+ # The LLM sees stripped plain text: "He met allegedly John Smith yesterday."
245
+ return json.dumps([
246
+ {
247
+ "element": "persName",
248
+ "text": "John Smith",
249
+ "context": "allegedly John Smith yesterday.",
250
+ "attrs": {},
251
+ }
252
+ ])
253
+
254
+ result = _annotate(source, schema, call_fn)
255
+
256
+ assert "<note>" in result.xml
257
+ assert "</note>" in result.xml
258
+ assert "<persName>John Smith</persName>" in result.xml
259
+ # Plain text must be unchanged
260
+ assert _strip_tags(result.xml) == _strip_tags(source)
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # 6. Attributes preserved end-to-end
265
+ # ---------------------------------------------------------------------------
266
+
267
+
268
+ def test_attributes_preserved_end_to_end():
269
+ """Attribute values returned by the LLM must appear verbatim in the output tag."""
270
+ source = "The emperor Napoleon was defeated at Waterloo."
271
+ schema = _schema(("persName", "a person's name"))
272
+
273
+ def call_fn(_):
274
+ return json.dumps([
275
+ {
276
+ "element": "persName",
277
+ "text": "Napoleon",
278
+ "context": "emperor Napoleon was defeated",
279
+ "attrs": {"ref": "http://viaf.org/viaf/106964661", "cert": "high"},
280
+ }
281
+ ])
282
+
283
+ result = _annotate(source, schema, call_fn)
284
+ assert 'ref="http://viaf.org/viaf/106964661"' in result.xml
285
+ assert 'cert="high"' in result.xml
286
+ assert _strip_tags(result.xml) == source
287
+
288
+
289
+ # ---------------------------------------------------------------------------
290
+ # 7. Hallucinated context → span silently rejected
291
+ # ---------------------------------------------------------------------------
292
+
293
+
294
+ def test_hallucinated_context_span_rejected():
295
+ """
296
+ LLM returns a plausible-looking but non-existent context.
297
+ The resolver must reject the span; the source text is returned unmodified.
298
+ """
299
+ source = "Marie Curie discovered polonium."
300
+ schema = _schema(("persName", "a person's name"))
301
+
302
+ def call_fn(_):
303
+ return json.dumps([
304
+ {
305
+ "element": "persName",
306
+ "text": "Marie Curie",
307
+ "context": "Dr. Marie Curie discovered polonium", # "Dr. " not in source
308
+ "attrs": {},
309
+ }
310
+ ])
311
+
312
+ result = _annotate(source, schema, call_fn)
313
+ assert "<persName>" not in result.xml
314
+ assert result.xml == source
315
+
316
+
317
+ # ---------------------------------------------------------------------------
318
+ # 8. Fuzzy context match → span annotated and flagged
319
+ # ---------------------------------------------------------------------------
320
+
321
+
322
+ def test_fuzzy_context_match_flags_span():
323
+ """
324
+ A context with a single-character typo should still resolve via fuzzy
325
+ matching (score > 0.92) and be included with fuzzy_match=True.
326
+ """
327
+ source = "Galileo Galilei observed the moons of Jupiter."
328
+ schema = _schema(("persName", "a person's name"))
329
+
330
+ def call_fn(_):
331
+ return json.dumps([
332
+ {
333
+ "element": "persName",
334
+ "text": "Galileo Galilei",
335
+ # One character different from the source — should trigger fuzzy
336
+ "context": "Galileo Galilei observd the moons of Jupiter.",
337
+ "attrs": {},
338
+ }
339
+ ])
340
+
341
+ result = _annotate(source, schema, call_fn)
342
+ # The span should still be annotated
343
+ assert "<persName>Galileo Galilei</persName>" in result.xml
344
+ # And flagged as fuzzy
345
+ assert len(result.fuzzy_spans) == 1
346
+ assert result.fuzzy_spans[0].element == "persName"
347
+ assert _strip_tags(result.xml) == source
348
+
349
+
350
+ # ---------------------------------------------------------------------------
351
+ # 9. Source text never modified (plain-text invariant)
352
+ # ---------------------------------------------------------------------------
353
+
354
+
355
+ def test_plain_text_invariant_with_multiple_entities():
356
+ """Stripping all tags from the output must yield exactly the input text."""
357
+ source = (
358
+ "Leonardo da Vinci was born in Vinci, Tuscany, "
359
+ "and later worked in Milan and Florence."
360
+ )
361
+ schema = _schema(
362
+ ("persName", "a person's name"),
363
+ ("placeName", "a place name"),
364
+ )
365
+
366
+ def call_fn(_):
367
+ return json.dumps([
368
+ {"element": "persName", "text": "Leonardo da Vinci",
369
+ "context": "Leonardo da Vinci was born in Vinci", "attrs": {}},
370
+ {"element": "placeName", "text": "Vinci",
371
+ "context": "born in Vinci, Tuscany", "attrs": {}},
372
+ {"element": "placeName", "text": "Tuscany",
373
+ "context": "Vinci, Tuscany, and later", "attrs": {}},
374
+ {"element": "placeName", "text": "Milan",
375
+ "context": "later worked in Milan and Florence", "attrs": {}},
376
+ {"element": "placeName", "text": "Florence",
377
+ "context": "in Milan and Florence.", "attrs": {}},
378
+ ])
379
+
380
+ result = _annotate(source, schema, call_fn)
381
+ assert _strip_tags(result.xml) == source
382
+ assert result.xml.count("<placeName>") == 4
383
+ assert "<persName>Leonardo da Vinci</persName>" in result.xml
384
+
385
+
386
+ # ---------------------------------------------------------------------------
387
+ # 10. Real GLiNER model (requires HuggingFace download)
388
+ # ---------------------------------------------------------------------------
389
+
390
+
391
+ def test_pipeline_with_real_gliner():
392
+ """Full pipeline: real GLiNER pre-detection + mocked LLM call_fn."""
393
+ schema = _schema(("persName", "a person's name"))
394
+
395
+ def mock_llm(_: str) -> str:
396
+ return json.dumps([
397
+ {
398
+ "element": "persName",
399
+ "text": "Albert Einstein",
400
+ "context": "Albert Einstein was born",
401
+ "attrs": {},
402
+ }
403
+ ])
404
+
405
+ result = _annotate(
406
+ "Albert Einstein was born in Ulm in 1879.",
407
+ schema,
408
+ mock_llm,
409
+ gliner_model="numind/NuNER_Zero",
410
+ )
411
+ assert "persName" in result.xml
412
+ assert "Albert Einstein" in result.xml
413
+ assert result.xml.count("Albert Einstein") == 1
tests/test_builder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from tei_annotator.inference.endpoint import EndpointCapability
4
+ from tei_annotator.models.schema import TEIElement, TEISchema
5
+ from tei_annotator.models.spans import SpanDescriptor
6
+ from tei_annotator.prompting.builder import build_prompt, make_correction_prompt
7
+
8
+
9
+ def _schema():
10
+ return TEISchema(
11
+ elements=[
12
+ TEIElement(tag="persName", description="a person's name", attributes=[]),
13
+ TEIElement(tag="placeName", description="a place name", attributes=[]),
14
+ ]
15
+ )
16
+
17
+
18
+ def test_text_gen_prompt_contains_json_instruction():
19
+ prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
20
+ assert "JSON" in prompt or "json" in prompt
21
+
22
+
23
+ def test_text_gen_prompt_contains_example():
24
+ prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
25
+ # The template shows an example output array
26
+ assert "persName" in prompt or "element" in prompt
27
+
28
+
29
+ def test_text_gen_prompt_contains_schema_elements():
30
+ prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
31
+ assert "persName" in prompt
32
+ assert "placeName" in prompt
33
+
34
+
35
+ def test_text_gen_prompt_contains_source_text():
36
+ prompt = build_prompt("unique_source_42", _schema(), EndpointCapability.TEXT_GENERATION)
37
+ assert "unique_source_42" in prompt
38
+
39
+
40
+ def test_json_enforced_prompt_contains_schema():
41
+ prompt = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED)
42
+ assert "persName" in prompt
43
+ assert "placeName" in prompt
44
+
45
+
46
+ def test_json_enforced_prompt_shorter_than_text_gen():
47
+ text_gen = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION)
48
+ json_enf = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED)
49
+ assert len(json_enf) < len(text_gen)
50
+
51
+
52
+ def test_candidates_appear_in_prompt():
53
+ candidates = [
54
+ SpanDescriptor(element="persName", text="John", context="said John went", attrs={})
55
+ ]
56
+ prompt = build_prompt(
57
+ "said John went.",
58
+ _schema(),
59
+ EndpointCapability.TEXT_GENERATION,
60
+ candidates=candidates,
61
+ )
62
+ assert "John" in prompt
63
+
64
+
65
+ def test_no_candidate_section_when_none():
66
+ prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=None)
67
+ assert "Pre-detected" not in prompt
68
+
69
+
70
+ def test_empty_candidates_list_no_section():
71
+ prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=[])
72
+ assert "Pre-detected" not in prompt
73
+
74
+
75
+ def test_extraction_raises():
76
+ with pytest.raises(ValueError):
77
+ build_prompt("text", _schema(), EndpointCapability.EXTRACTION)
78
+
79
+
80
+ def test_correction_prompt_contains_original_response():
81
+ prompt = make_correction_prompt("bad_json_here", "JSONDecodeError")
82
+ assert "bad_json_here" in prompt
83
+ assert "JSONDecodeError" in prompt
tests/test_chunker.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tei_annotator.chunking.chunker import Chunk, chunk_text
2
+
3
+
4
+ def test_short_text_single_chunk():
5
+ text = "Short text."
6
+ chunks = chunk_text(text, chunk_size=1500)
7
+ assert len(chunks) == 1
8
+ assert chunks[0].text == text
9
+ assert chunks[0].start_offset == 0
10
+
11
+
12
+ def test_long_text_multiple_chunks():
13
+ text = "word " * 400 # 2000 chars
14
+ chunks = chunk_text(text, chunk_size=500, overlap=50)
15
+ assert len(chunks) > 1
16
+ for i, chunk in enumerate(chunks):
17
+ assert chunk.start_offset >= 0
18
+ if i > 0:
19
+ assert chunk.start_offset > chunks[i - 1].start_offset
20
+
21
+
22
+ def test_chunk_start_offsets_correct():
23
+ """Every chunk's text must match a slice of the original at start_offset."""
24
+ text = "hello world " * 200
25
+ chunks = chunk_text(text, chunk_size=300, overlap=50)
26
+ for chunk in chunks:
27
+ assert (
28
+ text[chunk.start_offset : chunk.start_offset + len(chunk.text)]
29
+ == chunk.text
30
+ )
31
+
32
+
33
+ def test_long_text_covers_all_characters():
34
+ """Union of all chunk ranges must cover the entire source text."""
35
+ text = "abcdefghij" * 200 # 2000 chars
36
+ chunks = chunk_text(text, chunk_size=400, overlap=80)
37
+ covered: set[int] = set()
38
+ for chunk in chunks:
39
+ for j in range(chunk.start_offset, chunk.start_offset + len(chunk.text)):
40
+ covered.add(j)
41
+ assert covered == set(range(len(text)))
42
+
43
+
44
+ def test_chunk_boundary_does_not_split_xml_tag():
45
+ """A chunk boundary must never fall inside an XML tag."""
46
+ # Place a tag that straddles the natural 500-char boundary
47
+ prefix = "a" * 495
48
+ tag = "<someElement>"
49
+ suffix = "b" * 600
50
+ text = prefix + tag + suffix
51
+
52
+ chunks = chunk_text(text, chunk_size=500, overlap=0)
53
+
54
+ for chunk in chunks:
55
+ # Each chunk must be self-consistent XML-tag-wise:
56
+ # count of '<' must equal count of '>' within the chunk text
57
+ # (a split tag would have an unbalanced '<' or '>')
58
+ assert chunk.text.count("<") == chunk.text.count(">"), (
59
+ f"Chunk at offset {chunk.start_offset} has unbalanced angle brackets: "
60
+ f"{chunk.text!r}"
61
+ )
62
+
63
+
64
+ def test_exact_chunk_size_no_overflow():
65
+ text = "x" * 1500
66
+ chunks = chunk_text(text, chunk_size=1500, overlap=0)
67
+ assert len(chunks) == 1
68
+ assert chunks[0].text == text
69
+
70
+
71
+ def test_overlap_produces_repeated_content():
72
+ """With positive overlap, the end of chunk N overlaps with the start of chunk N+1."""
73
+ text = "word " * 300 # 1500 chars
74
+ chunks = chunk_text(text, chunk_size=500, overlap=100)
75
+ assert len(chunks) >= 2
76
+ # The end of chunk 0 and the start of chunk 1 must share content
77
+ c0_end = chunks[0].start_offset + len(chunks[0].text)
78
+ c1_start = chunks[1].start_offset
79
+ assert c1_start < c0_end, "Expected overlapping content between consecutive chunks"
tests/test_injector.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from tei_annotator.models.spans import ResolvedSpan
4
+ from tei_annotator.postprocessing.injector import _build_nesting_tree, inject_xml
5
+
6
+
7
+ def _span(element, start, end, attrs=None):
8
+ return ResolvedSpan(element=element, start=start, end=end, attrs=attrs or {})
9
+
10
+
11
+ def test_no_spans_returns_source():
12
+ assert inject_xml("hello world", []) == "hello world"
13
+
14
+
15
+ def test_single_span():
16
+ source = "He said John Smith yesterday."
17
+ # "John Smith" = [8:18]
18
+ span = _span("persName", 8, 18)
19
+ result = inject_xml(source, [span])
20
+ assert result == "He said <persName>John Smith</persName> yesterday."
21
+
22
+
23
+ def test_two_non_overlapping_spans():
24
+ source = "John met Mary."
25
+ # "John" = [0:4], "Mary" = [9:13]
26
+ spans = [_span("persName", 0, 4), _span("persName", 9, 13)]
27
+ result = inject_xml(source, spans)
28
+ assert result == "<persName>John</persName> met <persName>Mary</persName>."
29
+
30
+
31
+ def test_nested_spans():
32
+ # "Dr. Smith" = outer, "Dr." = inner
33
+ source = "He met Dr. Smith today."
34
+ # "Dr. Smith" = [7:16], "Dr." = [7:10]
35
+ spans = [_span("persName", 7, 16), _span("roleName", 7, 10)]
36
+ result = inject_xml(source, spans)
37
+ assert "<persName>" in result
38
+ assert "<roleName>" in result
39
+ # roleName must appear inside persName
40
+ assert result.index("<roleName>") > result.index("<persName>")
41
+ assert result.index("</roleName>") < result.index("</persName>")
42
+ # Text is split by the inner tag; check exact output structure
43
+ assert result == "He met <persName><roleName>Dr.</roleName> Smith</persName> today."
44
+
45
+
46
+ def test_attrs_rendered_in_tag():
47
+ source = "Visit Paris."
48
+ span = _span("placeName", 6, 11, {"ref": "http://example.com/paris"})
49
+ result = inject_xml(source, [span])
50
+ assert 'ref="http://example.com/paris"' in result
51
+ assert "<placeName" in result
52
+ assert "Paris" in result
53
+
54
+
55
+ def test_span_at_start_of_text():
56
+ source = "John went home."
57
+ span = _span("persName", 0, 4)
58
+ result = inject_xml(source, [span])
59
+ assert result.startswith("<persName>John</persName>")
60
+
61
+
62
+ def test_span_covering_entire_text():
63
+ source = "John Smith"
64
+ span = _span("persName", 0, 10)
65
+ result = inject_xml(source, [span])
66
+ assert result == "<persName>John Smith</persName>"
67
+
68
+
69
+ def test_span_at_end_of_text():
70
+ source = "He visited Paris"
71
+ span = _span("placeName", 11, 16)
72
+ result = inject_xml(source, [span])
73
+ assert result.endswith("<placeName>Paris</placeName>")
74
+
75
+
76
+ def test_overlapping_spans_warns_and_skips():
77
+ source = "Hello World"
78
+ # Partial overlap: [0,7] and [5,11]
79
+ spans = [_span("a", 0, 7), _span("b", 5, 11)]
80
+ with pytest.warns(UserWarning, match="Overlapping"):
81
+ result = inject_xml(source, spans)
82
+ # Only the first span should be present
83
+ assert "<a>" in result
84
+ assert "<b>" not in result
85
+
86
+
87
+ def test_build_nesting_tree_simple():
88
+ outer = _span("persName", 0, 20)
89
+ inner = _span("roleName", 0, 5)
90
+ roots = _build_nesting_tree([outer, inner])
91
+ assert len(roots) == 1
92
+ assert roots[0].element == "persName"
93
+ assert len(roots[0].children) == 1
94
+ assert roots[0].children[0].element == "roleName"
95
+
96
+
97
+ def test_build_nesting_tree_siblings():
98
+ a = _span("a", 0, 5)
99
+ b = _span("b", 6, 10)
100
+ roots = _build_nesting_tree([a, b])
101
+ assert len(roots) == 2
102
+ assert all(len(r.children) == 0 for r in roots)
tests/test_parser.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import pytest
4
+
5
+ from tei_annotator.postprocessing.parser import _strip_fences, parse_response
6
+
7
+ VALID_JSON = json.dumps(
8
+ [{"element": "persName", "text": "John Smith", "context": "said John Smith", "attrs": {}}]
9
+ )
10
+
11
+
12
+ # ---- _strip_fences ----------------------------------------------------------
13
+
14
+
15
+ def test_strip_fences_json_lang():
16
+ fenced = f"```json\n{VALID_JSON}\n```"
17
+ assert _strip_fences(fenced) == VALID_JSON
18
+
19
+
20
+ def test_strip_fences_no_lang():
21
+ fenced = f"```\n{VALID_JSON}\n```"
22
+ assert _strip_fences(fenced) == VALID_JSON
23
+
24
+
25
+ def test_strip_fences_no_fences():
26
+ assert _strip_fences(VALID_JSON) == VALID_JSON
27
+
28
+
29
+ def test_strip_fences_with_preamble():
30
+ text = f"Here is the JSON:\n```json\n{VALID_JSON}\n```"
31
+ assert _strip_fences(text) == VALID_JSON
32
+
33
+
34
+ # ---- parse_response ---------------------------------------------------------
35
+
36
+
37
+ def test_valid_json_parsed_directly():
38
+ spans = parse_response(VALID_JSON)
39
+ assert len(spans) == 1
40
+ assert spans[0].element == "persName"
41
+ assert spans[0].text == "John Smith"
42
+
43
+
44
+ def test_markdown_fenced_json_parsed():
45
+ spans = parse_response(f"```json\n{VALID_JSON}\n```")
46
+ assert len(spans) == 1
47
+
48
+
49
+ def test_invalid_json_no_retry_raises():
50
+ with pytest.raises(ValueError):
51
+ parse_response("not json at all")
52
+
53
+
54
+ def test_retry_triggered_on_first_failure():
55
+ call_count = [0]
56
+
57
+ def retry_fn(prompt: str) -> str:
58
+ call_count[0] += 1
59
+ return VALID_JSON
60
+
61
+ def correction_fn(bad: str, err: str) -> str:
62
+ return f"fix: {bad}"
63
+
64
+ spans = parse_response("bad json", call_fn=retry_fn, make_correction_prompt=correction_fn)
65
+ assert call_count[0] == 1
66
+ assert len(spans) == 1
67
+
68
+
69
+ def test_retry_still_invalid_raises():
70
+ def retry_fn(prompt: str) -> str:
71
+ return "still bad"
72
+
73
+ def correction_fn(bad: str, err: str) -> str:
74
+ return "fix it"
75
+
76
+ with pytest.raises(ValueError):
77
+ parse_response("bad", call_fn=retry_fn, make_correction_prompt=correction_fn)
78
+
79
+
80
+ def test_missing_fields_items_skipped():
81
+ raw = json.dumps(
82
+ [
83
+ {"element": "persName"}, # missing text and context → skip
84
+ {"element": "persName", "text": "John", "context": "John went"}, # valid
85
+ ]
86
+ )
87
+ spans = parse_response(raw)
88
+ assert len(spans) == 1
89
+ assert spans[0].text == "John"
90
+
91
+
92
+ def test_non_list_response_raises():
93
+ with pytest.raises(ValueError):
94
+ parse_response(json.dumps({"element": "persName"}))
95
+
96
+
97
+ def test_attrs_defaults_to_empty_dict():
98
+ raw = json.dumps([{"element": "persName", "text": "x", "context": "x"}])
99
+ spans = parse_response(raw)
100
+ assert spans[0].attrs == {}
tests/test_pipeline.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import pytest
4
+
5
+ from tei_annotator.inference.endpoint import EndpointCapability, EndpointConfig
6
+ from tei_annotator.models.schema import TEIElement, TEISchema
7
+ from tei_annotator.pipeline import annotate
8
+
9
+
10
+ def _schema():
11
+ return TEISchema(
12
+ elements=[
13
+ TEIElement(
14
+ tag="persName",
15
+ description="a person's name",
16
+ allowed_children=[],
17
+ attributes=[],
18
+ )
19
+ ]
20
+ )
21
+
22
+
23
+ def _mock_call_fn(prompt: str) -> str:
24
+ return json.dumps(
25
+ [
26
+ {
27
+ "element": "persName",
28
+ "text": "John Smith",
29
+ "context": "said John Smith yesterday",
30
+ "attrs": {},
31
+ }
32
+ ]
33
+ )
34
+
35
+
36
+ def test_annotate_smoke():
37
+ result = annotate(
38
+ text="He said John Smith yesterday.",
39
+ schema=_schema(),
40
+ endpoint=EndpointConfig(
41
+ capability=EndpointCapability.JSON_ENFORCED,
42
+ call_fn=_mock_call_fn,
43
+ ),
44
+ gliner_model=None,
45
+ )
46
+ assert "persName" in result.xml
47
+ assert "John Smith" in result.xml
48
+ assert result.xml.count("John Smith") == 1 # text not duplicated
49
+
50
+
51
+ def test_annotate_empty_response():
52
+ result = annotate(
53
+ text="No entities here.",
54
+ schema=_schema(),
55
+ endpoint=EndpointConfig(
56
+ capability=EndpointCapability.JSON_ENFORCED,
57
+ call_fn=lambda _: "[]",
58
+ ),
59
+ gliner_model=None,
60
+ )
61
+ assert result.xml == "No entities here."
62
+ assert result.fuzzy_spans == []
63
+
64
+
65
+ def test_annotate_preserves_existing_xml():
66
+ # Pre-existing <b> tag must survive
67
+ def call_fn(prompt: str) -> str:
68
+ return json.dumps(
69
+ [
70
+ {
71
+ "element": "persName",
72
+ "text": "John Smith",
73
+ "context": "said John Smith yesterday",
74
+ "attrs": {},
75
+ }
76
+ ]
77
+ )
78
+
79
+ result = annotate(
80
+ text="He said <b>John Smith</b> yesterday.",
81
+ schema=_schema(),
82
+ endpoint=EndpointConfig(
83
+ capability=EndpointCapability.JSON_ENFORCED, call_fn=call_fn
84
+ ),
85
+ gliner_model=None,
86
+ )
87
+ assert "<b>" in result.xml
88
+ assert "John Smith" in result.xml
89
+
90
+
91
+ def test_annotate_fuzzy_spans_surfaced():
92
+ """Spans flagged as fuzzy appear in AnnotationResult.fuzzy_spans."""
93
+ # We cannot force a fuzzy match easily without mocking internals,
94
+ # so we just verify the field exists and is a list.
95
+ result = annotate(
96
+ text="He said John Smith yesterday.",
97
+ schema=_schema(),
98
+ endpoint=EndpointConfig(
99
+ capability=EndpointCapability.JSON_ENFORCED,
100
+ call_fn=_mock_call_fn,
101
+ ),
102
+ gliner_model=None,
103
+ )
104
+ assert isinstance(result.fuzzy_spans, list)
105
+
106
+
107
+ def test_annotate_text_generation_endpoint():
108
+ """TEXT_GENERATION capability path (with retry logic enabled) works end-to-end."""
109
+ result = annotate(
110
+ text="He said John Smith yesterday.",
111
+ schema=_schema(),
112
+ endpoint=EndpointConfig(
113
+ capability=EndpointCapability.TEXT_GENERATION,
114
+ call_fn=_mock_call_fn,
115
+ ),
116
+ gliner_model=None,
117
+ )
118
+ assert "persName" in result.xml
119
+
120
+
121
+ def test_annotate_no_text_modification():
122
+ """The original text characters must all appear in the output (no hallucination)."""
123
+ original = "He said John Smith yesterday."
124
+ result = annotate(
125
+ text=original,
126
+ schema=_schema(),
127
+ endpoint=EndpointConfig(
128
+ capability=EndpointCapability.JSON_ENFORCED,
129
+ call_fn=_mock_call_fn,
130
+ ),
131
+ gliner_model=None,
132
+ )
133
+ # Strip all tags from output; plain text should equal original
134
+ import re
135
+
136
+ plain = re.sub(r"<[^>]+>", "", result.xml)
137
+ assert plain == original
tests/test_resolver.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from tei_annotator.models.spans import SpanDescriptor
4
+ from tei_annotator.postprocessing.resolver import resolve_spans
5
+
6
+ SOURCE = "He said John Smith yesterday, and John Smith agreed."
7
+
8
+
9
+ def _span(element, text, context, attrs=None):
10
+ return SpanDescriptor(element=element, text=text, context=context, attrs=attrs or {})
11
+
12
+
13
+ def test_exact_context_match():
14
+ span = _span("persName", "John Smith", "said John Smith yesterday")
15
+ resolved = resolve_spans(SOURCE, [span])
16
+ assert len(resolved) == 1
17
+ rs = resolved[0]
18
+ assert rs.start == SOURCE.index("John Smith")
19
+ assert rs.end == rs.start + len("John Smith")
20
+ assert not rs.fuzzy_match
21
+
22
+
23
+ def test_context_not_found_rejected():
24
+ span = _span("persName", "John Smith", "this context does not exist xyz987")
25
+ assert resolve_spans(SOURCE, [span]) == []
26
+
27
+
28
+ def test_text_not_in_context_window_rejected():
29
+ span = _span("persName", "Jane Doe", "said John Smith yesterday")
30
+ assert resolve_spans(SOURCE, [span]) == []
31
+
32
+
33
+ def test_source_slice_verified():
34
+ span = _span("persName", "John Smith", "said John Smith yesterday")
35
+ resolved = resolve_spans(SOURCE, [span])
36
+ assert len(resolved) == 1
37
+ rs = resolved[0]
38
+ assert SOURCE[rs.start : rs.end] == "John Smith"
39
+
40
+
41
+ def test_attrs_preserved():
42
+ span = _span("persName", "John Smith", "said John Smith yesterday", {"ref": "#js"})
43
+ resolved = resolve_spans(SOURCE, [span])
44
+ assert len(resolved) == 1
45
+ assert resolved[0].attrs == {"ref": "#js"}
46
+
47
+
48
+ def test_multiple_spans_resolved():
49
+ spans = [
50
+ _span("persName", "John Smith", "He said John Smith yesterday"),
51
+ _span("persName", "John Smith", "and John Smith agreed"),
52
+ ]
53
+ resolved = resolve_spans(SOURCE, spans)
54
+ assert len(resolved) == 2
55
+ assert resolved[0].start != resolved[1].start
56
+
57
+
58
+ def test_empty_span_list():
59
+ assert resolve_spans(SOURCE, []) == []
60
+
61
+
62
+ def test_children_start_empty():
63
+ span = _span("persName", "John Smith", "said John Smith yesterday")
64
+ resolved = resolve_spans(SOURCE, [span])
65
+ assert resolved[0].children == []
tests/test_validator.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tei_annotator.models.schema import TEIAttribute, TEIElement, TEISchema
2
+ from tei_annotator.models.spans import ResolvedSpan
3
+ from tei_annotator.postprocessing.validator import validate_spans
4
+
5
+ SOURCE = "He met John Smith."
6
+
7
+
8
+ def _schema():
9
+ return TEISchema(
10
+ elements=[
11
+ TEIElement(
12
+ tag="persName",
13
+ description="a person's name",
14
+ attributes=[
15
+ TEIAttribute(name="ref", description="reference URI"),
16
+ TEIAttribute(
17
+ name="cert",
18
+ description="certainty",
19
+ allowed_values=["high", "low"],
20
+ ),
21
+ ],
22
+ )
23
+ ]
24
+ )
25
+
26
+
27
+ def _span(element, start, end, attrs=None):
28
+ return ResolvedSpan(element=element, start=start, end=end, attrs=attrs or {})
29
+
30
+
31
+ # SOURCE: "He met John Smith."
32
+ # positions: H=0 e=1 ' '=2 m=3 e=4 t=5 ' '=6 J=7 o=8 h=9 n=10 ' '=11 S=12 m=13 i=14 t=15 h=16 .=17
33
+ # "John Smith" => [7:17]
34
+
35
+
36
+ def test_valid_span_passes():
37
+ result = validate_spans([_span("persName", 7, 17)], _schema(), SOURCE)
38
+ assert len(result) == 1
39
+
40
+
41
+ def test_unknown_element_rejected():
42
+ result = validate_spans([_span("orgName", 7, 17)], _schema(), SOURCE)
43
+ assert len(result) == 0
44
+
45
+
46
+ def test_unknown_attribute_rejected():
47
+ result = validate_spans(
48
+ [_span("persName", 7, 17, {"unknown_attr": "val"})], _schema(), SOURCE
49
+ )
50
+ assert len(result) == 0
51
+
52
+
53
+ def test_invalid_attribute_value_rejected():
54
+ result = validate_spans(
55
+ [_span("persName", 7, 17, {"cert": "medium"})], _schema(), SOURCE
56
+ )
57
+ assert len(result) == 0
58
+
59
+
60
+ def test_valid_constrained_attribute_passes():
61
+ result = validate_spans(
62
+ [_span("persName", 7, 17, {"cert": "high"})], _schema(), SOURCE
63
+ )
64
+ assert len(result) == 1
65
+
66
+
67
+ def test_free_string_attribute_passes():
68
+ result = validate_spans(
69
+ [_span("persName", 7, 17, {"ref": "http://example.com/p/1"})], _schema(), SOURCE
70
+ )
71
+ assert len(result) == 1
72
+
73
+
74
+ def test_out_of_bounds_span_rejected():
75
+ result = validate_spans([_span("persName", -1, 5)], _schema(), SOURCE)
76
+ assert len(result) == 0
77
+ result2 = validate_spans([_span("persName", 5, 200)], _schema(), SOURCE)
78
+ assert len(result2) == 0
79
+
80
+
81
+ def test_empty_span_list():
82
+ assert validate_spans([], _schema(), SOURCE) == []
uv.lock ADDED
The diff for this file is too large to render. See raw diff