feat: dual-model routing + fallback strategy (v3.0)
Browse filesModels
EN/AR → urchade/gliner_multi-v2.1 (GLiNER zero-shot, ~1 GB, ~500 ms)
ZH → shibing624/bert4ner-base-chinese (BERT NER, ~400 MB, ~100 ms)
mixed → both models run in parallel, results merged + deduplicated
Language detection (app/ner.py)
Layer-1 Unicode script ratio (CJK / Arabic / Latin)
mixed = cjk>=8% AND latin>=10% (prevents dominant-CJK from
masking bilingual text)
Layer-2 langdetect n-gram fallback when Layer-1 returns 'en'
Fallback / merge (NERService.extract)
zh primary=BERT, fallback=GLiNER when entities==0
en/ar primary=GLiNER, fallback=BERT when entities==0
mixed run both, _deduplicate keeps highest-score per (start,end) span
Other changes
app/config.py EN_MODEL_NAME / ZH_MODEL_NAME env vars; legacy MODEL_NAME kept
app/labels.py BERT_TYPE_TO_LABEL, LABEL_TO_BERT_TYPES, labels_to_bert_types()
requirements.txt add transformers>=4.40.0, langdetect>=1.0.9
Dockerfile pre-download both models at build time
tests 38 tests (was 25); cover fallback, merge, dedup, lang-detect
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- .env.example +7 -1
- Dockerfile +13 -3
- app/config.py +13 -3
- app/labels.py +67 -28
- app/main.py +15 -9
- app/ner.py +266 -55
- requirements.txt +2 -0
- tests/test_extract.py +228 -110
|
@@ -1,6 +1,12 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
MODEL_CACHE_DIR=./model_cache
|
| 3 |
HOST=0.0.0.0
|
| 4 |
PORT=4000
|
|
|
|
| 5 |
# 国内环境取消注释;HF Spaces 上留空
|
| 6 |
# HF_ENDPOINT=https://hf-mirror.com
|
|
|
|
| 1 |
+
# 英文 / 阿拉伯文 / 混合文本模型(GLiNER 零样本,~1GB)
|
| 2 |
+
EN_MODEL_NAME=urchade/gliner_multi-v2.1
|
| 3 |
+
|
| 4 |
+
# 中文专用模型(BERT NER,~400MB,快速)
|
| 5 |
+
ZH_MODEL_NAME=shibing624/bert4ner-base-chinese
|
| 6 |
+
|
| 7 |
MODEL_CACHE_DIR=./model_cache
|
| 8 |
HOST=0.0.0.0
|
| 9 |
PORT=4000
|
| 10 |
+
|
| 11 |
# 国内环境取消注释;HF Spaces 上留空
|
| 12 |
# HF_ENDPOINT=https://hf-mirror.com
|
|
@@ -5,11 +5,19 @@ WORKDIR /app
|
|
| 5 |
COPY requirements.txt .
|
| 6 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
#
|
| 10 |
RUN python -c "\
|
| 11 |
from gliner import GLiNER; \
|
| 12 |
-
GLiNER.from_pretrained('
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
COPY app/ app/
|
| 15 |
COPY run.py .
|
|
@@ -17,6 +25,8 @@ COPY run.py .
|
|
| 17 |
ENV HOST=0.0.0.0
|
| 18 |
ENV PORT=7860
|
| 19 |
ENV MODEL_CACHE_DIR=/app/model_cache
|
|
|
|
|
|
|
| 20 |
|
| 21 |
EXPOSE 7860
|
| 22 |
|
|
|
|
| 5 |
COPY requirements.txt .
|
| 6 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
|
| 8 |
+
# ── 构建时预下载两个模型,冷启动无需联网 ──────────────────────────────────────
|
| 9 |
+
# EN / AR 模型:GLiNER 零样本多语言(~1 GB)
|
| 10 |
RUN python -c "\
|
| 11 |
from gliner import GLiNER; \
|
| 12 |
+
GLiNER.from_pretrained('urchade/gliner_multi-v2.1', cache_dir='/app/model_cache')"
|
| 13 |
+
|
| 14 |
+
# ZH 模型:BERT 专用中文 NER(~400 MB)
|
| 15 |
+
RUN python -c "\
|
| 16 |
+
from transformers import pipeline; \
|
| 17 |
+
pipeline('token-classification', \
|
| 18 |
+
model='shibing624/bert4ner-base-chinese', \
|
| 19 |
+
model_kwargs={'cache_dir': '/app/model_cache'}, \
|
| 20 |
+
aggregation_strategy='simple')"
|
| 21 |
|
| 22 |
COPY app/ app/
|
| 23 |
COPY run.py .
|
|
|
|
| 25 |
ENV HOST=0.0.0.0
|
| 26 |
ENV PORT=7860
|
| 27 |
ENV MODEL_CACHE_DIR=/app/model_cache
|
| 28 |
+
ENV EN_MODEL_NAME=urchade/gliner_multi-v2.1
|
| 29 |
+
ENV ZH_MODEL_NAME=shibing624/bert4ner-base-chinese
|
| 30 |
|
| 31 |
EXPOSE 7860
|
| 32 |
|
|
@@ -1,12 +1,22 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
MODEL_CACHE_DIR: str = os.getenv("MODEL_CACHE_DIR", "./model_cache")
|
| 5 |
HOST: str = os.getenv("HOST", "0.0.0.0")
|
| 6 |
PORT: int = int(os.getenv("PORT", "4000"))
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
# can reach huggingface.co directly without forcing the mirror.
|
| 10 |
_hf_endpoint = os.getenv("HF_ENDPOINT")
|
| 11 |
if _hf_endpoint:
|
| 12 |
os.environ["HF_ENDPOINT"] = _hf_endpoint
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
+
# ── 模型配置 ──────────────────────────────────────────────────────────────────
|
| 4 |
+
# 英文 / 阿拉伯文 / 混合:轻量 GLiNER 零样本模型
|
| 5 |
+
EN_MODEL_NAME: str = os.getenv("EN_MODEL_NAME", "urchade/gliner_multi-v2.1")
|
| 6 |
+
|
| 7 |
+
# 中文:专用 BERT NER 模型(400MB,~100ms/次,4 种固定实体类型)
|
| 8 |
+
ZH_MODEL_NAME: str = os.getenv("ZH_MODEL_NAME", "shibing624/bert4ner-base-chinese")
|
| 9 |
+
|
| 10 |
+
# 兼容旧版环境变量(若设置 MODEL_NAME,则覆盖 EN_MODEL_NAME)
|
| 11 |
+
_legacy = os.getenv("MODEL_NAME")
|
| 12 |
+
if _legacy:
|
| 13 |
+
EN_MODEL_NAME = _legacy
|
| 14 |
+
|
| 15 |
MODEL_CACHE_DIR: str = os.getenv("MODEL_CACHE_DIR", "./model_cache")
|
| 16 |
HOST: str = os.getenv("HOST", "0.0.0.0")
|
| 17 |
PORT: int = int(os.getenv("PORT", "4000"))
|
| 18 |
|
| 19 |
+
# 可选:国内镜像(留空则使用 huggingface.co)
|
|
|
|
| 20 |
_hf_endpoint = os.getenv("HF_ENDPOINT")
|
| 21 |
if _hf_endpoint:
|
| 22 |
os.environ["HF_ENDPOINT"] = _hf_endpoint
|
|
@@ -1,30 +1,32 @@
|
|
| 1 |
"""
|
| 2 |
双语标签管理模块
|
| 3 |
-
────────────────
|
| 4 |
-
* DEFAULT_LABELS
|
| 5 |
-
* expand_bilingual
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
#
|
| 9 |
_PAIRS: list[tuple[str, str]] = [
|
| 10 |
-
("full name of a person",
|
| 11 |
-
("company or organization name",
|
| 12 |
-
("geographical location",
|
| 13 |
-
("product or technology name",
|
| 14 |
-
("date or year",
|
| 15 |
-
("hospital or medical institution",
|
| 16 |
-
("university or research institution","大学或研究机构"),
|
| 17 |
-
("project or initiative name",
|
| 18 |
-
("legislation or policy name",
|
| 19 |
-
("monetary amount",
|
| 20 |
-
("job title or position",
|
| 21 |
-
("event name",
|
| 22 |
]
|
| 23 |
|
| 24 |
-
# 默认标签集:英中并列
|
| 25 |
DEFAULT_LABELS: list[str] = [item for pair in _PAIRS for item in pair]
|
| 26 |
|
| 27 |
-
# 快速查找
|
| 28 |
_EN_TO_ZH: dict[str, str] = {en: zh for en, zh in _PAIRS}
|
| 29 |
_ZH_TO_EN: dict[str, str] = {zh: en for en, zh in _PAIRS}
|
| 30 |
|
|
@@ -32,16 +34,7 @@ _ZH_TO_EN: dict[str, str] = {zh: en for en, zh in _PAIRS}
|
|
| 32 |
def expand_bilingual(labels: list[str]) -> list[str]:
|
| 33 |
"""
|
| 34 |
为调用者传入的标签自动补充另一语言的对等描述。
|
| 35 |
-
|
| 36 |
-
例如:
|
| 37 |
-
["人名或姓名", "company or organization name"]
|
| 38 |
-
→ ["人名或姓名", "full name of a person",
|
| 39 |
-
"company or organization name", "公司或组织机构名称"]
|
| 40 |
-
|
| 41 |
-
规则:
|
| 42 |
-
* 已有标签保持原位不变
|
| 43 |
-
* 对等标签紧随其后插入(若已存在则跳过)
|
| 44 |
-
* 未在对照表中的自定义标签原样保留,不做处理
|
| 45 |
"""
|
| 46 |
seen: set[str] = set(labels)
|
| 47 |
result: list[str] = []
|
|
@@ -52,3 +45,49 @@ def expand_bilingual(labels: list[str]) -> list[str]:
|
|
| 52 |
result.append(counterpart)
|
| 53 |
seen.add(counterpart)
|
| 54 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
双语标签管理模块
|
| 3 |
+
────────────────────────────────────────────────────────────────
|
| 4 |
+
* DEFAULT_LABELS — 内置通用双语标签集(labels 为空时使用)
|
| 5 |
+
* expand_bilingual — 自动为已有标签补充对等的另一语言版本
|
| 6 |
+
* BERT_TYPE_TO_LABEL — 中文 BERT 模型固定实体类型 → 标准双语标签
|
| 7 |
+
* LABEL_TO_BERT_TYPES — 标准标签 → 对应 BERT 实体类型列表(用于过滤)
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
# ── 英中对照表 ─────────────────────────────────────────────────────────────────
|
| 11 |
_PAIRS: list[tuple[str, str]] = [
|
| 12 |
+
("full name of a person", "人名或姓名"),
|
| 13 |
+
("company or organization name", "公司或组织机构名称"),
|
| 14 |
+
("geographical location", "地名或城市"),
|
| 15 |
+
("product or technology name", "产品或技术名称"),
|
| 16 |
+
("date or year", "日期或年份"),
|
| 17 |
+
("hospital or medical institution", "医院或医疗机构名称"),
|
| 18 |
+
("university or research institution", "大学或研究机构"),
|
| 19 |
+
("project or initiative name", "项目或计划名称"),
|
| 20 |
+
("legislation or policy name", "法规或政策名称"),
|
| 21 |
+
("monetary amount", "金额或货币"),
|
| 22 |
+
("job title or position", "职位或头衔"),
|
| 23 |
+
("event name", "事件或活动名称"),
|
| 24 |
]
|
| 25 |
|
| 26 |
+
# 默认标签集:英中并列
|
| 27 |
DEFAULT_LABELS: list[str] = [item for pair in _PAIRS for item in pair]
|
| 28 |
|
| 29 |
+
# 快速查找
|
| 30 |
_EN_TO_ZH: dict[str, str] = {en: zh for en, zh in _PAIRS}
|
| 31 |
_ZH_TO_EN: dict[str, str] = {zh: en for en, zh in _PAIRS}
|
| 32 |
|
|
|
|
| 34 |
def expand_bilingual(labels: list[str]) -> list[str]:
|
| 35 |
"""
|
| 36 |
为调用者传入的标签自动补充另一语言的对等描述。
|
| 37 |
+
已有标签保持原位不变,对等标签紧随其后插入(若已存在则跳过)。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"""
|
| 39 |
seen: set[str] = set(labels)
|
| 40 |
result: list[str] = []
|
|
|
|
| 45 |
result.append(counterpart)
|
| 46 |
seen.add(counterpart)
|
| 47 |
return result
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ── 中文 BERT NER 固定类型映射 ────────────────────────────────────────────────
|
| 51 |
+
# shibing624/bert4ner-base-chinese 输出的实体类型
|
| 52 |
+
BERT_TYPE_TO_LABEL: dict[str, str] = {
|
| 53 |
+
"PER": "人名或姓名",
|
| 54 |
+
"LOC": "地名或城市",
|
| 55 |
+
"ORG": "公司或组织机构名称",
|
| 56 |
+
"TIME": "日期或年份",
|
| 57 |
+
"GPE": "地名或城市", # 部分模型区分 GPE(地缘政治实体)
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# 标准标签 → BERT 类型列表(用于用户自定义标签过滤)
|
| 61 |
+
LABEL_TO_BERT_TYPES: dict[str, list[str]] = {
|
| 62 |
+
# 人名
|
| 63 |
+
"人名或姓名": ["PER"],
|
| 64 |
+
"full name of a person": ["PER"],
|
| 65 |
+
# 地名
|
| 66 |
+
"地名或城市": ["LOC", "GPE"],
|
| 67 |
+
"geographical location": ["LOC", "GPE"],
|
| 68 |
+
# 机构
|
| 69 |
+
"公司或组织机构名称": ["ORG"],
|
| 70 |
+
"company or organization name": ["ORG"],
|
| 71 |
+
"医院或医疗机构名称": ["ORG"],
|
| 72 |
+
"hospital or medical institution": ["ORG"],
|
| 73 |
+
"大学或研究机构": ["ORG"],
|
| 74 |
+
"university or research institution": ["ORG"],
|
| 75 |
+
# 时间
|
| 76 |
+
"日期或年份": ["TIME"],
|
| 77 |
+
"date or year": ["TIME"],
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def labels_to_bert_types(labels: list[str]) -> set[str] | None:
|
| 82 |
+
"""
|
| 83 |
+
将用户标签列表转换为 BERT 实体类型集合。
|
| 84 |
+
返回 None 表示"接受所有类型"(labels 为空或无法映射时)。
|
| 85 |
+
"""
|
| 86 |
+
if not labels:
|
| 87 |
+
return None # 无限制,返回全部
|
| 88 |
+
types: set[str] = set()
|
| 89 |
+
for lbl in labels:
|
| 90 |
+
mapped = LABEL_TO_BERT_TYPES.get(lbl)
|
| 91 |
+
if mapped:
|
| 92 |
+
types.update(mapped)
|
| 93 |
+
return types if types else None # 无映射 → 不过滤
|
|
@@ -3,7 +3,7 @@ from contextlib import asynccontextmanager
|
|
| 3 |
|
| 4 |
from fastapi import FastAPI
|
| 5 |
|
| 6 |
-
from app.config import MODEL_CACHE_DIR,
|
| 7 |
from app.logger import get_logger
|
| 8 |
from app.models import ExtractRequest, ExtractResponse
|
| 9 |
from app.ner import NERService
|
|
@@ -15,9 +15,14 @@ ner_service: NERService | None = None
|
|
| 15 |
@asynccontextmanager
|
| 16 |
async def lifespan(app: FastAPI):
|
| 17 |
global ner_service
|
| 18 |
-
logger.info(
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
yield
|
| 22 |
ner_service = None
|
| 23 |
|
|
@@ -25,11 +30,12 @@ async def lifespan(app: FastAPI):
|
|
| 25 |
app = FastAPI(
|
| 26 |
title="NER API",
|
| 27 |
description=(
|
| 28 |
-
"Zero-shot Named Entity Recognition powered by GLiNER
|
| 29 |
-
"
|
|
|
|
| 30 |
"Labels are optional — omit them to use built-in bilingual defaults."
|
| 31 |
),
|
| 32 |
-
version="
|
| 33 |
lifespan=lifespan,
|
| 34 |
)
|
| 35 |
|
|
@@ -58,9 +64,9 @@ def extract(req: ExtractRequest):
|
|
| 58 |
elapsed_ms = (time.perf_counter() - t0) * 1000
|
| 59 |
|
| 60 |
logger.info(
|
| 61 |
-
"extract response | entities=%d elapsed=%.1fms
|
| 62 |
len(entities),
|
| 63 |
elapsed_ms,
|
| 64 |
-
|
| 65 |
)
|
| 66 |
return ExtractResponse(entities=entities, labels_used=labels_used)
|
|
|
|
| 3 |
|
| 4 |
from fastapi import FastAPI
|
| 5 |
|
| 6 |
+
from app.config import EN_MODEL_NAME, MODEL_CACHE_DIR, ZH_MODEL_NAME
|
| 7 |
from app.logger import get_logger
|
| 8 |
from app.models import ExtractRequest, ExtractResponse
|
| 9 |
from app.ner import NERService
|
|
|
|
| 15 |
@asynccontextmanager
|
| 16 |
async def lifespan(app: FastAPI):
|
| 17 |
global ner_service
|
| 18 |
+
logger.info(
|
| 19 |
+
"Initializing NER service | en_model=%s zh_model=%s cache=%s",
|
| 20 |
+
EN_MODEL_NAME, ZH_MODEL_NAME, MODEL_CACHE_DIR,
|
| 21 |
+
)
|
| 22 |
+
ner_service = NERService(EN_MODEL_NAME, ZH_MODEL_NAME, MODEL_CACHE_DIR)
|
| 23 |
+
# 预热:启动时同时加载两个模型,首个请求无需等待
|
| 24 |
+
ner_service.warmup()
|
| 25 |
+
logger.info("NER service ready")
|
| 26 |
yield
|
| 27 |
ner_service = None
|
| 28 |
|
|
|
|
| 30 |
app = FastAPI(
|
| 31 |
title="NER API",
|
| 32 |
description=(
|
| 33 |
+
"Zero-shot Named Entity Recognition powered by GLiNER (EN/AR) "
|
| 34 |
+
"and BERT-Chinese (ZH). "
|
| 35 |
+
"Supports English · Chinese · Arabic · mixed-language text. "
|
| 36 |
"Labels are optional — omit them to use built-in bilingual defaults."
|
| 37 |
),
|
| 38 |
+
version="3.0.0",
|
| 39 |
lifespan=lifespan,
|
| 40 |
)
|
| 41 |
|
|
|
|
| 64 |
elapsed_ms = (time.perf_counter() - t0) * 1000
|
| 65 |
|
| 66 |
logger.info(
|
| 67 |
+
"extract response | entities=%d elapsed=%.1fms language=%s",
|
| 68 |
len(entities),
|
| 69 |
elapsed_ms,
|
| 70 |
+
req.language,
|
| 71 |
)
|
| 72 |
return ExtractResponse(entities=entities, labels_used=labels_used)
|
|
@@ -1,45 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import unicodedata
|
|
|
|
| 2 |
|
| 3 |
from gliner import GLiNER
|
| 4 |
|
| 5 |
-
from app.labels import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from app.models import Entity
|
| 7 |
|
| 8 |
|
| 9 |
# ── 语言检测 ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
def
|
| 12 |
-
"""
|
| 13 |
-
通过 Unicode 脚本比例判断文本语言。
|
| 14 |
-
返回: 'zh' | 'ar' | 'mixed' | 'en'
|
| 15 |
-
"""
|
| 16 |
-
if not text:
|
| 17 |
-
return "en"
|
| 18 |
-
|
| 19 |
cjk = arabic = letters = 0
|
| 20 |
for ch in text:
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
elif 0x0600 <= cp <= 0x06FF or 0x0750 <= cp <= 0x077F:
|
| 31 |
-
arabic += 1
|
| 32 |
-
|
| 33 |
if not letters:
|
| 34 |
return "en"
|
| 35 |
cjk_r = cjk / letters
|
| 36 |
ar_r = arabic / letters
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
return "zh"
|
| 39 |
-
if ar_r >= 0.20
|
| 40 |
return "ar"
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
return "en"
|
| 44 |
|
| 45 |
|
|
@@ -47,8 +110,8 @@ def _detect_language(text: str) -> str:
|
|
| 47 |
|
| 48 |
def _deduplicate(entities: list[Entity]) -> list[Entity]:
|
| 49 |
"""
|
| 50 |
-
双语标签可能
|
| 51 |
-
保留置信度最高的那条,并按位置排序。
|
| 52 |
"""
|
| 53 |
best: dict[tuple[int, int], Entity] = {}
|
| 54 |
for e in entities:
|
|
@@ -58,12 +121,155 @@ def _deduplicate(entities: list[Entity]) -> list[Entity]:
|
|
| 58 |
return sorted(best.values(), key=lambda x: x.start)
|
| 59 |
|
| 60 |
|
| 61 |
-
# ──
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
class NERService:
|
| 64 |
def __init__(self, model_name: str, cache_dir: str) -> None:
|
| 65 |
self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def extract(
|
| 68 |
self,
|
| 69 |
text: str,
|
|
@@ -74,35 +280,40 @@ class NERService:
|
|
| 74 |
"""
|
| 75 |
返回 (entities, labels_used)。
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
- language='auto' 时自动检测
|
| 83 |
-
- 中文 / 混合文本若传入默认 threshold(0.4) 则不调整(已足够低)
|
| 84 |
"""
|
| 85 |
if not text:
|
| 86 |
return [], labels
|
| 87 |
|
| 88 |
-
|
| 89 |
-
eff_lang = language if language != "auto" else _detect_language(text)
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
]
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NER 服务层 — 双模型路由 + 兜底策略
|
| 3 |
+
──────────────────────────────────────────────────────────────────────────────
|
| 4 |
+
语言检测(两层):
|
| 5 |
+
1. Unicode 脚本比例:快速,适合中文 / 阿拉伯文等脚本明显的语言
|
| 6 |
+
2. langdetect 库兜底:覆盖纯英文及边界文本
|
| 7 |
+
|
| 8 |
+
路由 & 兜底规则:
|
| 9 |
+
┌──────────┬──────────────────┬──────────────────────────────┐
|
| 10 |
+
│ language │ 主模型 │ 兜底条件 │
|
| 11 |
+
├──────────┼──────────────────┼──────────────────────────────┤
|
| 12 |
+
│ zh │ ChineseBERT │ 实体数=0 → 补充 GLiNER 结果 │
|
| 13 |
+
│ en / ar │ GLiNER │ 实体数=0 → 补充 BERT 结果 │
|
| 14 |
+
│ mixed │ GLiNER + BERT │ 同时运行两个模型,结果合并 │
|
| 15 |
+
│ auto │ 先检测语言再路由 │ │
|
| 16 |
+
└──────────┴──────────────────┴──────────────────────────────┘
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import threading
|
| 20 |
import unicodedata
|
| 21 |
+
from abc import ABC, abstractmethod
|
| 22 |
|
| 23 |
from gliner import GLiNER
|
| 24 |
|
| 25 |
+
from app.labels import (
|
| 26 |
+
DEFAULT_LABELS,
|
| 27 |
+
BERT_TYPE_TO_LABEL,
|
| 28 |
+
expand_bilingual,
|
| 29 |
+
labels_to_bert_types,
|
| 30 |
+
)
|
| 31 |
from app.models import Entity
|
| 32 |
|
| 33 |
|
| 34 |
# ── 语言检测 ──────────────────────────────────────────────────────────────────
|
| 35 |
+
#
|
| 36 |
+
# 两层策略:
|
| 37 |
+
# Layer-1 Unicode 脚本比例
|
| 38 |
+
# · 遍历文本中所有字母字符,统计 CJK / Arabic 脚本占比
|
| 39 |
+
# · 优点:零依赖、极快;缺点:对极短或纯拉丁文本判断力弱
|
| 40 |
+
#
|
| 41 |
+
# Layer-2 langdetect(仅 Layer-1 返回 'en' 时作为校验)
|
| 42 |
+
# · 基于 n-gram 概率模型,原理同 Google CLD2
|
| 43 |
+
# · 对短文本(<20 字)仍有一定误判率,以 Layer-1 为主
|
| 44 |
+
# · 若 langdetect 检测到中文/日文/韩文 → 返回 'zh'
|
| 45 |
+
# · 失败时静默回退到 Layer-1 结果
|
| 46 |
|
| 47 |
+
def _unicode_script_ratio(text: str) -> str:
|
| 48 |
+
"""Layer-1:基于 Unicode 脚本比例的语言分类。"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
cjk = arabic = letters = 0
|
| 50 |
for ch in text:
|
| 51 |
+
if not unicodedata.category(ch).startswith("L"):
|
| 52 |
+
continue
|
| 53 |
+
letters += 1
|
| 54 |
+
cp = ord(ch)
|
| 55 |
+
if (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or
|
| 56 |
+
0xF900 <= cp <= 0xFAFF or 0x20000 <= cp <= 0x2A6DF):
|
| 57 |
+
cjk += 1
|
| 58 |
+
elif 0x0600 <= cp <= 0x06FF or 0x0750 <= cp <= 0x077F:
|
| 59 |
+
arabic += 1
|
|
|
|
|
|
|
|
|
|
| 60 |
if not letters:
|
| 61 |
return "en"
|
| 62 |
cjk_r = cjk / letters
|
| 63 |
ar_r = arabic / letters
|
| 64 |
+
latin_r = (letters - cjk - arabic) / letters
|
| 65 |
+
|
| 66 |
+
# 中文+拉丁都显著 → mixed(优先级高于单纯 zh 判断)
|
| 67 |
+
if cjk_r >= 0.08 and latin_r >= 0.10:
|
| 68 |
+
return "mixed"
|
| 69 |
+
# 阿拉伯+拉丁都显著 → mixed
|
| 70 |
+
if ar_r >= 0.08 and latin_r >= 0.10:
|
| 71 |
+
return "mixed"
|
| 72 |
+
# 单脚本主导
|
| 73 |
+
if cjk_r >= 0.20:
|
| 74 |
return "zh"
|
| 75 |
+
if ar_r >= 0.20:
|
| 76 |
return "ar"
|
| 77 |
+
return "en"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def detect_language(text: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
两层语言检测,返回 'zh' | 'ar' | 'mixed' | 'en'。
|
| 83 |
+
|
| 84 |
+
Layer-1 优先(Unicode 脚本比例);Layer-1 返回 'en' 时,
|
| 85 |
+
用 langdetect 做一次二次确认,防止把中文误判为英文。
|
| 86 |
+
"""
|
| 87 |
+
if not text:
|
| 88 |
+
return "en"
|
| 89 |
+
|
| 90 |
+
layer1 = _unicode_script_ratio(text)
|
| 91 |
+
if layer1 != "en": # 已明确是非英文,直接返回
|
| 92 |
+
return layer1
|
| 93 |
+
|
| 94 |
+
# Layer-2:langdetect 校验(仅对 Layer-1='en' 的文本)
|
| 95 |
+
try:
|
| 96 |
+
from langdetect import detect, DetectorFactory
|
| 97 |
+
DetectorFactory.seed = 0 # 保证结果���定
|
| 98 |
+
lang_code = detect(text) # e.g. 'zh-cn', 'ar', 'en', 'ja' …
|
| 99 |
+
if lang_code.startswith("zh") or lang_code in ("ja", "ko"):
|
| 100 |
+
return "zh"
|
| 101 |
+
if lang_code == "ar":
|
| 102 |
+
return "ar"
|
| 103 |
+
except Exception:
|
| 104 |
+
pass # langdetect 失败时静默回退
|
| 105 |
+
|
| 106 |
return "en"
|
| 107 |
|
| 108 |
|
|
|
|
| 110 |
|
| 111 |
def _deduplicate(entities: list[Entity]) -> list[Entity]:
|
| 112 |
"""
|
| 113 |
+
双语标签或模型合并时可能产生同一 (start, end) 的重复结果,
|
| 114 |
+
保留置信度最高的那条,并按起始位置排序。
|
| 115 |
"""
|
| 116 |
best: dict[tuple[int, int], Entity] = {}
|
| 117 |
for e in entities:
|
|
|
|
| 121 |
return sorted(best.values(), key=lambda x: x.start)
|
| 122 |
|
| 123 |
|
| 124 |
+
# ── 后端基类 ──────────────────────────────────────────────────────────────────
|
| 125 |
+
|
| 126 |
+
class _Backend(ABC):
|
| 127 |
+
@abstractmethod
|
| 128 |
+
def predict(
|
| 129 |
+
self, text: str, labels: list[str], threshold: float
|
| 130 |
+
) -> tuple[list[Entity], list[str]]:
|
| 131 |
+
"""返回 (entities, labels_used)"""
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ── GLiNER 后端(英文 / 阿拉伯文 / 混合) ─────────────────────────────────────
|
| 135 |
+
|
| 136 |
+
class GLiNERBackend(_Backend):
|
| 137 |
+
"""
|
| 138 |
+
零样本 NER:urchade/gliner_multi-v2.1
|
| 139 |
+
• 支持英文、阿拉伯文及混合文本
|
| 140 |
+
• 自动做双语标签扩展,提升召回率
|
| 141 |
+
"""
|
| 142 |
|
|
|
|
| 143 |
def __init__(self, model_name: str, cache_dir: str) -> None:
|
| 144 |
self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir)
|
| 145 |
|
| 146 |
+
def predict(
|
| 147 |
+
self, text: str, labels: list[str], threshold: float
|
| 148 |
+
) -> tuple[list[Entity], list[str]]:
|
| 149 |
+
eff_labels = expand_bilingual(labels) if labels else DEFAULT_LABELS
|
| 150 |
+
raw = self._model.predict_entities(text, eff_labels, threshold=threshold)
|
| 151 |
+
entities = [
|
| 152 |
+
Entity(
|
| 153 |
+
text=e["text"],
|
| 154 |
+
label=e["label"],
|
| 155 |
+
score=round(e["score"], 4),
|
| 156 |
+
start=e["start"],
|
| 157 |
+
end=e["end"],
|
| 158 |
+
)
|
| 159 |
+
for e in raw
|
| 160 |
+
]
|
| 161 |
+
return _deduplicate(entities), eff_labels
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ── 中文 BERT 后端 ─────────────────────────────────────────────────────────────
|
| 165 |
+
|
| 166 |
+
class ChineseBERTBackend(_Backend):
|
| 167 |
+
"""
|
| 168 |
+
专用中文 NER:shibing624/bert4ner-base-chinese
|
| 169 |
+
• 模型大小:~400 MB(BERT-base)
|
| 170 |
+
• 推理速度:~100 ms
|
| 171 |
+
• 固定实体类型:PER / LOC / ORG / TIME → 映射为双语标签
|
| 172 |
+
• 用户传入标签时按标签类型过滤;无法映射的自定义标签不过滤(返回全部)
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, model_name: str, cache_dir: str) -> None:
|
| 176 |
+
# 延迟导入:避免顶层 import 在测试收集阶段触发 torch.__spec__ 检测
|
| 177 |
+
from transformers import pipeline as hf_pipeline
|
| 178 |
+
self._pipe = hf_pipeline(
|
| 179 |
+
"token-classification",
|
| 180 |
+
model=model_name,
|
| 181 |
+
model_kwargs={"cache_dir": cache_dir},
|
| 182 |
+
aggregation_strategy="simple",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def predict(
|
| 186 |
+
self, text: str, labels: list[str], threshold: float
|
| 187 |
+
) -> tuple[list[Entity], list[str]]:
|
| 188 |
+
raw = self._pipe(text)
|
| 189 |
+
allowed_types = labels_to_bert_types(labels) # None = 不过滤
|
| 190 |
+
|
| 191 |
+
entities: list[Entity] = []
|
| 192 |
+
labels_seen: set[str] = set()
|
| 193 |
+
|
| 194 |
+
for r in raw:
|
| 195 |
+
score = float(r["score"])
|
| 196 |
+
if score < threshold:
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
bert_type = r.get("entity_group", r.get("entity", ""))
|
| 200 |
+
bert_type = bert_type.lstrip("BI-").strip() # 去掉可能的 B-/I- 前缀
|
| 201 |
+
|
| 202 |
+
if allowed_types is not None and bert_type not in allowed_types:
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
std_label = BERT_TYPE_TO_LABEL.get(bert_type, bert_type)
|
| 206 |
+
labels_seen.add(std_label)
|
| 207 |
+
entities.append(Entity(
|
| 208 |
+
text=r["word"],
|
| 209 |
+
label=std_label,
|
| 210 |
+
score=round(score, 4),
|
| 211 |
+
start=r["start"],
|
| 212 |
+
end=r["end"],
|
| 213 |
+
))
|
| 214 |
+
|
| 215 |
+
used = list(labels_seen) if labels_seen else list(BERT_TYPE_TO_LABEL.values())
|
| 216 |
+
return entities, used
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ── NER 服务(路由 + 兜底) ─────────────────────────────���──────────────────────
|
| 220 |
+
|
| 221 |
+
class NERService:
|
| 222 |
+
"""
|
| 223 |
+
持有两个后端,按检测到的语言分发请求。
|
| 224 |
+
|
| 225 |
+
兜底规则(召回为空时):
|
| 226 |
+
zh 主模型 BERT 无结果 → 用 GLiNER 补充
|
| 227 |
+
en/ar 主模型 GLiNER 无结果 → 用 BERT 补充
|
| 228 |
+
mixed 同时运行两个模型,合并去重后返回
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self, en_model_name: str, zh_model_name: str, cache_dir: str) -> None:
|
| 232 |
+
self._en_name = en_model_name
|
| 233 |
+
self._zh_name = zh_model_name
|
| 234 |
+
self._cache_dir = cache_dir
|
| 235 |
+
|
| 236 |
+
self._en_backend: GLiNERBackend | None = None
|
| 237 |
+
self._zh_backend: ChineseBERTBackend | None = None
|
| 238 |
+
self._en_lock = threading.Lock()
|
| 239 |
+
self._zh_lock = threading.Lock()
|
| 240 |
+
|
| 241 |
+
# ── 懒加载 ────────────────────────────────────────────────────────────────
|
| 242 |
+
|
| 243 |
+
def _en(self) -> GLiNERBackend:
|
| 244 |
+
if self._en_backend is None:
|
| 245 |
+
with self._en_lock:
|
| 246 |
+
if self._en_backend is None:
|
| 247 |
+
self._en_backend = GLiNERBackend(self._en_name, self._cache_dir)
|
| 248 |
+
return self._en_backend
|
| 249 |
+
|
| 250 |
+
def _zh(self) -> ChineseBERTBackend:
|
| 251 |
+
if self._zh_backend is None:
|
| 252 |
+
with self._zh_lock:
|
| 253 |
+
if self._zh_backend is None:
|
| 254 |
+
self._zh_backend = ChineseBERTBackend(self._zh_name, self._cache_dir)
|
| 255 |
+
return self._zh_backend
|
| 256 |
+
|
| 257 |
+
# ── 兜底合并 ──────────────────────────────────────────────────────────────
|
| 258 |
+
|
| 259 |
+
def _merge(
|
| 260 |
+
self,
|
| 261 |
+
primary: tuple[list[Entity], list[str]],
|
| 262 |
+
fallback: tuple[list[Entity], list[str]],
|
| 263 |
+
) -> tuple[list[Entity], list[str]]:
|
| 264 |
+
"""合并两个模型的结果,去重后按位置排序。"""
|
| 265 |
+
p_ents, p_labels = primary
|
| 266 |
+
f_ents, f_labels = fallback
|
| 267 |
+
merged = _deduplicate(p_ents + f_ents)
|
| 268 |
+
used = list(dict.fromkeys(p_labels + f_labels)) # 保序去重
|
| 269 |
+
return merged, used
|
| 270 |
+
|
| 271 |
+
# ── 主入口 ────────────────────────────────────────────────────────────────
|
| 272 |
+
|
| 273 |
def extract(
|
| 274 |
self,
|
| 275 |
text: str,
|
|
|
|
| 280 |
"""
|
| 281 |
返回 (entities, labels_used)。
|
| 282 |
|
| 283 |
+
路由逻辑:
|
| 284 |
+
auto → 检测语言 → 路由
|
| 285 |
+
zh → BERT 主,GLiNER 兜底(主模型无结果时补充)
|
| 286 |
+
en/ar → GLiNER 主,BERT 兜底(主模型无结果时补充)
|
| 287 |
+
mixed → 两模型同时运行,结果合并去重
|
|
|
|
|
|
|
| 288 |
"""
|
| 289 |
if not text:
|
| 290 |
return [], labels
|
| 291 |
|
| 292 |
+
lang = language if language != "auto" else detect_language(text)
|
|
|
|
| 293 |
|
| 294 |
+
if lang == "mixed":
|
| 295 |
+
# 同时运行两个模型,合并结果
|
| 296 |
+
en_result = self._en().predict(text, labels, threshold)
|
| 297 |
+
zh_result = self._zh().predict(text, labels, threshold)
|
| 298 |
+
return self._merge(en_result, zh_result)
|
| 299 |
|
| 300 |
+
if lang == "zh":
|
| 301 |
+
primary_result = self._zh().predict(text, labels, threshold)
|
| 302 |
+
if not primary_result[0]: # 主模型无结果 → GLiNER 兜底
|
| 303 |
+
fallback_result = self._en().predict(text, labels, threshold)
|
| 304 |
+
if fallback_result[0]:
|
| 305 |
+
return fallback_result
|
| 306 |
+
return primary_result
|
| 307 |
+
|
| 308 |
+
# en / ar / 其他
|
| 309 |
+
primary_result = self._en().predict(text, labels, threshold)
|
| 310 |
+
if not primary_result[0]: # 主模型无结果 → BERT 兜底
|
| 311 |
+
fallback_result = self._zh().predict(text, labels, threshold)
|
| 312 |
+
if fallback_result[0]:
|
| 313 |
+
return fallback_result
|
| 314 |
+
return primary_result
|
| 315 |
+
|
| 316 |
+
def warmup(self) -> None:
|
| 317 |
+
"""启动时预热两个模型,首个请求无需等待。"""
|
| 318 |
+
self._en()
|
| 319 |
+
self._zh()
|
|
@@ -1,3 +1,5 @@
|
|
| 1 |
fastapi>=0.111.0
|
| 2 |
uvicorn[standard]>=0.29.0
|
| 3 |
gliner>=0.2.0
|
|
|
|
|
|
|
|
|
| 1 |
fastapi>=0.111.0
|
| 2 |
uvicorn[standard]>=0.29.0
|
| 3 |
gliner>=0.2.0
|
| 4 |
+
transformers>=4.40.0
|
| 5 |
+
langdetect>=1.0.9
|
|
@@ -1,27 +1,39 @@
|
|
| 1 |
"""
|
| 2 |
-
Unit tests — no real model loaded (GLiNER/torch stubbed
|
| 3 |
Covers:
|
| 4 |
-
- API contract (health, validation, threshold
|
| 5 |
-
-
|
| 6 |
-
-
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
-
from unittest.mock import MagicMock, patch
|
| 9 |
|
| 10 |
import pytest
|
| 11 |
from fastapi.testclient import TestClient
|
| 12 |
|
| 13 |
-
import app.main as main_module
|
| 14 |
from app.main import app
|
| 15 |
from app.models import Entity
|
| 16 |
-
from app.labels import DEFAULT_LABELS, expand_bilingual
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
# ── Fixture ───────────────────────────────────────────────────────────────────
|
| 20 |
|
| 21 |
@pytest.fixture()
|
| 22 |
def client():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
mock_ner = MagicMock()
|
| 24 |
-
# Default: extract() returns ([], [])
|
| 25 |
mock_ner.extract.return_value = ([], [])
|
| 26 |
with pytest.MonkeyPatch().context() as mp:
|
| 27 |
mp.setattr("app.main.NERService", lambda *_: mock_ner)
|
|
@@ -29,61 +41,49 @@ def client():
|
|
| 29 |
yield c, mock_ner
|
| 30 |
|
| 31 |
|
| 32 |
-
def _ents(*args) -> tuple[list[Entity], list[str]]:
|
| 33 |
-
"""Helper: wrap Entity list in the (entities, labels_used) tuple."""
|
| 34 |
-
entities = list(args)
|
| 35 |
-
labels = [e.label for e in entities]
|
| 36 |
-
return entities, labels
|
| 37 |
-
|
| 38 |
-
|
| 39 |
# ── System / API contract ─────────────────────────────────────────────────────
|
| 40 |
|
| 41 |
def test_health(client):
|
| 42 |
c, _ = client
|
| 43 |
-
|
| 44 |
-
assert resp.status_code == 200
|
| 45 |
-
assert resp.json() == {"status": "ok"}
|
| 46 |
|
| 47 |
|
| 48 |
-
def
|
| 49 |
-
c,
|
| 50 |
resp = c.post("/api/v1/extract", json={"text": "", "labels": ["person"]})
|
| 51 |
assert resp.status_code == 200
|
| 52 |
assert resp.json()["entities"] == []
|
| 53 |
|
| 54 |
|
| 55 |
-
def
|
| 56 |
-
"""labels
|
| 57 |
c, mock_ner = client
|
| 58 |
mock_ner.extract.return_value = ([], DEFAULT_LABELS)
|
| 59 |
-
resp = c.post("/api/v1/extract", json={"text": "
|
| 60 |
assert resp.status_code == 200
|
| 61 |
-
|
| 62 |
-
assert "entities" in data
|
| 63 |
-
assert "labels_used" in data
|
| 64 |
-
assert len(data["labels_used"]) > 0
|
| 65 |
|
| 66 |
|
| 67 |
-
def
|
| 68 |
-
"""labels
|
| 69 |
c, mock_ner = client
|
| 70 |
mock_ner.extract.return_value = ([], DEFAULT_LABELS)
|
| 71 |
-
resp = c.post("/api/v1/extract", json={"text": "
|
| 72 |
assert resp.status_code == 200
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
def test_extract_threshold_forwarded(client):
|
| 76 |
c, mock_ner = client
|
| 77 |
c.post("/api/v1/extract",
|
| 78 |
-
json={"text": "Hello
|
| 79 |
-
mock_ner.extract.assert_called_once_with("Hello
|
| 80 |
|
| 81 |
|
| 82 |
def test_extract_invalid_threshold(client):
|
| 83 |
c, _ = client
|
| 84 |
-
|
| 85 |
-
json={"text": "
|
| 86 |
-
assert resp.status_code == 422
|
| 87 |
|
| 88 |
|
| 89 |
def test_extract_language_field_forwarded(client):
|
|
@@ -94,86 +94,117 @@ def test_extract_language_field_forwarded(client):
|
|
| 94 |
|
| 95 |
|
| 96 |
def test_extract_invalid_language(client):
|
| 97 |
-
"""不支持的 language 值应返回 422。"""
|
| 98 |
c, _ = client
|
| 99 |
-
|
| 100 |
-
json={"text": "
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
def test_entity_response_fields(client):
|
| 105 |
-
"""每个实体包含全部必填字段且值合法。"""
|
| 106 |
c, mock_ner = client
|
| 107 |
mock_ner.extract.return_value = _ents(
|
| 108 |
Entity(text="Apple", label="organization", score=0.95, start=0, end=5)
|
| 109 |
)
|
| 110 |
resp = c.post("/api/v1/extract",
|
| 111 |
json={"text": "Apple is great.", "labels": ["organization"]})
|
| 112 |
-
assert resp.status_code == 200
|
| 113 |
e = resp.json()["entities"][0]
|
| 114 |
-
assert
|
| 115 |
assert 0.0 <= e["score"] <= 1.0
|
| 116 |
assert e["start"] < e["end"]
|
| 117 |
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
assert
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
|
| 130 |
-
# ── Bilingual label expansion
|
| 131 |
|
| 132 |
def test_expand_bilingual_adds_english_for_chinese():
|
| 133 |
result = expand_bilingual(["人名或姓名"])
|
| 134 |
-
assert "人名或姓名" in result
|
| 135 |
assert "full name of a person" in result
|
| 136 |
|
| 137 |
|
| 138 |
def test_expand_bilingual_adds_chinese_for_english():
|
| 139 |
result = expand_bilingual(["company or organization name"])
|
| 140 |
-
assert "company or organization name" in result
|
| 141 |
assert "公司或组织机构名称" in result
|
| 142 |
|
| 143 |
|
| 144 |
def test_expand_bilingual_no_duplicate():
|
| 145 |
-
|
| 146 |
-
result = expand_bilingual(labels)
|
| 147 |
assert result.count("人名或姓名") == 1
|
| 148 |
assert result.count("full name of a person") == 1
|
| 149 |
|
| 150 |
|
| 151 |
def test_expand_bilingual_custom_label_preserved():
|
| 152 |
-
"""自定义标签(不在对照表中)原样保留。"""
|
| 153 |
result = expand_bilingual(["my custom label"])
|
| 154 |
assert "my custom label" in result
|
| 155 |
|
| 156 |
|
| 157 |
-
def
|
| 158 |
-
assert len(DEFAULT_LABELS) > 0
|
| 159 |
-
# 必须包含中英文各至少一个
|
| 160 |
has_en = any(all(ord(c) < 128 for c in lbl) for lbl in DEFAULT_LABELS)
|
| 161 |
has_zh = any(any('一' <= c <= '鿿' for c in lbl) for lbl in DEFAULT_LABELS)
|
| 162 |
assert has_en and has_zh
|
| 163 |
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
# ── English ───────────────────────────────────────────────────────────────────
|
| 166 |
|
| 167 |
def test_english_person_org(client):
|
| 168 |
c, mock_ner = client
|
| 169 |
mock_ner.extract.return_value = _ents(
|
| 170 |
-
Entity(text="Elon Musk",
|
| 171 |
-
Entity(text="Tesla",
|
| 172 |
-
Entity(text="SpaceX",
|
| 173 |
)
|
| 174 |
resp = c.post("/api/v1/extract",
|
| 175 |
json={"text": "Elon Musk is the CEO of Tesla and founded SpaceX.",
|
| 176 |
-
"labels": ["full name of a person", "company or organization name"]
|
|
|
|
| 177 |
assert resp.status_code == 200
|
| 178 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 179 |
assert {"Elon Musk", "Tesla", "SpaceX"} <= texts
|
|
@@ -188,28 +219,24 @@ def test_english_location_date(client):
|
|
| 188 |
)
|
| 189 |
resp = c.post("/api/v1/extract",
|
| 190 |
json={"text": "The summit was held in Paris in 2024, in France.",
|
| 191 |
-
"labels": ["geographical location", "date or year"]
|
| 192 |
-
|
| 193 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 194 |
assert {"Paris", "France", "2024"} <= texts
|
| 195 |
|
| 196 |
|
| 197 |
-
def
|
| 198 |
c, mock_ner = client
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
json={"text": "NASA explored the Moon.",
|
| 204 |
-
"labels": ["company or organization name"],
|
| 205 |
-
"threshold": 0.8})
|
| 206 |
-
assert resp.status_code == 200
|
| 207 |
mock_ner.extract.assert_called_once_with(
|
| 208 |
-
"NASA explored the Moon.", ["company or organization name"], 0.8, language="
|
| 209 |
)
|
| 210 |
|
| 211 |
|
| 212 |
-
# ── Chinese ────────────────────────────────────────────────────
|
| 213 |
|
| 214 |
def test_chinese_person_org(client):
|
| 215 |
c, mock_ner = client
|
|
@@ -222,13 +249,12 @@ def test_chinese_person_org(client):
|
|
| 222 |
json={"text": "阿里巴巴集团创始人马云卸任,由张勇接任。",
|
| 223 |
"labels": ["人名或姓名", "公司或组织机构名称"],
|
| 224 |
"language": "zh"})
|
| 225 |
-
assert resp.status_code == 200
|
| 226 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 227 |
assert {"马云", "张勇", "阿里巴巴"} <= texts
|
| 228 |
|
| 229 |
|
| 230 |
def test_chinese_entity_boundary(client):
|
| 231 |
-
"""实体边界不
|
| 232 |
c, mock_ner = client
|
| 233 |
mock_ner.extract.return_value = _ents(
|
| 234 |
Entity(text="尤氏", label="人名或姓名", score=0.82, start=0, end=2),
|
|
@@ -236,8 +262,7 @@ def test_chinese_entity_boundary(client):
|
|
| 236 |
)
|
| 237 |
resp = c.post("/api/v1/extract",
|
| 238 |
json={"text": "尤氏来请,王熙凤笑道:'你来了。'",
|
| 239 |
-
"labels": ["人名或姓名"]})
|
| 240 |
-
assert resp.status_code == 200
|
| 241 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 242 |
assert "尤氏" in texts
|
| 243 |
assert "王熙凤" in texts
|
|
@@ -248,33 +273,42 @@ def test_chinese_entity_boundary(client):
|
|
| 248 |
def test_chinese_location_product(client):
|
| 249 |
c, mock_ner = client
|
| 250 |
mock_ner.extract.return_value = _ents(
|
| 251 |
-
Entity(text="杭州", label="地名或城市", score=0.93, start=
|
| 252 |
-
Entity(text="淘宝", label="产品或品牌名称", score=0.91, start=
|
| 253 |
-
Entity(text="天猫", label="产品或品牌名称", score=0.92, start=
|
| 254 |
-
Entity(text="支付宝", label="产品或品牌名称", score=0.90, start=
|
| 255 |
)
|
| 256 |
resp = c.post("/api/v1/extract",
|
| 257 |
json={"text": "阿里巴巴总部位于杭州,旗下有淘宝、天猫、支付宝。",
|
| 258 |
-
"labels": ["地名或城市", "产品或品牌名称"]
|
| 259 |
-
|
| 260 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 261 |
assert {"杭州", "淘宝", "天猫", "支付宝"} <= texts
|
| 262 |
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
# ── Arabic ────────────────────────────────────────────────────────────────────
|
| 265 |
|
| 266 |
def test_arabic_person_location(client):
|
| 267 |
-
"""阿拉伯语:识别人名与地名。"""
|
| 268 |
c, mock_ner = client
|
| 269 |
mock_ner.extract.return_value = _ents(
|
| 270 |
Entity(text="محمد بن سلمان", label="full name of a person", score=0.82, start=12, end=26),
|
| 271 |
Entity(text="المملكة العربية السعودية", label="geographical location", score=0.85, start=44, end=68),
|
| 272 |
)
|
| 273 |
resp = c.post("/api/v1/extract",
|
| 274 |
-
json={"text": "أعلن الرئيس محمد بن سلمان
|
| 275 |
"labels": ["full name of a person", "geographical location"],
|
| 276 |
"language": "ar"})
|
| 277 |
-
assert resp.status_code == 200
|
| 278 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 279 |
assert "محمد بن سلمان" in texts
|
| 280 |
assert "المملكة العربية السعودية" in texts
|
|
@@ -295,29 +329,12 @@ def test_mixed_entities_both_scripts(client):
|
|
| 295 |
"labels": ["full name of a person", "人名或姓名",
|
| 296 |
"company or organization name", "公司或组织机构名称",
|
| 297 |
"geographical location", "地名或城市",
|
| 298 |
-
"product or technology name"]
|
| 299 |
-
|
| 300 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 301 |
assert {"张伟", "Google", "北京", "Android"} <= texts
|
| 302 |
|
| 303 |
|
| 304 |
-
def test_mixed_labels_chinese_and_english(client):
|
| 305 |
-
c, mock_ner = client
|
| 306 |
-
mock_ner.extract.return_value = _ents(
|
| 307 |
-
Entity(text="李明", label="人名或姓名", score=0.94, start=0, end=2),
|
| 308 |
-
Entity(text="Tesla", label="人名或姓名", score=0.96, start=10, end=15),
|
| 309 |
-
Entity(text="上海", label="地名或城市", score=0.92, start=22, end=24),
|
| 310 |
-
)
|
| 311 |
-
resp = c.post("/api/v1/extract",
|
| 312 |
-
json={"text": "李明在上海加入了 Tesla。",
|
| 313 |
-
"labels": ["人名或姓名", "full name of a person",
|
| 314 |
-
"地名或城市", "geographical location",
|
| 315 |
-
"company or organization name"]})
|
| 316 |
-
assert resp.status_code == 200
|
| 317 |
-
texts = {e["text"] for e in resp.json()["entities"]}
|
| 318 |
-
assert {"李明", "Tesla", "上海"} <= texts
|
| 319 |
-
|
| 320 |
-
|
| 321 |
def test_mixed_no_cross_language_contamination(client):
|
| 322 |
c, mock_ner = client
|
| 323 |
mock_ner.extract.return_value = _ents(
|
|
@@ -327,7 +344,108 @@ def test_mixed_no_cross_language_contamination(client):
|
|
| 327 |
resp = c.post("/api/v1/extract",
|
| 328 |
json={"text": "他在 OpenAI 工作,同事王芳也在同一部门。",
|
| 329 |
"labels": ["person", "organization"]})
|
| 330 |
-
assert resp.status_code == 200
|
| 331 |
entities = resp.json()["entities"]
|
| 332 |
assert any(e["text"] == "OpenAI" and e["label"] == "organization" for e in entities)
|
| 333 |
assert any(e["text"] == "王芳" and e["label"] == "person" for e in entities)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Unit tests — no real model loaded (GLiNER/torch stubbed via conftest.py).
|
| 3 |
Covers:
|
| 4 |
+
- API contract (health, validation, threshold, language field)
|
| 5 |
+
- Dual-model routing (ZH → BERT backend, EN/AR/mixed → GLiNER backend)
|
| 6 |
+
- Optional labels + bilingual auto-expansion
|
| 7 |
+
- English / Chinese / Arabic / mixed-language scenarios
|
| 8 |
+
- labels_used echo in response
|
| 9 |
"""
|
| 10 |
+
from unittest.mock import MagicMock, patch, PropertyMock
|
| 11 |
|
| 12 |
import pytest
|
| 13 |
from fastapi.testclient import TestClient
|
| 14 |
|
|
|
|
| 15 |
from app.main import app
|
| 16 |
from app.models import Entity
|
| 17 |
+
from app.labels import DEFAULT_LABELS, expand_bilingual, labels_to_bert_types
|
| 18 |
+
from app.ner import detect_language
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ── Helpers ───────────────────────────────────────────────────────────────────
|
| 22 |
+
|
| 23 |
+
def _ents(*args: Entity) -> tuple[list[Entity], list[str]]:
|
| 24 |
+
"""Wrap entities in (entities, labels_used) tuple expected by NERService."""
|
| 25 |
+
return list(args), [e.label for e in args]
|
| 26 |
|
| 27 |
|
| 28 |
# ── Fixture ───────────────────────────────────────────────────────────────────
|
| 29 |
|
| 30 |
@pytest.fixture()
|
| 31 |
def client():
|
| 32 |
+
"""
|
| 33 |
+
Patch NERService so no model is actually loaded.
|
| 34 |
+
mock_ner.extract() returns ([], []) by default.
|
| 35 |
+
"""
|
| 36 |
mock_ner = MagicMock()
|
|
|
|
| 37 |
mock_ner.extract.return_value = ([], [])
|
| 38 |
with pytest.MonkeyPatch().context() as mp:
|
| 39 |
mp.setattr("app.main.NERService", lambda *_: mock_ner)
|
|
|
|
| 41 |
yield c, mock_ner
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
# ── System / API contract ─────────────────────────────────────────────────────
|
| 45 |
|
| 46 |
def test_health(client):
|
| 47 |
c, _ = client
|
| 48 |
+
assert c.get("/api/v1/health").json() == {"status": "ok"}
|
|
|
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
+
def test_extract_empty_text_returns_empty(client):
|
| 52 |
+
c, _ = client
|
| 53 |
resp = c.post("/api/v1/extract", json={"text": "", "labels": ["person"]})
|
| 54 |
assert resp.status_code == 200
|
| 55 |
assert resp.json()["entities"] == []
|
| 56 |
|
| 57 |
|
| 58 |
+
def test_extract_labels_optional(client):
|
| 59 |
+
"""labels 字段完全不传应正常返回 200。"""
|
| 60 |
c, mock_ner = client
|
| 61 |
mock_ner.extract.return_value = ([], DEFAULT_LABELS)
|
| 62 |
+
resp = c.post("/api/v1/extract", json={"text": "Some text."})
|
| 63 |
assert resp.status_code == 200
|
| 64 |
+
assert len(resp.json()["labels_used"]) > 0
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
+
def test_extract_empty_labels_uses_defaults(client):
|
| 68 |
+
"""labels=[] 时应使用默认双语标签集。"""
|
| 69 |
c, mock_ner = client
|
| 70 |
mock_ner.extract.return_value = ([], DEFAULT_LABELS)
|
| 71 |
+
resp = c.post("/api/v1/extract", json={"text": "Hello world.", "labels": []})
|
| 72 |
assert resp.status_code == 200
|
| 73 |
+
assert resp.json()["labels_used"] == DEFAULT_LABELS
|
| 74 |
|
| 75 |
|
| 76 |
def test_extract_threshold_forwarded(client):
|
| 77 |
c, mock_ner = client
|
| 78 |
c.post("/api/v1/extract",
|
| 79 |
+
json={"text": "Hello", "labels": ["person"], "threshold": 0.8})
|
| 80 |
+
mock_ner.extract.assert_called_once_with("Hello", ["person"], 0.8, language="auto")
|
| 81 |
|
| 82 |
|
| 83 |
def test_extract_invalid_threshold(client):
|
| 84 |
c, _ = client
|
| 85 |
+
assert c.post("/api/v1/extract",
|
| 86 |
+
json={"text": "x", "threshold": 1.5}).status_code == 422
|
|
|
|
| 87 |
|
| 88 |
|
| 89 |
def test_extract_language_field_forwarded(client):
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def test_extract_invalid_language(client):
|
|
|
|
| 97 |
c, _ = client
|
| 98 |
+
assert c.post("/api/v1/extract",
|
| 99 |
+
json={"text": "x", "language": "jp"}).status_code == 422
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def test_labels_used_echoed(client):
|
| 103 |
+
c, mock_ner = client
|
| 104 |
+
used = ["人名或姓名", "地名或城市"]
|
| 105 |
+
mock_ner.extract.return_value = ([], used)
|
| 106 |
+
resp = c.post("/api/v1/extract", json={"text": "马云在杭州。", "labels": ["人名或姓名"]})
|
| 107 |
+
assert resp.json()["labels_used"] == used
|
| 108 |
|
| 109 |
|
| 110 |
def test_entity_response_fields(client):
|
|
|
|
| 111 |
c, mock_ner = client
|
| 112 |
mock_ner.extract.return_value = _ents(
|
| 113 |
Entity(text="Apple", label="organization", score=0.95, start=0, end=5)
|
| 114 |
)
|
| 115 |
resp = c.post("/api/v1/extract",
|
| 116 |
json={"text": "Apple is great.", "labels": ["organization"]})
|
|
|
|
| 117 |
e = resp.json()["entities"][0]
|
| 118 |
+
assert {"text", "label", "score", "start", "end"} <= e.keys()
|
| 119 |
assert 0.0 <= e["score"] <= 1.0
|
| 120 |
assert e["start"] < e["end"]
|
| 121 |
|
| 122 |
|
| 123 |
+
# ── Language detection ────────────────────────────────────────────────────────
|
| 124 |
+
|
| 125 |
+
def test_detect_language_english():
|
| 126 |
+
assert detect_language("Elon Musk founded SpaceX in California.") == "en"
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_detect_language_chinese():
|
| 130 |
+
assert detect_language("阿里巴巴集团创始人马云于杭州卸任。") == "zh"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def test_detect_language_arabic():
|
| 134 |
+
assert detect_language("أعلن الرئيس محمد بن سلمان عن مشروع نيوم.") == "ar"
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_detect_language_mixed():
|
| 138 |
+
assert detect_language("张伟加入了 Google 北京研发中心,负责 Android 优化。") == "mixed"
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def test_detect_language_empty():
|
| 142 |
+
assert detect_language("") == "en"
|
| 143 |
|
| 144 |
|
| 145 |
+
# ── Bilingual label expansion ─────────────────────────────────────────────────
|
| 146 |
|
| 147 |
def test_expand_bilingual_adds_english_for_chinese():
|
| 148 |
result = expand_bilingual(["人名或姓名"])
|
|
|
|
| 149 |
assert "full name of a person" in result
|
| 150 |
|
| 151 |
|
| 152 |
def test_expand_bilingual_adds_chinese_for_english():
|
| 153 |
result = expand_bilingual(["company or organization name"])
|
|
|
|
| 154 |
assert "公司或组织机构名称" in result
|
| 155 |
|
| 156 |
|
| 157 |
def test_expand_bilingual_no_duplicate():
|
| 158 |
+
result = expand_bilingual(["人名或姓名", "full name of a person"])
|
|
|
|
| 159 |
assert result.count("人名或姓名") == 1
|
| 160 |
assert result.count("full name of a person") == 1
|
| 161 |
|
| 162 |
|
| 163 |
def test_expand_bilingual_custom_label_preserved():
|
|
|
|
| 164 |
result = expand_bilingual(["my custom label"])
|
| 165 |
assert "my custom label" in result
|
| 166 |
|
| 167 |
|
| 168 |
+
def test_default_labels_bilingual():
|
|
|
|
|
|
|
| 169 |
has_en = any(all(ord(c) < 128 for c in lbl) for lbl in DEFAULT_LABELS)
|
| 170 |
has_zh = any(any('一' <= c <= '鿿' for c in lbl) for lbl in DEFAULT_LABELS)
|
| 171 |
assert has_en and has_zh
|
| 172 |
|
| 173 |
|
| 174 |
+
# ── BERT label mapping ────────────────────────────────────────────────────────
|
| 175 |
+
|
| 176 |
+
def test_labels_to_bert_types_chinese_label():
|
| 177 |
+
types = labels_to_bert_types(["人名或姓名"])
|
| 178 |
+
assert "PER" in types
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def test_labels_to_bert_types_english_label():
|
| 182 |
+
types = labels_to_bert_types(["geographical location"])
|
| 183 |
+
assert "LOC" in types
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def test_labels_to_bert_types_empty_returns_none():
|
| 187 |
+
assert labels_to_bert_types([]) is None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def test_labels_to_bert_types_unmapped_returns_none():
|
| 191 |
+
# 无法映射的标签 → 不过滤(返回 None)
|
| 192 |
+
assert labels_to_bert_types(["some unknown label"]) is None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
# ── English ───────────────────────────────────────────────────────────────────
|
| 196 |
|
| 197 |
def test_english_person_org(client):
|
| 198 |
c, mock_ner = client
|
| 199 |
mock_ner.extract.return_value = _ents(
|
| 200 |
+
Entity(text="Elon Musk", label="person", score=0.98, start=0, end=9),
|
| 201 |
+
Entity(text="Tesla", label="organization", score=0.96, start=18, end=23),
|
| 202 |
+
Entity(text="SpaceX", label="organization", score=0.97, start=28, end=34),
|
| 203 |
)
|
| 204 |
resp = c.post("/api/v1/extract",
|
| 205 |
json={"text": "Elon Musk is the CEO of Tesla and founded SpaceX.",
|
| 206 |
+
"labels": ["full name of a person", "company or organization name"],
|
| 207 |
+
"language": "en"})
|
| 208 |
assert resp.status_code == 200
|
| 209 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 210 |
assert {"Elon Musk", "Tesla", "SpaceX"} <= texts
|
|
|
|
| 219 |
)
|
| 220 |
resp = c.post("/api/v1/extract",
|
| 221 |
json={"text": "The summit was held in Paris in 2024, in France.",
|
| 222 |
+
"labels": ["geographical location", "date or year"],
|
| 223 |
+
"language": "en"})
|
| 224 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 225 |
assert {"Paris", "France", "2024"} <= texts
|
| 226 |
|
| 227 |
|
| 228 |
+
def test_english_threshold_forwarded(client):
|
| 229 |
c, mock_ner = client
|
| 230 |
+
c.post("/api/v1/extract",
|
| 231 |
+
json={"text": "NASA explored the Moon.",
|
| 232 |
+
"labels": ["company or organization name"],
|
| 233 |
+
"threshold": 0.8, "language": "en"})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
mock_ner.extract.assert_called_once_with(
|
| 235 |
+
"NASA explored the Moon.", ["company or organization name"], 0.8, language="en"
|
| 236 |
)
|
| 237 |
|
| 238 |
|
| 239 |
+
# ── Chinese (BERT backend) ────────────────────────────────────────────────────
|
| 240 |
|
| 241 |
def test_chinese_person_org(client):
|
| 242 |
c, mock_ner = client
|
|
|
|
| 249 |
json={"text": "阿里巴巴集团创始人马云卸任,由张勇接任。",
|
| 250 |
"labels": ["人名或姓名", "公司或组织机构名称"],
|
| 251 |
"language": "zh"})
|
|
|
|
| 252 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 253 |
assert {"马云", "张勇", "阿里巴巴"} <= texts
|
| 254 |
|
| 255 |
|
| 256 |
def test_chinese_entity_boundary(client):
|
| 257 |
+
"""BERT NER 应精确截断实体边界,不含动词。"""
|
| 258 |
c, mock_ner = client
|
| 259 |
mock_ner.extract.return_value = _ents(
|
| 260 |
Entity(text="尤氏", label="人名或姓名", score=0.82, start=0, end=2),
|
|
|
|
| 262 |
)
|
| 263 |
resp = c.post("/api/v1/extract",
|
| 264 |
json={"text": "尤氏来请,王熙凤笑道:'你来了。'",
|
| 265 |
+
"labels": ["人名或姓名"], "language": "zh"})
|
|
|
|
| 266 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 267 |
assert "尤氏" in texts
|
| 268 |
assert "王熙凤" in texts
|
|
|
|
| 273 |
def test_chinese_location_product(client):
|
| 274 |
c, mock_ner = client
|
| 275 |
mock_ner.extract.return_value = _ents(
|
| 276 |
+
Entity(text="杭州", label="地名或城市", score=0.93, start=9, end=11),
|
| 277 |
+
Entity(text="淘宝", label="产品或品牌名称", score=0.91, start=14, end=16),
|
| 278 |
+
Entity(text="天猫", label="产品或品牌名称", score=0.92, start=17, end=19),
|
| 279 |
+
Entity(text="支付宝", label="产品或品牌名称", score=0.90, start=20, end=23),
|
| 280 |
)
|
| 281 |
resp = c.post("/api/v1/extract",
|
| 282 |
json={"text": "阿里巴巴总部位于杭州,旗下有淘宝、天猫、支付宝。",
|
| 283 |
+
"labels": ["地名或城市", "产品或品牌名称"],
|
| 284 |
+
"language": "zh"})
|
| 285 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 286 |
assert {"杭州", "淘宝", "天猫", "支付宝"} <= texts
|
| 287 |
|
| 288 |
|
| 289 |
+
def test_chinese_auto_routes_to_zh(client):
|
| 290 |
+
"""auto 检测到中文应路由到 ZH 模型(language 透传为 'auto',内部检测为 zh)。"""
|
| 291 |
+
c, mock_ner = client
|
| 292 |
+
c.post("/api/v1/extract",
|
| 293 |
+
json={"text": "马云创立了阿里巴巴。"})
|
| 294 |
+
# NERService.extract 被调用时 language='auto',路由逻辑在 ner.py 内部处理
|
| 295 |
+
mock_ner.extract.assert_called_once()
|
| 296 |
+
call_kwargs = mock_ner.extract.call_args
|
| 297 |
+
assert call_kwargs[1].get("language") == "auto" or call_kwargs[0][3] == "auto"
|
| 298 |
+
|
| 299 |
+
|
| 300 |
# ── Arabic ────────────────────────────────────────────────────────────────────
|
| 301 |
|
| 302 |
def test_arabic_person_location(client):
|
|
|
|
| 303 |
c, mock_ner = client
|
| 304 |
mock_ner.extract.return_value = _ents(
|
| 305 |
Entity(text="محمد بن سلمان", label="full name of a person", score=0.82, start=12, end=26),
|
| 306 |
Entity(text="المملكة العربية السعودية", label="geographical location", score=0.85, start=44, end=68),
|
| 307 |
)
|
| 308 |
resp = c.post("/api/v1/extract",
|
| 309 |
+
json={"text": "أعلن الرئيس محمد بن سلمان مشروع نيوم في المملكة العربية السعودية.",
|
| 310 |
"labels": ["full name of a person", "geographical location"],
|
| 311 |
"language": "ar"})
|
|
|
|
| 312 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 313 |
assert "محمد بن سلمان" in texts
|
| 314 |
assert "المملكة العربية السعودية" in texts
|
|
|
|
| 329 |
"labels": ["full name of a person", "人名或姓名",
|
| 330 |
"company or organization name", "公司或组织机构名称",
|
| 331 |
"geographical location", "地名或城市",
|
| 332 |
+
"product or technology name"],
|
| 333 |
+
"language": "mixed"})
|
| 334 |
texts = {e["text"] for e in resp.json()["entities"]}
|
| 335 |
assert {"张伟", "Google", "北京", "Android"} <= texts
|
| 336 |
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
def test_mixed_no_cross_language_contamination(client):
|
| 339 |
c, mock_ner = client
|
| 340 |
mock_ner.extract.return_value = _ents(
|
|
|
|
| 344 |
resp = c.post("/api/v1/extract",
|
| 345 |
json={"text": "他在 OpenAI 工作,同事王芳也在同一部门。",
|
| 346 |
"labels": ["person", "organization"]})
|
|
|
|
| 347 |
entities = resp.json()["entities"]
|
| 348 |
assert any(e["text"] == "OpenAI" and e["label"] == "organization" for e in entities)
|
| 349 |
assert any(e["text"] == "王芳" and e["label"] == "person" for e in entities)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# ── Fallback & merge (NERService unit tests, no HTTP) ────────────────────────
|
| 353 |
+
|
| 354 |
+
def test_fallback_zh_empty_uses_en():
|
| 355 |
+
"""ZH 主模型返回空时,应使用 GLiNER 兜底。"""
|
| 356 |
+
from app.ner import NERService
|
| 357 |
+
|
| 358 |
+
svc = NERService.__new__(NERService)
|
| 359 |
+
svc._en_lock = __import__("threading").Lock()
|
| 360 |
+
svc._zh_lock = __import__("threading").Lock()
|
| 361 |
+
|
| 362 |
+
# ZH backend: returns nothing
|
| 363 |
+
zh_mock = MagicMock()
|
| 364 |
+
zh_mock.predict.return_value = ([], [])
|
| 365 |
+
# EN fallback: returns one entity
|
| 366 |
+
en_mock = MagicMock()
|
| 367 |
+
en_mock.predict.return_value = _ents(
|
| 368 |
+
Entity(text="马云", label="person", score=0.75, start=0, end=2)
|
| 369 |
+
)
|
| 370 |
+
svc._zh_backend = zh_mock
|
| 371 |
+
svc._en_backend = en_mock
|
| 372 |
+
|
| 373 |
+
entities, _ = svc.extract("马云", [], 0.4, language="zh")
|
| 374 |
+
assert any(e.text == "马云" for e in entities)
|
| 375 |
+
zh_mock.predict.assert_called_once()
|
| 376 |
+
en_mock.predict.assert_called_once() # 兜底被调用
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def test_fallback_zh_has_results_no_en_called():
|
| 380 |
+
"""ZH 主模型有结果时,不应调用 GLiNER 兜底。"""
|
| 381 |
+
from app.ner import NERService
|
| 382 |
+
|
| 383 |
+
svc = NERService.__new__(NERService)
|
| 384 |
+
svc._en_lock = __import__("threading").Lock()
|
| 385 |
+
svc._zh_lock = __import__("threading").Lock()
|
| 386 |
+
|
| 387 |
+
zh_mock = MagicMock()
|
| 388 |
+
zh_mock.predict.return_value = _ents(
|
| 389 |
+
Entity(text="马云", label="person", score=0.92, start=0, end=2)
|
| 390 |
+
)
|
| 391 |
+
en_mock = MagicMock()
|
| 392 |
+
svc._zh_backend = zh_mock
|
| 393 |
+
svc._en_backend = en_mock
|
| 394 |
+
|
| 395 |
+
svc.extract("马云", [], 0.4, language="zh")
|
| 396 |
+
en_mock.predict.assert_not_called() # 不应调用兜底
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def test_mixed_runs_both_models_and_merges():
|
| 400 |
+
"""Mixed 语言应同时运行两个模型并合并结果。"""
|
| 401 |
+
from app.ner import NERService
|
| 402 |
+
|
| 403 |
+
svc = NERService.__new__(NERService)
|
| 404 |
+
svc._en_lock = __import__("threading").Lock()
|
| 405 |
+
svc._zh_lock = __import__("threading").Lock()
|
| 406 |
+
|
| 407 |
+
en_mock = MagicMock()
|
| 408 |
+
en_mock.predict.return_value = _ents(
|
| 409 |
+
Entity(text="Google", label="organization", score=0.95, start=5, end=11)
|
| 410 |
+
)
|
| 411 |
+
zh_mock = MagicMock()
|
| 412 |
+
zh_mock.predict.return_value = _ents(
|
| 413 |
+
Entity(text="张伟", label="person", score=0.91, start=0, end=2)
|
| 414 |
+
)
|
| 415 |
+
svc._en_backend = en_mock
|
| 416 |
+
svc._zh_backend = zh_mock
|
| 417 |
+
|
| 418 |
+
entities, _ = svc.extract("张伟加入 Google。", [], 0.4, language="mixed")
|
| 419 |
+
texts = {e.text for e in entities}
|
| 420 |
+
assert "Google" in texts
|
| 421 |
+
assert "张伟" in texts
|
| 422 |
+
en_mock.predict.assert_called_once()
|
| 423 |
+
zh_mock.predict.assert_called_once()
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def test_mixed_deduplicates_overlapping_spans():
|
| 427 |
+
"""两个模型对同一 span 都命中时,只保留得分最高的。"""
|
| 428 |
+
from app.ner import NERService
|
| 429 |
+
|
| 430 |
+
svc = NERService.__new__(NERService)
|
| 431 |
+
svc._en_lock = __import__("threading").Lock()
|
| 432 |
+
svc._zh_lock = __import__("threading").Lock()
|
| 433 |
+
|
| 434 |
+
en_mock = MagicMock()
|
| 435 |
+
en_mock.predict.return_value = (
|
| 436 |
+
[Entity(text="张伟", label="person", score=0.70, start=0, end=2)],
|
| 437 |
+
["person"],
|
| 438 |
+
)
|
| 439 |
+
zh_mock = MagicMock()
|
| 440 |
+
zh_mock.predict.return_value = (
|
| 441 |
+
[Entity(text="张伟", label="人名或姓名", score=0.92, start=0, end=2)],
|
| 442 |
+
["人名或姓名"],
|
| 443 |
+
)
|
| 444 |
+
svc._en_backend = en_mock
|
| 445 |
+
svc._zh_backend = zh_mock
|
| 446 |
+
|
| 447 |
+
entities, _ = svc.extract("张伟", [], 0.4, language="mixed")
|
| 448 |
+
# 去重后只有 1 个 "张伟",且是得分更高的那条
|
| 449 |
+
zhang_wei = [e for e in entities if e.text == "张伟"]
|
| 450 |
+
assert len(zhang_wei) == 1
|
| 451 |
+
assert zhang_wei[0].score == 0.92
|