File size: 2,850 Bytes
37eaffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pytest

from tei_annotator.inference.endpoint import EndpointCapability
from tei_annotator.models.schema import TEIElement, TEISchema
from tei_annotator.models.spans import SpanDescriptor
from tei_annotator.prompting.builder import build_prompt, make_correction_prompt


def _schema():
    return TEISchema(
        elements=[
            TEIElement(tag="persName", description="a person's name", attributes=[]),
            TEIElement(tag="placeName", description="a place name", attributes=[]),
        ]
    )


def test_text_gen_prompt_contains_json_instruction():
    prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
    assert "JSON" in prompt or "json" in prompt


def test_text_gen_prompt_contains_example():
    prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
    # The template shows an example output array
    assert "persName" in prompt or "element" in prompt


def test_text_gen_prompt_contains_schema_elements():
    prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
    assert "persName" in prompt
    assert "placeName" in prompt


def test_text_gen_prompt_contains_source_text():
    prompt = build_prompt("unique_source_42", _schema(), EndpointCapability.TEXT_GENERATION)
    assert "unique_source_42" in prompt


def test_json_enforced_prompt_contains_schema():
    prompt = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED)
    assert "persName" in prompt
    assert "placeName" in prompt


def test_json_enforced_prompt_shorter_than_text_gen():
    text_gen = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION)
    json_enf = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED)
    assert len(json_enf) < len(text_gen)


def test_candidates_appear_in_prompt():
    candidates = [
        SpanDescriptor(element="persName", text="John", context="said John went", attrs={})
    ]
    prompt = build_prompt(
        "said John went.",
        _schema(),
        EndpointCapability.TEXT_GENERATION,
        candidates=candidates,
    )
    assert "John" in prompt


def test_no_candidate_section_when_none():
    prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=None)
    assert "Pre-detected" not in prompt


def test_empty_candidates_list_no_section():
    prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=[])
    assert "Pre-detected" not in prompt


def test_extraction_raises():
    with pytest.raises(ValueError):
        build_prompt("text", _schema(), EndpointCapability.EXTRACTION)


def test_correction_prompt_contains_original_response():
    prompt = make_correction_prompt("bad_json_here", "JSONDecodeError")
    assert "bad_json_here" in prompt
    assert "JSONDecodeError" in prompt