nerserver / tests /test_extract.py
Robin
feat: smarter fallback — heuristic sufficiency + result merging (v3.1)
2288fd7
"""
Unit tests — no real model loaded (GLiNER/torch stubbed via conftest.py).
Covers:
- API contract (health, validation, threshold, language field)
- Dual-model routing (ZH → BERT backend, EN/AR/mixed → GLiNER backend)
- Optional labels + bilingual auto-expansion
- English / Chinese / Arabic / mixed-language scenarios
- labels_used echo in response
"""
from unittest.mock import MagicMock, patch, PropertyMock
import pytest
from fastapi.testclient import TestClient
from app.main import app
from app.models import Entity
from app.labels import DEFAULT_LABELS, expand_bilingual, labels_to_bert_types
from app.ner import detect_language
# ── Helpers ───────────────────────────────────────────────────────────────────
def _ents(*args: Entity) -> tuple[list[Entity], list[str]]:
"""Wrap entities in (entities, labels_used) tuple expected by NERService."""
return list(args), [e.label for e in args]
# ── Fixture ───────────────────────────────────────────────────────────────────
@pytest.fixture()
def client():
"""
Patch NERService so no model is actually loaded.
mock_ner.extract() returns ([], []) by default.
"""
mock_ner = MagicMock()
mock_ner.extract.return_value = ([], [])
with pytest.MonkeyPatch().context() as mp:
mp.setattr("app.main.NERService", lambda *_: mock_ner)
with TestClient(app) as c:
yield c, mock_ner
# ── System / API contract ─────────────────────────────────────────────────────
def test_health(client):
c, _ = client
assert c.get("/api/v1/health").json() == {"status": "ok"}
def test_extract_empty_text_returns_empty(client):
c, _ = client
resp = c.post("/api/v1/extract", json={"text": "", "labels": ["person"]})
assert resp.status_code == 200
assert resp.json()["entities"] == []
def test_extract_labels_optional(client):
"""labels 字段完全不传应正常返回 200。"""
c, mock_ner = client
mock_ner.extract.return_value = ([], DEFAULT_LABELS)
resp = c.post("/api/v1/extract", json={"text": "Some text."})
assert resp.status_code == 200
assert len(resp.json()["labels_used"]) > 0
def test_extract_empty_labels_uses_defaults(client):
"""labels=[] 时应使用默认双语标签集。"""
c, mock_ner = client
mock_ner.extract.return_value = ([], DEFAULT_LABELS)
resp = c.post("/api/v1/extract", json={"text": "Hello world.", "labels": []})
assert resp.status_code == 200
assert resp.json()["labels_used"] == DEFAULT_LABELS
def test_extract_threshold_forwarded(client):
c, mock_ner = client
c.post("/api/v1/extract",
json={"text": "Hello", "labels": ["person"], "threshold": 0.8})
mock_ner.extract.assert_called_once_with(
"Hello", ["person"], 0.8, language="auto", min_entities=None
)
def test_extract_invalid_threshold(client):
c, _ = client
assert c.post("/api/v1/extract",
json={"text": "x", "threshold": 1.5}).status_code == 422
def test_extract_language_field_forwarded(client):
c, mock_ner = client
c.post("/api/v1/extract",
json={"text": "北京协和医院", "labels": ["医院名称"], "language": "zh"})
mock_ner.extract.assert_called_once_with(
"北京协和医院", ["医院名称"], 0.4, language="zh", min_entities=None
)
def test_extract_min_entities_forwarded(client):
c, mock_ner = client
c.post("/api/v1/extract",
json={"text": "马云在杭州。", "language": "zh", "min_entities": 5})
mock_ner.extract.assert_called_once_with(
"马云在杭州。", [], 0.4, language="zh", min_entities=5
)
def test_extract_negative_min_entities_rejected(client):
c, _ = client
resp = c.post("/api/v1/extract",
json={"text": "x", "min_entities": -1})
assert resp.status_code == 422
def test_extract_invalid_language(client):
c, _ = client
assert c.post("/api/v1/extract",
json={"text": "x", "language": "jp"}).status_code == 422
def test_labels_used_echoed(client):
c, mock_ner = client
used = ["人名或姓名", "地名或城市"]
mock_ner.extract.return_value = ([], used)
resp = c.post("/api/v1/extract", json={"text": "马云在杭州。", "labels": ["人名或姓名"]})
assert resp.json()["labels_used"] == used
def test_entity_response_fields(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="Apple", label="organization", score=0.95, start=0, end=5)
)
resp = c.post("/api/v1/extract",
json={"text": "Apple is great.", "labels": ["organization"]})
e = resp.json()["entities"][0]
assert {"text", "label", "score", "start", "end"} <= e.keys()
assert 0.0 <= e["score"] <= 1.0
assert e["start"] < e["end"]
# ── Language detection ────────────────────────────────────────────────────────
def test_detect_language_english():
assert detect_language("Elon Musk founded SpaceX in California.") == "en"
def test_detect_language_chinese():
assert detect_language("阿里巴巴集团创始人马云于杭州卸任。") == "zh"
def test_detect_language_arabic():
assert detect_language("أعلن الرئيس محمد بن سلمان عن مشروع نيوم.") == "ar"
def test_detect_language_mixed():
assert detect_language("张伟加入了 Google 北京研发中心,负责 Android 优化。") == "mixed"
def test_detect_language_empty():
assert detect_language("") == "en"
# ── Bilingual label expansion ─────────────────────────────────────────────────
def test_expand_bilingual_adds_english_for_chinese():
result = expand_bilingual(["人名或姓名"])
assert "full name of a person" in result
def test_expand_bilingual_adds_chinese_for_english():
result = expand_bilingual(["company or organization name"])
assert "公司或组织机构名称" in result
def test_expand_bilingual_no_duplicate():
result = expand_bilingual(["人名或姓名", "full name of a person"])
assert result.count("人名或姓名") == 1
assert result.count("full name of a person") == 1
def test_expand_bilingual_custom_label_preserved():
result = expand_bilingual(["my custom label"])
assert "my custom label" in result
def test_default_labels_bilingual():
has_en = any(all(ord(c) < 128 for c in lbl) for lbl in DEFAULT_LABELS)
has_zh = any(any('一' <= c <= '鿿' for c in lbl) for lbl in DEFAULT_LABELS)
assert has_en and has_zh
# ── BERT label mapping ────────────────────────────────────────────────────────
def test_labels_to_bert_types_chinese_label():
types = labels_to_bert_types(["人名或姓名"])
assert "PER" in types
def test_labels_to_bert_types_english_label():
types = labels_to_bert_types(["geographical location"])
assert "LOC" in types
def test_labels_to_bert_types_empty_returns_none():
assert labels_to_bert_types([]) is None
def test_labels_to_bert_types_unmapped_returns_none():
# 无法映射的标签 → 不过滤(返回 None)
assert labels_to_bert_types(["some unknown label"]) is None
# ── English ───────────────────────────────────────────────────────────────────
def test_english_person_org(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="Elon Musk", label="person", score=0.98, start=0, end=9),
Entity(text="Tesla", label="organization", score=0.96, start=18, end=23),
Entity(text="SpaceX", label="organization", score=0.97, start=28, end=34),
)
resp = c.post("/api/v1/extract",
json={"text": "Elon Musk is the CEO of Tesla and founded SpaceX.",
"labels": ["full name of a person", "company or organization name"],
"language": "en"})
assert resp.status_code == 200
texts = {e["text"] for e in resp.json()["entities"]}
assert {"Elon Musk", "Tesla", "SpaceX"} <= texts
def test_english_location_date(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="Paris", label="location", score=0.94, start=20, end=25),
Entity(text="2024", label="date", score=0.91, start=29, end=33),
Entity(text="France", label="location", score=0.93, start=38, end=44),
)
resp = c.post("/api/v1/extract",
json={"text": "The summit was held in Paris in 2024, in France.",
"labels": ["geographical location", "date or year"],
"language": "en"})
texts = {e["text"] for e in resp.json()["entities"]}
assert {"Paris", "France", "2024"} <= texts
def test_english_threshold_forwarded(client):
c, mock_ner = client
c.post("/api/v1/extract",
json={"text": "NASA explored the Moon.",
"labels": ["company or organization name"],
"threshold": 0.8, "language": "en"})
mock_ner.extract.assert_called_once_with(
"NASA explored the Moon.", ["company or organization name"], 0.8,
language="en", min_entities=None,
)
# ── Chinese (BERT backend) ────────────────────────────────────────────────────
def test_chinese_person_org(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="马云", label="人名或姓名", score=0.96, start=8, end=10),
Entity(text="张勇", label="人名或姓名", score=0.94, start=25, end=27),
Entity(text="阿里巴巴", label="公司或组织机构名称", score=0.97, start=0, end=4),
)
resp = c.post("/api/v1/extract",
json={"text": "阿里巴巴集团创始人马云卸任,由张勇接任。",
"labels": ["人名或姓名", "公司或组织机构名称"],
"language": "zh"})
texts = {e["text"] for e in resp.json()["entities"]}
assert {"马云", "张勇", "阿里巴巴"} <= texts
def test_chinese_entity_boundary(client):
"""BERT NER 应精确截断实体边界,不含动词。"""
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="尤氏", label="人名或姓名", score=0.82, start=0, end=2),
Entity(text="王熙凤", label="人名或姓名", score=0.95, start=8, end=11),
)
resp = c.post("/api/v1/extract",
json={"text": "尤氏来请,王熙凤笑道:'你来了。'",
"labels": ["人名或姓名"], "language": "zh"})
texts = {e["text"] for e in resp.json()["entities"]}
assert "尤氏" in texts
assert "王熙凤" in texts
assert "尤氏来请" not in texts
assert "王熙凤笑道" not in texts
def test_chinese_location_product(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="杭州", label="地名或城市", score=0.93, start=9, end=11),
Entity(text="淘宝", label="产品或品牌名称", score=0.91, start=14, end=16),
Entity(text="天猫", label="产品或品牌名称", score=0.92, start=17, end=19),
Entity(text="支付宝", label="产品或品牌名称", score=0.90, start=20, end=23),
)
resp = c.post("/api/v1/extract",
json={"text": "阿里巴巴总部位于杭州,旗下有淘宝、天猫、支付宝。",
"labels": ["地名或城市", "产品或品牌名称"],
"language": "zh"})
texts = {e["text"] for e in resp.json()["entities"]}
assert {"杭州", "淘宝", "天猫", "支付宝"} <= texts
def test_chinese_auto_routes_to_zh(client):
"""auto 检测到中文应路由到 ZH 模型(language 透传为 'auto',内部检测为 zh)。"""
c, mock_ner = client
c.post("/api/v1/extract",
json={"text": "马云创立了阿里巴巴。"})
# NERService.extract 被调用时 language='auto',路由逻辑在 ner.py 内部处理
mock_ner.extract.assert_called_once()
call_kwargs = mock_ner.extract.call_args
assert call_kwargs[1].get("language") == "auto" or call_kwargs[0][3] == "auto"
# ── Arabic ────────────────────────────────────────────────────────────────────
def test_arabic_person_location(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="محمد بن سلمان", label="full name of a person", score=0.82, start=12, end=26),
Entity(text="المملكة العربية السعودية", label="geographical location", score=0.85, start=44, end=68),
)
resp = c.post("/api/v1/extract",
json={"text": "أعلن الرئيس محمد بن سلمان مشروع نيوم في المملكة العربية السعودية.",
"labels": ["full name of a person", "geographical location"],
"language": "ar"})
texts = {e["text"] for e in resp.json()["entities"]}
assert "محمد بن سلمان" in texts
assert "المملكة العربية السعودية" in texts
# ── Mixed Chinese-English ─────────────────────────────────────────────────────
def test_mixed_entities_both_scripts(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="张伟", label="person", score=0.95, start=0, end=2),
Entity(text="Google", label="organization", score=0.97, start=9, end=15),
Entity(text="北京", label="location", score=0.93, start=25, end=27),
Entity(text="Android", label="product", score=0.91, start=33, end=40),
)
resp = c.post("/api/v1/extract",
json={"text": "张伟入职了 Google,驻扎在北京,负责 Android 研发。",
"labels": ["full name of a person", "人名或姓名",
"company or organization name", "公司或组织机构名称",
"geographical location", "地名或城市",
"product or technology name"],
"language": "mixed"})
texts = {e["text"] for e in resp.json()["entities"]}
assert {"张伟", "Google", "北京", "Android"} <= texts
def test_mixed_no_cross_language_contamination(client):
c, mock_ner = client
mock_ner.extract.return_value = _ents(
Entity(text="OpenAI", label="organization", score=0.97, start=5, end=11),
Entity(text="王芳", label="person", score=0.93, start=15, end=17),
)
resp = c.post("/api/v1/extract",
json={"text": "他在 OpenAI 工作,同事王芳也在同一部门。",
"labels": ["person", "organization"]})
entities = resp.json()["entities"]
assert any(e["text"] == "OpenAI" and e["label"] == "organization" for e in entities)
assert any(e["text"] == "王芳" and e["label"] == "person" for e in entities)
# ── Fallback & merge (NERService unit tests, no HTTP) ────────────────────────
def _build_svc():
"""Construct a bare NERService with mocked backends and locks."""
import threading
from app.ner import NERService
svc = NERService.__new__(NERService)
svc._en_lock = threading.Lock()
svc._zh_lock = threading.Lock()
svc._en_backend = MagicMock()
svc._zh_backend = MagicMock()
return svc
# ── Sufficiency heuristic ────────────────────────────────────────────────────
def test_expected_min_short_text():
from app.ner import NERService
assert NERService._expected_min("马云", []) == 1
def test_expected_min_medium_text():
from app.ner import NERService
text = "x" * 50
assert NERService._expected_min(text, []) == 2
def test_expected_min_long_text():
from app.ner import NERService
text = "x" * 350
assert NERService._expected_min(text, []) == 4
def test_expected_min_label_floor_takes_over():
"""9 个标签 → ⌈9/3⌉=3,超过短文本的 length_floor=1,最终取 3。"""
from app.ner import NERService
short_text = "马云"
labels = [f"l{i}" for i in range(9)]
assert NERService._expected_min(short_text, labels) == 3
# ── ZH branch fallback ───────────────────────────────────────────────────────
def test_zh_empty_triggers_fallback_and_adds():
"""ZH 主模型 0 个 → 触发兜底 → 返回兜底结果。"""
svc = _build_svc()
svc._zh_backend.predict.return_value = ([], [])
svc._en_backend.predict.return_value = _ents(
Entity(text="马云", label="person", score=0.75, start=0, end=2)
)
entities, _ = svc.extract("马云", [], 0.4, language="zh")
assert any(e.text == "马云" for e in entities)
svc._zh_backend.predict.assert_called_once()
svc._en_backend.predict.assert_called_once()
def test_zh_sufficient_no_fallback():
"""ZH 主模型实体数 ≥ expected_min(=1 短文本) → 不调用兜底。"""
svc = _build_svc()
svc._zh_backend.predict.return_value = _ents(
Entity(text="马云", label="person", score=0.92, start=0, end=2)
)
svc.extract("马云", [], 0.4, language="zh")
svc._en_backend.predict.assert_not_called()
def test_zh_insufficient_triggers_fallback_and_results_added():
"""
关键测试:ZH 返回 1 个,但文本长 → expected_min=4,不充分 →
触发兜底,主结果 + 兜底结果一并返回(相加,不替换)。
"""
svc = _build_svc()
long_text = "马云" + "x" * 350 # length_floor = 4
svc._zh_backend.predict.return_value = _ents(
Entity(text="马云", label="人名或姓名", score=0.95, start=0, end=2),
)
svc._en_backend.predict.return_value = _ents(
Entity(text="Tesla", label="organization", score=0.90, start=10, end=15),
Entity(text="2024", label="date", score=0.88, start=20, end=24),
)
entities, _ = svc.extract(long_text, [], 0.4, language="zh")
texts = {e.text for e in entities}
# 主模型的"马云"必须保留,同时兜底的 Tesla / 2024 也加进来
assert "马云" in texts
assert "Tesla" in texts
assert "2024" in texts
assert len(entities) == 3 # 1 + 2 = 3,确实是相加
def test_user_min_entities_overrides_heuristic():
"""请求里传 min_entities=5 时应覆盖启发式,主模型 3 个仍触发兜底。"""
svc = _build_svc()
svc._zh_backend.predict.return_value = _ents(
Entity(text="马云", label="person", score=0.95, start=0, end=2),
Entity(text="张勇", label="person", score=0.93, start=4, end=6),
Entity(text="杭州", label="location", score=0.91, start=8, end=10),
)
svc._en_backend.predict.return_value = _ents(
Entity(text="Tesla", label="organization", score=0.85, start=15, end=20),
)
entities, _ = svc.extract("马云、张勇、杭州 Tesla", [], 0.4,
language="zh", min_entities=5)
# 3 < 5 → 触发兜底;最终 3 + 1 = 4
assert len(entities) == 4
svc._en_backend.predict.assert_called_once()
def test_user_min_entities_zero_disables_fallback():
"""min_entities=0 时主模型即使返回空也不触发兜底。"""
svc = _build_svc()
svc._zh_backend.predict.return_value = ([], [])
svc.extract("马云", [], 0.4, language="zh", min_entities=0)
svc._en_backend.predict.assert_not_called()
# ── EN branch fallback (symmetric) ───────────────────────────────────────────
def test_en_insufficient_triggers_zh_fallback_and_adds():
"""EN 主模型不充分 → 调 ZH 兜底 → 结果相加。"""
svc = _build_svc()
long_text = "Tesla" + "x" * 350
svc._en_backend.predict.return_value = _ents(
Entity(text="Tesla", label="organization", score=0.95, start=0, end=5),
)
svc._zh_backend.predict.return_value = _ents(
Entity(text="马云", label="人名或姓名", score=0.91, start=10, end=12),
)
entities, _ = svc.extract(long_text, [], 0.4, language="en")
texts = {e.text for e in entities}
assert "Tesla" in texts and "马云" in texts
assert len(entities) == 2
# ── Mixed: always merge both ─────────────────────────────────────────────────
def test_mixed_always_runs_both_models():
"""mixed 语言无视充分性,永远跑两个模型并合并。"""
svc = _build_svc()
svc._en_backend.predict.return_value = _ents(
Entity(text="Google", label="organization", score=0.95, start=5, end=11)
)
svc._zh_backend.predict.return_value = _ents(
Entity(text="张伟", label="人名或姓名", score=0.91, start=0, end=2)
)
entities, _ = svc.extract("张伟加入 Google。", [], 0.4, language="mixed")
texts = {e.text for e in entities}
assert {"Google", "张伟"} <= texts
svc._en_backend.predict.assert_called_once()
svc._zh_backend.predict.assert_called_once()
def test_merge_deduplicates_overlapping_spans():
"""两个模型对同一 span 都命中 → 保留得分最高的那条。"""
svc = _build_svc()
svc._en_backend.predict.return_value = (
[Entity(text="张伟", label="person", score=0.70, start=0, end=2)],
["person"],
)
svc._zh_backend.predict.return_value = (
[Entity(text="张伟", label="人名或姓名", score=0.92, start=0, end=2)],
["人名或姓名"],
)
entities, _ = svc.extract("张伟", [], 0.4, language="mixed")
matches = [e for e in entities if e.text == "张伟"]
assert len(matches) == 1
assert matches[0].score == 0.92 # 高分胜出