import pytest from tei_annotator.inference.endpoint import EndpointCapability from tei_annotator.models.schema import TEIElement, TEISchema from tei_annotator.models.spans import SpanDescriptor from tei_annotator.prompting.builder import build_prompt, make_correction_prompt def _schema(): return TEISchema( elements=[ TEIElement(tag="persName", description="a person's name", attributes=[]), TEIElement(tag="placeName", description="a place name", attributes=[]), ] ) def test_text_gen_prompt_contains_json_instruction(): prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION) assert "JSON" in prompt or "json" in prompt def test_text_gen_prompt_contains_example(): prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION) # The template shows an example output array assert "persName" in prompt or "element" in prompt def test_text_gen_prompt_contains_schema_elements(): prompt = build_prompt("Some text.", _schema(), EndpointCapability.TEXT_GENERATION) assert "persName" in prompt assert "placeName" in prompt def test_text_gen_prompt_contains_source_text(): prompt = build_prompt("unique_source_42", _schema(), EndpointCapability.TEXT_GENERATION) assert "unique_source_42" in prompt def test_json_enforced_prompt_contains_schema(): prompt = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED) assert "persName" in prompt assert "placeName" in prompt def test_json_enforced_prompt_shorter_than_text_gen(): text_gen = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION) json_enf = build_prompt("text", _schema(), EndpointCapability.JSON_ENFORCED) assert len(json_enf) < len(text_gen) def test_candidates_appear_in_prompt(): candidates = [ SpanDescriptor(element="persName", text="John", context="said John went", attrs={}) ] prompt = build_prompt( "said John went.", _schema(), EndpointCapability.TEXT_GENERATION, candidates=candidates, ) assert "John" in prompt def test_no_candidate_section_when_none(): prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=None) assert "Pre-detected" not in prompt def test_empty_candidates_list_no_section(): prompt = build_prompt("text", _schema(), EndpointCapability.TEXT_GENERATION, candidates=[]) assert "Pre-detected" not in prompt def test_extraction_raises(): with pytest.raises(ValueError): build_prompt("text", _schema(), EndpointCapability.EXTRACTION) def test_correction_prompt_contains_original_response(): prompt = make_correction_prompt("bad_json_here", "JSONDecodeError") assert "bad_json_here" in prompt assert "JSONDecodeError" in prompt