Spaces:
Sleeping
Sleeping
File size: 2,850 Bytes
37eaffd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import pytest
from tei_annotator.inference.endpoint import EndpointCapability
from tei_annotator.models.schema import TEIElement, TEISchema
from tei_annotator.models.spans import SpanDescriptor
from tei_annotator.prompting.builder import build_prompt, make_correction_prompt
def _schema():
return TEISchema(
elements=[
TEIElement(tag="persName", description="a person's name", attributes=[]),
TEIElement(tag="placeName", description="a place name", attributes=[]),
]
)
def test_text_gen_prompt_contains_json_instruction():
prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
assert "JSON" in prompt or "json" in prompt
def test_text_gen_prompt_contains_example():
prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
# The template shows an example output array
assert "persName" in prompt or "element" in prompt
def test_text_gen_prompt_contains_schema_elements():
prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION)
assert "persName" in prompt
assert "placeName" in prompt
def test_text_gen_prompt_contains_source_text():
prompt = build_prompt("unique_source_42", _schema(), EndpointCapability.TEXT_GENERATION)
assert "unique_source_42" in prompt
def test_json_enforced_prompt_contains_schema():
prompt = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED)
assert "persName" in prompt
assert "placeName" in prompt
def test_json_enforced_prompt_shorter_than_text_gen():
text_gen = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION)
json_enf = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED)
assert len(json_enf) < len(text_gen)
def test_candidates_appear_in_prompt():
candidates = [
SpanDescriptor(element="persName", text="John", context="said John went", attrs={})
]
prompt = build_prompt(
"said John went.",
_schema(),
EndpointCapability.TEXT_GENERATION,
candidates=candidates,
)
assert "John" in prompt
def test_no_candidate_section_when_none():
prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=None)
assert "Pre-detected" not in prompt
def test_empty_candidates_list_no_section():
prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=[])
assert "Pre-detected" not in prompt
def test_extraction_raises():
with pytest.raises(ValueError):
build_prompt("text", _schema(), EndpointCapability.EXTRACTION)
def test_correction_prompt_contains_original_response():
prompt = make_correction_prompt("bad_json_here", "JSONDecodeError")
assert "bad_json_here" in prompt
assert "JSONDecodeError" in prompt
|