| """ |
| Unit tests — no real model loaded (GLiNER/torch stubbed via conftest.py). |
| Covers: |
| - API contract (health, validation, threshold, language field) |
| - Dual-model routing (ZH → BERT backend, EN/AR/mixed → GLiNER backend) |
| - Optional labels + bilingual auto-expansion |
| - English / Chinese / Arabic / mixed-language scenarios |
| - labels_used echo in response |
| """ |
| from unittest.mock import MagicMock, patch, PropertyMock |
|
|
| import pytest |
| from fastapi.testclient import TestClient |
|
|
| from app.main import app |
| from app.models import Entity |
| from app.labels import DEFAULT_LABELS, expand_bilingual, labels_to_bert_types |
| from app.ner import detect_language |
|
|
|
|
| |
|
|
| def _ents(*args: Entity) -> tuple[list[Entity], list[str]]: |
| """Wrap entities in (entities, labels_used) tuple expected by NERService.""" |
| return list(args), [e.label for e in args] |
|
|
|
|
| |
|
|
| @pytest.fixture() |
| def client(): |
| """ |
| Patch NERService so no model is actually loaded. |
| mock_ner.extract() returns ([], []) by default. |
| """ |
| mock_ner = MagicMock() |
| mock_ner.extract.return_value = ([], []) |
| with pytest.MonkeyPatch().context() as mp: |
| mp.setattr("app.main.NERService", lambda *_: mock_ner) |
| with TestClient(app) as c: |
| yield c, mock_ner |
|
|
|
|
| |
|
|
| def test_health(client): |
| c, _ = client |
| assert c.get("/api/v1/health").json() == {"status": "ok"} |
|
|
|
|
| def test_extract_empty_text_returns_empty(client): |
| c, _ = client |
| resp = c.post("/api/v1/extract", json={"text": "", "labels": ["person"]}) |
| assert resp.status_code == 200 |
| assert resp.json()["entities"] == [] |
|
|
|
|
| def test_extract_labels_optional(client): |
| """labels 字段完全不传应正常返回 200。""" |
| c, mock_ner = client |
| mock_ner.extract.return_value = ([], DEFAULT_LABELS) |
| resp = c.post("/api/v1/extract", json={"text": "Some text."}) |
| assert resp.status_code == 200 |
| assert len(resp.json()["labels_used"]) > 0 |
|
|
|
|
| def test_extract_empty_labels_uses_defaults(client): |
| """labels=[] 时应使用默认双语标签集。""" |
| c, mock_ner = client |
| mock_ner.extract.return_value = ([], DEFAULT_LABELS) |
| resp = c.post("/api/v1/extract", json={"text": "Hello world.", "labels": []}) |
| assert resp.status_code == 200 |
| assert resp.json()["labels_used"] == DEFAULT_LABELS |
|
|
|
|
| def test_extract_threshold_forwarded(client): |
| c, mock_ner = client |
| c.post("/api/v1/extract", |
| json={"text": "Hello", "labels": ["person"], "threshold": 0.8}) |
| mock_ner.extract.assert_called_once_with( |
| "Hello", ["person"], 0.8, language="auto", min_entities=None |
| ) |
|
|
|
|
| def test_extract_invalid_threshold(client): |
| c, _ = client |
| assert c.post("/api/v1/extract", |
| json={"text": "x", "threshold": 1.5}).status_code == 422 |
|
|
|
|
| def test_extract_language_field_forwarded(client): |
| c, mock_ner = client |
| c.post("/api/v1/extract", |
| json={"text": "北京协和医院", "labels": ["医院名称"], "language": "zh"}) |
| mock_ner.extract.assert_called_once_with( |
| "北京协和医院", ["医院名称"], 0.4, language="zh", min_entities=None |
| ) |
|
|
|
|
| def test_extract_min_entities_forwarded(client): |
| c, mock_ner = client |
| c.post("/api/v1/extract", |
| json={"text": "马云在杭州。", "language": "zh", "min_entities": 5}) |
| mock_ner.extract.assert_called_once_with( |
| "马云在杭州。", [], 0.4, language="zh", min_entities=5 |
| ) |
|
|
|
|
| def test_extract_negative_min_entities_rejected(client): |
| c, _ = client |
| resp = c.post("/api/v1/extract", |
| json={"text": "x", "min_entities": -1}) |
| assert resp.status_code == 422 |
|
|
|
|
| def test_extract_invalid_language(client): |
| c, _ = client |
| assert c.post("/api/v1/extract", |
| json={"text": "x", "language": "jp"}).status_code == 422 |
|
|
|
|
| def test_labels_used_echoed(client): |
| c, mock_ner = client |
| used = ["人名或姓名", "地名或城市"] |
| mock_ner.extract.return_value = ([], used) |
| resp = c.post("/api/v1/extract", json={"text": "马云在杭州。", "labels": ["人名或姓名"]}) |
| assert resp.json()["labels_used"] == used |
|
|
|
|
| def test_entity_response_fields(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="Apple", label="organization", score=0.95, start=0, end=5) |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "Apple is great.", "labels": ["organization"]}) |
| e = resp.json()["entities"][0] |
| assert {"text", "label", "score", "start", "end"} <= e.keys() |
| assert 0.0 <= e["score"] <= 1.0 |
| assert e["start"] < e["end"] |
|
|
|
|
| |
|
|
| def test_detect_language_english(): |
| assert detect_language("Elon Musk founded SpaceX in California.") == "en" |
|
|
|
|
| def test_detect_language_chinese(): |
| assert detect_language("阿里巴巴集团创始人马云于杭州卸任。") == "zh" |
|
|
|
|
| def test_detect_language_arabic(): |
| assert detect_language("أعلن الرئيس محمد بن سلمان عن مشروع نيوم.") == "ar" |
|
|
|
|
| def test_detect_language_mixed(): |
| assert detect_language("张伟加入了 Google 北京研发中心,负责 Android 优化。") == "mixed" |
|
|
|
|
| def test_detect_language_empty(): |
| assert detect_language("") == "en" |
|
|
|
|
| |
|
|
| def test_expand_bilingual_adds_english_for_chinese(): |
| result = expand_bilingual(["人名或姓名"]) |
| assert "full name of a person" in result |
|
|
|
|
| def test_expand_bilingual_adds_chinese_for_english(): |
| result = expand_bilingual(["company or organization name"]) |
| assert "公司或组织机构名称" in result |
|
|
|
|
| def test_expand_bilingual_no_duplicate(): |
| result = expand_bilingual(["人名或姓名", "full name of a person"]) |
| assert result.count("人名或姓名") == 1 |
| assert result.count("full name of a person") == 1 |
|
|
|
|
| def test_expand_bilingual_custom_label_preserved(): |
| result = expand_bilingual(["my custom label"]) |
| assert "my custom label" in result |
|
|
|
|
| def test_default_labels_bilingual(): |
| has_en = any(all(ord(c) < 128 for c in lbl) for lbl in DEFAULT_LABELS) |
| has_zh = any(any('一' <= c <= '鿿' for c in lbl) for lbl in DEFAULT_LABELS) |
| assert has_en and has_zh |
|
|
|
|
| |
|
|
| def test_labels_to_bert_types_chinese_label(): |
| types = labels_to_bert_types(["人名或姓名"]) |
| assert "PER" in types |
|
|
|
|
| def test_labels_to_bert_types_english_label(): |
| types = labels_to_bert_types(["geographical location"]) |
| assert "LOC" in types |
|
|
|
|
| def test_labels_to_bert_types_empty_returns_none(): |
| assert labels_to_bert_types([]) is None |
|
|
|
|
| def test_labels_to_bert_types_unmapped_returns_none(): |
| |
| assert labels_to_bert_types(["some unknown label"]) is None |
|
|
|
|
| |
|
|
| def test_english_person_org(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="Elon Musk", label="person", score=0.98, start=0, end=9), |
| Entity(text="Tesla", label="organization", score=0.96, start=18, end=23), |
| Entity(text="SpaceX", label="organization", score=0.97, start=28, end=34), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "Elon Musk is the CEO of Tesla and founded SpaceX.", |
| "labels": ["full name of a person", "company or organization name"], |
| "language": "en"}) |
| assert resp.status_code == 200 |
| texts = {e["text"] for e in resp.json()["entities"]} |
| assert {"Elon Musk", "Tesla", "SpaceX"} <= texts |
|
|
|
|
| def test_english_location_date(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="Paris", label="location", score=0.94, start=20, end=25), |
| Entity(text="2024", label="date", score=0.91, start=29, end=33), |
| Entity(text="France", label="location", score=0.93, start=38, end=44), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "The summit was held in Paris in 2024, in France.", |
| "labels": ["geographical location", "date or year"], |
| "language": "en"}) |
| texts = {e["text"] for e in resp.json()["entities"]} |
| assert {"Paris", "France", "2024"} <= texts |
|
|
|
|
| def test_english_threshold_forwarded(client): |
| c, mock_ner = client |
| c.post("/api/v1/extract", |
| json={"text": "NASA explored the Moon.", |
| "labels": ["company or organization name"], |
| "threshold": 0.8, "language": "en"}) |
| mock_ner.extract.assert_called_once_with( |
| "NASA explored the Moon.", ["company or organization name"], 0.8, |
| language="en", min_entities=None, |
| ) |
|
|
|
|
| |
|
|
| def test_chinese_person_org(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="马云", label="人名或姓名", score=0.96, start=8, end=10), |
| Entity(text="张勇", label="人名或姓名", score=0.94, start=25, end=27), |
| Entity(text="阿里巴巴", label="公司或组织机构名称", score=0.97, start=0, end=4), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "阿里巴巴集团创始人马云卸任,由张勇接任。", |
| "labels": ["人名或姓名", "公司或组织机构名称"], |
| "language": "zh"}) |
| texts = {e["text"] for e in resp.json()["entities"]} |
| assert {"马云", "张勇", "阿里巴巴"} <= texts |
|
|
|
|
| def test_chinese_entity_boundary(client): |
| """BERT NER 应精确截断实体边界,不含动词。""" |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="尤氏", label="人名或姓名", score=0.82, start=0, end=2), |
| Entity(text="王熙凤", label="人名或姓名", score=0.95, start=8, end=11), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "尤氏来请,王熙凤笑道:'你来了。'", |
| "labels": ["人名或姓名"], "language": "zh"}) |
| texts = {e["text"] for e in resp.json()["entities"]} |
| assert "尤氏" in texts |
| assert "王熙凤" in texts |
| assert "尤氏来请" not in texts |
| assert "王熙凤笑道" not in texts |
|
|
|
|
| def test_chinese_location_product(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="杭州", label="地名或城市", score=0.93, start=9, end=11), |
| Entity(text="淘宝", label="产品或品牌名称", score=0.91, start=14, end=16), |
| Entity(text="天猫", label="产品或品牌名称", score=0.92, start=17, end=19), |
| Entity(text="支付宝", label="产品或品牌名称", score=0.90, start=20, end=23), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "阿里巴巴总部位于杭州,旗下有淘宝、天猫、支付宝。", |
| "labels": ["地名或城市", "产品或品牌名称"], |
| "language": "zh"}) |
| texts = {e["text"] for e in resp.json()["entities"]} |
| assert {"杭州", "淘宝", "天猫", "支付宝"} <= texts |
|
|
|
|
| def test_chinese_auto_routes_to_zh(client): |
| """auto 检测到中文应路由到 ZH 模型(language 透传为 'auto',内部检测为 zh)。""" |
| c, mock_ner = client |
| c.post("/api/v1/extract", |
| json={"text": "马云创立了阿里巴巴。"}) |
| |
| mock_ner.extract.assert_called_once() |
| call_kwargs = mock_ner.extract.call_args |
| assert call_kwargs[1].get("language") == "auto" or call_kwargs[0][3] == "auto" |
|
|
|
|
| |
|
|
| def test_arabic_person_location(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="محمد بن سلمان", label="full name of a person", score=0.82, start=12, end=26), |
| Entity(text="المملكة العربية السعودية", label="geographical location", score=0.85, start=44, end=68), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "أعلن الرئيس محمد بن سلمان مشروع نيوم في المملكة العربية السعودية.", |
| "labels": ["full name of a person", "geographical location"], |
| "language": "ar"}) |
| texts = {e["text"] for e in resp.json()["entities"]} |
| assert "محمد بن سلمان" in texts |
| assert "المملكة العربية السعودية" in texts |
|
|
|
|
| |
|
|
| def test_mixed_entities_both_scripts(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="张伟", label="person", score=0.95, start=0, end=2), |
| Entity(text="Google", label="organization", score=0.97, start=9, end=15), |
| Entity(text="北京", label="location", score=0.93, start=25, end=27), |
| Entity(text="Android", label="product", score=0.91, start=33, end=40), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "张伟入职了 Google,驻扎在北京,负责 Android 研发。", |
| "labels": ["full name of a person", "人名或姓名", |
| "company or organization name", "公司或组织机构名称", |
| "geographical location", "地名或城市", |
| "product or technology name"], |
| "language": "mixed"}) |
| texts = {e["text"] for e in resp.json()["entities"]} |
| assert {"张伟", "Google", "北京", "Android"} <= texts |
|
|
|
|
| def test_mixed_no_cross_language_contamination(client): |
| c, mock_ner = client |
| mock_ner.extract.return_value = _ents( |
| Entity(text="OpenAI", label="organization", score=0.97, start=5, end=11), |
| Entity(text="王芳", label="person", score=0.93, start=15, end=17), |
| ) |
| resp = c.post("/api/v1/extract", |
| json={"text": "他在 OpenAI 工作,同事王芳也在同一部门。", |
| "labels": ["person", "organization"]}) |
| entities = resp.json()["entities"] |
| assert any(e["text"] == "OpenAI" and e["label"] == "organization" for e in entities) |
| assert any(e["text"] == "王芳" and e["label"] == "person" for e in entities) |
|
|
|
|
| |
|
|
| def _build_svc(): |
| """Construct a bare NERService with mocked backends and locks.""" |
| import threading |
| from app.ner import NERService |
| svc = NERService.__new__(NERService) |
| svc._en_lock = threading.Lock() |
| svc._zh_lock = threading.Lock() |
| svc._en_backend = MagicMock() |
| svc._zh_backend = MagicMock() |
| return svc |
|
|
|
|
| |
|
|
| def test_expected_min_short_text(): |
| from app.ner import NERService |
| assert NERService._expected_min("马云", []) == 1 |
|
|
|
|
| def test_expected_min_medium_text(): |
| from app.ner import NERService |
| text = "x" * 50 |
| assert NERService._expected_min(text, []) == 2 |
|
|
|
|
| def test_expected_min_long_text(): |
| from app.ner import NERService |
| text = "x" * 350 |
| assert NERService._expected_min(text, []) == 4 |
|
|
|
|
| def test_expected_min_label_floor_takes_over(): |
| """9 个标签 → ⌈9/3⌉=3,超过短文本的 length_floor=1,最终取 3。""" |
| from app.ner import NERService |
| short_text = "马云" |
| labels = [f"l{i}" for i in range(9)] |
| assert NERService._expected_min(short_text, labels) == 3 |
|
|
|
|
| |
|
|
| def test_zh_empty_triggers_fallback_and_adds(): |
| """ZH 主模型 0 个 → 触发兜底 → 返回兜底结果。""" |
| svc = _build_svc() |
| svc._zh_backend.predict.return_value = ([], []) |
| svc._en_backend.predict.return_value = _ents( |
| Entity(text="马云", label="person", score=0.75, start=0, end=2) |
| ) |
| entities, _ = svc.extract("马云", [], 0.4, language="zh") |
| assert any(e.text == "马云" for e in entities) |
| svc._zh_backend.predict.assert_called_once() |
| svc._en_backend.predict.assert_called_once() |
|
|
|
|
| def test_zh_sufficient_no_fallback(): |
| """ZH 主模型实体数 ≥ expected_min(=1 短文本) → 不调用兜底。""" |
| svc = _build_svc() |
| svc._zh_backend.predict.return_value = _ents( |
| Entity(text="马云", label="person", score=0.92, start=0, end=2) |
| ) |
| svc.extract("马云", [], 0.4, language="zh") |
| svc._en_backend.predict.assert_not_called() |
|
|
|
|
| def test_zh_insufficient_triggers_fallback_and_results_added(): |
| """ |
| 关键测试:ZH 返回 1 个,但文本长 → expected_min=4,不充分 → |
| 触发兜底,主结果 + 兜底结果一并返回(相加,不替换)。 |
| """ |
| svc = _build_svc() |
| long_text = "马云" + "x" * 350 |
| svc._zh_backend.predict.return_value = _ents( |
| Entity(text="马云", label="人名或姓名", score=0.95, start=0, end=2), |
| ) |
| svc._en_backend.predict.return_value = _ents( |
| Entity(text="Tesla", label="organization", score=0.90, start=10, end=15), |
| Entity(text="2024", label="date", score=0.88, start=20, end=24), |
| ) |
| entities, _ = svc.extract(long_text, [], 0.4, language="zh") |
|
|
| texts = {e.text for e in entities} |
| |
| assert "马云" in texts |
| assert "Tesla" in texts |
| assert "2024" in texts |
| assert len(entities) == 3 |
|
|
|
|
| def test_user_min_entities_overrides_heuristic(): |
| """请求里传 min_entities=5 时应覆盖启发式,主模型 3 个仍触发兜底。""" |
| svc = _build_svc() |
| svc._zh_backend.predict.return_value = _ents( |
| Entity(text="马云", label="person", score=0.95, start=0, end=2), |
| Entity(text="张勇", label="person", score=0.93, start=4, end=6), |
| Entity(text="杭州", label="location", score=0.91, start=8, end=10), |
| ) |
| svc._en_backend.predict.return_value = _ents( |
| Entity(text="Tesla", label="organization", score=0.85, start=15, end=20), |
| ) |
| entities, _ = svc.extract("马云、张勇、杭州 Tesla", [], 0.4, |
| language="zh", min_entities=5) |
|
|
| |
| assert len(entities) == 4 |
| svc._en_backend.predict.assert_called_once() |
|
|
|
|
| def test_user_min_entities_zero_disables_fallback(): |
| """min_entities=0 时主模型即使返回空也不触发兜底。""" |
| svc = _build_svc() |
| svc._zh_backend.predict.return_value = ([], []) |
| svc.extract("马云", [], 0.4, language="zh", min_entities=0) |
| svc._en_backend.predict.assert_not_called() |
|
|
|
|
| |
|
|
| def test_en_insufficient_triggers_zh_fallback_and_adds(): |
| """EN 主模型不充分 → 调 ZH 兜底 → 结果相加。""" |
| svc = _build_svc() |
| long_text = "Tesla" + "x" * 350 |
| svc._en_backend.predict.return_value = _ents( |
| Entity(text="Tesla", label="organization", score=0.95, start=0, end=5), |
| ) |
| svc._zh_backend.predict.return_value = _ents( |
| Entity(text="马云", label="人名或姓名", score=0.91, start=10, end=12), |
| ) |
| entities, _ = svc.extract(long_text, [], 0.4, language="en") |
|
|
| texts = {e.text for e in entities} |
| assert "Tesla" in texts and "马云" in texts |
| assert len(entities) == 2 |
|
|
|
|
| |
|
|
| def test_mixed_always_runs_both_models(): |
| """mixed 语言无视充分性,永远跑两个模型并合并。""" |
| svc = _build_svc() |
| svc._en_backend.predict.return_value = _ents( |
| Entity(text="Google", label="organization", score=0.95, start=5, end=11) |
| ) |
| svc._zh_backend.predict.return_value = _ents( |
| Entity(text="张伟", label="人名或姓名", score=0.91, start=0, end=2) |
| ) |
| entities, _ = svc.extract("张伟加入 Google。", [], 0.4, language="mixed") |
| texts = {e.text for e in entities} |
| assert {"Google", "张伟"} <= texts |
| svc._en_backend.predict.assert_called_once() |
| svc._zh_backend.predict.assert_called_once() |
|
|
|
|
| def test_merge_deduplicates_overlapping_spans(): |
| """两个模型对同一 span 都命中 → 保留得分最高的那条。""" |
| svc = _build_svc() |
| svc._en_backend.predict.return_value = ( |
| [Entity(text="张伟", label="person", score=0.70, start=0, end=2)], |
| ["person"], |
| ) |
| svc._zh_backend.predict.return_value = ( |
| [Entity(text="张伟", label="人名或姓名", score=0.92, start=0, end=2)], |
| ["人名或姓名"], |
| ) |
| entities, _ = svc.extract("张伟", [], 0.4, language="mixed") |
| matches = [e for e in entities if e.text == "张伟"] |
| assert len(matches) == 1 |
| assert matches[0].score == 0.92 |
|
|