Robin Claude Sonnet 4.6 commited on
Commit
d6faa4c
·
1 Parent(s): 372fe0c

feat: dual-model routing + fallback strategy (v3.0)

Browse files

Models
EN/AR → urchade/gliner_multi-v2.1 (GLiNER zero-shot, ~1 GB, ~500 ms)
ZH → shibing624/bert4ner-base-chinese (BERT NER, ~400 MB, ~100 ms)
mixed → both models run in parallel, results merged + deduplicated

Language detection (app/ner.py)
Layer-1 Unicode script ratio (CJK / Arabic / Latin)
mixed = cjk>=8% AND latin>=10% (prevents dominant-CJK from
masking bilingual text)
Layer-2 langdetect n-gram fallback when Layer-1 returns 'en'

Fallback / merge (NERService.extract)
zh primary=BERT, fallback=GLiNER when entities==0
en/ar primary=GLiNER, fallback=BERT when entities==0
mixed run both, _deduplicate keeps highest-score per (start,end) span

Other changes
app/config.py EN_MODEL_NAME / ZH_MODEL_NAME env vars; legacy MODEL_NAME kept
app/labels.py BERT_TYPE_TO_LABEL, LABEL_TO_BERT_TYPES, labels_to_bert_types()
requirements.txt add transformers>=4.40.0, langdetect>=1.0.9
Dockerfile pre-download both models at build time
tests 38 tests (was 25); cover fallback, merge, dedup, lang-detect

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (8) hide show
  1. .env.example +7 -1
  2. Dockerfile +13 -3
  3. app/config.py +13 -3
  4. app/labels.py +67 -28
  5. app/main.py +15 -9
  6. app/ner.py +266 -55
  7. requirements.txt +2 -0
  8. tests/test_extract.py +228 -110
.env.example CHANGED
@@ -1,6 +1,12 @@
1
- MODEL_NAME=knowledgator/gliner-multitask-large-v0.5
 
 
 
 
 
2
  MODEL_CACHE_DIR=./model_cache
3
  HOST=0.0.0.0
4
  PORT=4000
 
5
  # 国内环境取消注释;HF Spaces 上留空
6
  # HF_ENDPOINT=https://hf-mirror.com
 
1
+ # 英文 / 阿拉伯文 / 混合文本模型(GLiNER 零样本,~1GB)
2
+ EN_MODEL_NAME=urchade/gliner_multi-v2.1
3
+
4
+ # 中文专用模型(BERT NER,~400MB,快速)
5
+ ZH_MODEL_NAME=shibing624/bert4ner-base-chinese
6
+
7
  MODEL_CACHE_DIR=./model_cache
8
  HOST=0.0.0.0
9
  PORT=4000
10
+
11
  # 国内环境取消注释;HF Spaces 上留空
12
  # HF_ENDPOINT=https://hf-mirror.com
Dockerfile CHANGED
@@ -5,11 +5,19 @@ WORKDIR /app
5
  COPY requirements.txt .
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
8
- # Pre-download model at build time so cold-start needs no network access.
9
- # The image will be ~1.5 GB but startup is instant.
10
  RUN python -c "\
11
  from gliner import GLiNER; \
12
- GLiNER.from_pretrained('knowledgator/gliner-multitask-large-v0.5', cache_dir='/app/model_cache')"
 
 
 
 
 
 
 
 
13
 
14
  COPY app/ app/
15
  COPY run.py .
@@ -17,6 +25,8 @@ COPY run.py .
17
  ENV HOST=0.0.0.0
18
  ENV PORT=7860
19
  ENV MODEL_CACHE_DIR=/app/model_cache
 
 
20
 
21
  EXPOSE 7860
22
 
 
5
  COPY requirements.txt .
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
8
+ # ── 构建时预下载两个模型,冷启动无需联网 ──────────────────────────────────────
9
+ # EN / AR 模型:GLiNER 零样本多语言(~1 GB
10
  RUN python -c "\
11
  from gliner import GLiNER; \
12
+ GLiNER.from_pretrained('urchade/gliner_multi-v2.1', cache_dir='/app/model_cache')"
13
+
14
+ # ZH 模型:BERT 专用中文 NER(~400 MB)
15
+ RUN python -c "\
16
+ from transformers import pipeline; \
17
+ pipeline('token-classification', \
18
+ model='shibing624/bert4ner-base-chinese', \
19
+ model_kwargs={'cache_dir': '/app/model_cache'}, \
20
+ aggregation_strategy='simple')"
21
 
22
  COPY app/ app/
23
  COPY run.py .
 
25
  ENV HOST=0.0.0.0
26
  ENV PORT=7860
27
  ENV MODEL_CACHE_DIR=/app/model_cache
28
+ ENV EN_MODEL_NAME=urchade/gliner_multi-v2.1
29
+ ENV ZH_MODEL_NAME=shibing624/bert4ner-base-chinese
30
 
31
  EXPOSE 7860
32
 
app/config.py CHANGED
@@ -1,12 +1,22 @@
1
  import os
2
 
3
- MODEL_NAME: str = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5")
 
 
 
 
 
 
 
 
 
 
 
4
  MODEL_CACHE_DIR: str = os.getenv("MODEL_CACHE_DIR", "./model_cache")
5
  HOST: str = os.getenv("HOST", "0.0.0.0")
6
  PORT: int = int(os.getenv("PORT", "4000"))
7
 
8
- # Only override HF_ENDPOINT when explicitly set (local mirror) so HF Spaces
9
- # can reach huggingface.co directly without forcing the mirror.
10
  _hf_endpoint = os.getenv("HF_ENDPOINT")
11
  if _hf_endpoint:
12
  os.environ["HF_ENDPOINT"] = _hf_endpoint
 
1
  import os
2
 
3
+ # ── 模型配置 ──────────────────────────────────────────────────────────────────
4
+ # 英文 / 阿拉伯文 / 混合:轻量 GLiNER 零样本模型
5
+ EN_MODEL_NAME: str = os.getenv("EN_MODEL_NAME", "urchade/gliner_multi-v2.1")
6
+
7
+ # 中文:专用 BERT NER 模型(400MB,~100ms/次,4 种固定实体类型)
8
+ ZH_MODEL_NAME: str = os.getenv("ZH_MODEL_NAME", "shibing624/bert4ner-base-chinese")
9
+
10
+ # 兼容旧版环境变量(若设置 MODEL_NAME,则覆盖 EN_MODEL_NAME)
11
+ _legacy = os.getenv("MODEL_NAME")
12
+ if _legacy:
13
+ EN_MODEL_NAME = _legacy
14
+
15
  MODEL_CACHE_DIR: str = os.getenv("MODEL_CACHE_DIR", "./model_cache")
16
  HOST: str = os.getenv("HOST", "0.0.0.0")
17
  PORT: int = int(os.getenv("PORT", "4000"))
18
 
19
+ # 可选:国内镜像(留空则使用 huggingface.co)
 
20
  _hf_endpoint = os.getenv("HF_ENDPOINT")
21
  if _hf_endpoint:
22
  os.environ["HF_ENDPOINT"] = _hf_endpoint
app/labels.py CHANGED
@@ -1,30 +1,32 @@
1
  """
2
  双语标签管理模块
3
- ────────────────
4
- * DEFAULT_LABELS — 内置通用双语标签集(labels 为空时使用)
5
- * expand_bilingual — 自动为已有标签补充对等的另一语言版本
 
 
6
  """
7
 
8
- # 英文 对照表(顺序决定输出标签的语言偏好)
9
  _PAIRS: list[tuple[str, str]] = [
10
- ("full name of a person", "人名或姓名"),
11
- ("company or organization name", "公司或组织机构名称"),
12
- ("geographical location", "地名或城市"),
13
- ("product or technology name", "产品或技术名称"),
14
- ("date or year", "日期或年份"),
15
- ("hospital or medical institution", "医院或医疗机构名称"),
16
- ("university or research institution","大学或研究机构"),
17
- ("project or initiative name", "项目或计划名称"),
18
- ("legislation or policy name", "法规或政策名称"),
19
- ("monetary amount", "金额或货币"),
20
- ("job title or position", "职位或头衔"),
21
- ("event name", "事件或活动名称"),
22
  ]
23
 
24
- # 默认标签集:英中并列,涵盖最常用实体类型
25
  DEFAULT_LABELS: list[str] = [item for pair in _PAIRS for item in pair]
26
 
27
- # 快速查找:任意一种语言的标签 → 对等标签
28
  _EN_TO_ZH: dict[str, str] = {en: zh for en, zh in _PAIRS}
29
  _ZH_TO_EN: dict[str, str] = {zh: en for en, zh in _PAIRS}
30
 
@@ -32,16 +34,7 @@ _ZH_TO_EN: dict[str, str] = {zh: en for en, zh in _PAIRS}
32
  def expand_bilingual(labels: list[str]) -> list[str]:
33
  """
34
  为调用者传入的标签自动补充另一语言的对等描述。
35
-
36
- 例如:
37
- ["人名或姓名", "company or organization name"]
38
- → ["人名或姓名", "full name of a person",
39
- "company or organization name", "公司或组织机构名称"]
40
-
41
- 规则:
42
- * 已有标签保持原位不变
43
- * 对等标签紧随其后插入(若已存在则跳过)
44
- * 未在对照表中的自定义标签原样保留,不做处理
45
  """
46
  seen: set[str] = set(labels)
47
  result: list[str] = []
@@ -52,3 +45,49 @@ def expand_bilingual(labels: list[str]) -> list[str]:
52
  result.append(counterpart)
53
  seen.add(counterpart)
54
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  双语标签管理模块
3
+ ────────────────────────────────────────────────────────────────
4
+ * DEFAULT_LABELS — 内置通用双语标签集(labels 为空时使用)
5
+ * expand_bilingual — 自动为已有标签补充对等的另一语言版本
6
+ * BERT_TYPE_TO_LABEL — 中文 BERT 模型固定实体类型 → 标准双语标签
7
+ * LABEL_TO_BERT_TYPES — 标准标签 → 对应 BERT 实体类型列表(用于过滤)
8
  """
9
 
10
+ # ── 中对照表 ─────────────────────────────────────────────────────────────────
11
  _PAIRS: list[tuple[str, str]] = [
12
+ ("full name of a person", "人名或姓名"),
13
+ ("company or organization name", "公司或组织机构名称"),
14
+ ("geographical location", "地名或城市"),
15
+ ("product or technology name", "产品或技术名称"),
16
+ ("date or year", "日期或年份"),
17
+ ("hospital or medical institution", "医院或医疗机构名称"),
18
+ ("university or research institution", "大学或研究机构"),
19
+ ("project or initiative name", "项目或计划名称"),
20
+ ("legislation or policy name", "法规或政策名称"),
21
+ ("monetary amount", "金额或货币"),
22
+ ("job title or position", "职位或头衔"),
23
+ ("event name", "事件或活动名称"),
24
  ]
25
 
26
+ # 默认标签集:英中并列
27
  DEFAULT_LABELS: list[str] = [item for pair in _PAIRS for item in pair]
28
 
29
+ # 快速查找
30
  _EN_TO_ZH: dict[str, str] = {en: zh for en, zh in _PAIRS}
31
  _ZH_TO_EN: dict[str, str] = {zh: en for en, zh in _PAIRS}
32
 
 
34
  def expand_bilingual(labels: list[str]) -> list[str]:
35
  """
36
  为调用者传入的标签自动补充另一语言的对等描述。
37
+ 已有标签保持原位不变,对等标签紧随其后插入(若已存在则跳过)。
 
 
 
 
 
 
 
 
 
38
  """
39
  seen: set[str] = set(labels)
40
  result: list[str] = []
 
45
  result.append(counterpart)
46
  seen.add(counterpart)
47
  return result
48
+
49
+
50
+ # ── 中文 BERT NER 固定类型映射 ────────────────────────────────────────────────
51
+ # shibing624/bert4ner-base-chinese 输出的实体类型
52
+ BERT_TYPE_TO_LABEL: dict[str, str] = {
53
+ "PER": "人名或姓名",
54
+ "LOC": "地名或城市",
55
+ "ORG": "公司或组织机构名称",
56
+ "TIME": "日期或年份",
57
+ "GPE": "地名或城市", # 部分模型区分 GPE(地缘政治实体)
58
+ }
59
+
60
+ # 标准标签 → BERT 类型列表(用于用户自定义标签过滤)
61
+ LABEL_TO_BERT_TYPES: dict[str, list[str]] = {
62
+ # 人名
63
+ "人名或姓名": ["PER"],
64
+ "full name of a person": ["PER"],
65
+ # 地名
66
+ "地名或城市": ["LOC", "GPE"],
67
+ "geographical location": ["LOC", "GPE"],
68
+ # 机构
69
+ "公司或组织机构名称": ["ORG"],
70
+ "company or organization name": ["ORG"],
71
+ "医院或医疗机构名称": ["ORG"],
72
+ "hospital or medical institution": ["ORG"],
73
+ "大学或研究机构": ["ORG"],
74
+ "university or research institution": ["ORG"],
75
+ # 时间
76
+ "日期或年份": ["TIME"],
77
+ "date or year": ["TIME"],
78
+ }
79
+
80
+
81
+ def labels_to_bert_types(labels: list[str]) -> set[str] | None:
82
+ """
83
+ 将用户标签列表转换为 BERT 实体类型集合。
84
+ 返回 None 表示"接受所有类型"(labels 为空或无法映射时)。
85
+ """
86
+ if not labels:
87
+ return None # 无限制,返回全部
88
+ types: set[str] = set()
89
+ for lbl in labels:
90
+ mapped = LABEL_TO_BERT_TYPES.get(lbl)
91
+ if mapped:
92
+ types.update(mapped)
93
+ return types if types else None # 无映射 → 不过滤
app/main.py CHANGED
@@ -3,7 +3,7 @@ from contextlib import asynccontextmanager
3
 
4
  from fastapi import FastAPI
5
 
6
- from app.config import MODEL_CACHE_DIR, MODEL_NAME
7
  from app.logger import get_logger
8
  from app.models import ExtractRequest, ExtractResponse
9
  from app.ner import NERService
@@ -15,9 +15,14 @@ ner_service: NERService | None = None
15
  @asynccontextmanager
16
  async def lifespan(app: FastAPI):
17
  global ner_service
18
- logger.info("Loading model: %s (cache_dir=%s)", MODEL_NAME, MODEL_CACHE_DIR)
19
- ner_service = NERService(MODEL_NAME, MODEL_CACHE_DIR)
20
- logger.info("Model ready")
 
 
 
 
 
21
  yield
22
  ner_service = None
23
 
@@ -25,11 +30,12 @@ async def lifespan(app: FastAPI):
25
  app = FastAPI(
26
  title="NER API",
27
  description=(
28
- "Zero-shot Named Entity Recognition powered by GLiNER. "
29
- "Supports English, Chinese, Arabic and mixed-language text. "
 
30
  "Labels are optional — omit them to use built-in bilingual defaults."
31
  ),
32
- version="2.0.0",
33
  lifespan=lifespan,
34
  )
35
 
@@ -58,9 +64,9 @@ def extract(req: ExtractRequest):
58
  elapsed_ms = (time.perf_counter() - t0) * 1000
59
 
60
  logger.info(
61
- "extract response | entities=%d elapsed=%.1fms labels_used=%d",
62
  len(entities),
63
  elapsed_ms,
64
- len(labels_used),
65
  )
66
  return ExtractResponse(entities=entities, labels_used=labels_used)
 
3
 
4
  from fastapi import FastAPI
5
 
6
+ from app.config import EN_MODEL_NAME, MODEL_CACHE_DIR, ZH_MODEL_NAME
7
  from app.logger import get_logger
8
  from app.models import ExtractRequest, ExtractResponse
9
  from app.ner import NERService
 
15
  @asynccontextmanager
16
  async def lifespan(app: FastAPI):
17
  global ner_service
18
+ logger.info(
19
+ "Initializing NER service | en_model=%s zh_model=%s cache=%s",
20
+ EN_MODEL_NAME, ZH_MODEL_NAME, MODEL_CACHE_DIR,
21
+ )
22
+ ner_service = NERService(EN_MODEL_NAME, ZH_MODEL_NAME, MODEL_CACHE_DIR)
23
+ # 预热:启动时同时加载两个模型,首个请求无需等待
24
+ ner_service.warmup()
25
+ logger.info("NER service ready")
26
  yield
27
  ner_service = None
28
 
 
30
  app = FastAPI(
31
  title="NER API",
32
  description=(
33
+ "Zero-shot Named Entity Recognition powered by GLiNER (EN/AR) "
34
+ "and BERT-Chinese (ZH). "
35
+ "Supports English · Chinese · Arabic · mixed-language text. "
36
  "Labels are optional — omit them to use built-in bilingual defaults."
37
  ),
38
+ version="3.0.0",
39
  lifespan=lifespan,
40
  )
41
 
 
64
  elapsed_ms = (time.perf_counter() - t0) * 1000
65
 
66
  logger.info(
67
+ "extract response | entities=%d elapsed=%.1fms language=%s",
68
  len(entities),
69
  elapsed_ms,
70
+ req.language,
71
  )
72
  return ExtractResponse(entities=entities, labels_used=labels_used)
app/ner.py CHANGED
@@ -1,45 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import unicodedata
 
2
 
3
  from gliner import GLiNER
4
 
5
- from app.labels import DEFAULT_LABELS, expand_bilingual
 
 
 
 
 
6
  from app.models import Entity
7
 
8
 
9
  # ── 语言检测 ──────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def _detect_language(text: str) -> str:
12
- """
13
- 通过 Unicode 脚本比例判断文本语言。
14
- 返回: 'zh' | 'ar' | 'mixed' | 'en'
15
- """
16
- if not text:
17
- return "en"
18
-
19
  cjk = arabic = letters = 0
20
  for ch in text:
21
- cat = unicodedata.category(ch)
22
- if cat.startswith("L"):
23
- letters += 1
24
- cp = ord(ch)
25
- if (0x4E00 <= cp <= 0x9FFF or # CJK Unified
26
- 0x3400 <= cp <= 0x4DBF or
27
- 0xF900 <= cp <= 0xFAFF or
28
- 0x20000 <= cp <= 0x2A6DF):
29
- cjk += 1
30
- elif 0x0600 <= cp <= 0x06FF or 0x0750 <= cp <= 0x077F:
31
- arabic += 1
32
-
33
  if not letters:
34
  return "en"
35
  cjk_r = cjk / letters
36
  ar_r = arabic / letters
37
- if cjk_r >= 0.20 and ar_r < 0.08:
 
 
 
 
 
 
 
 
 
38
  return "zh"
39
- if ar_r >= 0.20 and cjk_r < 0.08:
40
  return "ar"
41
- if cjk_r >= 0.08 or ar_r >= 0.08:
42
- return "mixed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return "en"
44
 
45
 
@@ -47,8 +110,8 @@ def _detect_language(text: str) -> str:
47
 
48
  def _deduplicate(entities: list[Entity]) -> list[Entity]:
49
  """
50
- 双语标签可能同一 (start, end) 跨度产生两条结果,
51
- 保留置信度最高的那条,并按位置排序。
52
  """
53
  best: dict[tuple[int, int], Entity] = {}
54
  for e in entities:
@@ -58,12 +121,155 @@ def _deduplicate(entities: list[Entity]) -> list[Entity]:
58
  return sorted(best.values(), key=lambda x: x.start)
59
 
60
 
61
- # ── NER 服务 ──────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- class NERService:
64
  def __init__(self, model_name: str, cache_dir: str) -> None:
65
  self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def extract(
68
  self,
69
  text: str,
@@ -74,35 +280,40 @@ class NERService:
74
  """
75
  返回 (entities, labels_used)。
76
 
77
- labels 处理逻辑:
78
- 1. labels 为空 使用内置双语默认标签集
79
- 2. labels 非空 → 自动补充双语对等标签(提升中文召回
80
-
81
- threshold 处理逻辑:
82
- - language='auto' 时自动检测
83
- - 中文 / 混合文本若传入默认 threshold(0.4) 则不调整(已足够低)
84
  """
85
  if not text:
86
  return [], labels
87
 
88
- # 确定有效语言
89
- eff_lang = language if language != "auto" else _detect_language(text)
90
 
91
- # 确定标签集
92
- if not labels:
93
- eff_labels = DEFAULT_LABELS
94
- else:
95
- eff_labels = expand_bilingual(labels)
96
 
97
- raw = self._model.predict_entities(text, eff_labels, threshold=threshold)
98
- entities = [
99
- Entity(
100
- text=e["text"],
101
- label=e["label"],
102
- score=round(e["score"], 4),
103
- start=e["start"],
104
- end=e["end"],
105
- )
106
- for e in raw
107
- ]
108
- return _deduplicate(entities), eff_labels
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NER 服务层 — 双模型路由 + 兜底策略
3
+ ──────────────────────────────────────────────────────────────────────────────
4
+ 语言检测(两层):
5
+ 1. Unicode 脚本比例:快速,适合中文 / 阿拉伯文等脚本明显的语言
6
+ 2. langdetect 库兜底:覆盖纯英文及边界文本
7
+
8
+ 路由 & 兜底规则:
9
+ ┌──────────┬──────────────────┬──────────────────────────────┐
10
+ │ language │ 主模型 │ 兜底条件 │
11
+ ├──────────┼──────────────────┼──────────────────────────────┤
12
+ │ zh │ ChineseBERT │ 实体数=0 → 补充 GLiNER 结果 │
13
+ │ en / ar │ GLiNER │ 实体数=0 → 补充 BERT 结果 │
14
+ │ mixed │ GLiNER + BERT │ 同时运行两个模型,结果合并 │
15
+ │ auto │ 先检测语言再路由 │ │
16
+ └──────────┴──────────────────┴──────────────────────────────┘
17
+ """
18
+
19
+ import threading
20
  import unicodedata
21
+ from abc import ABC, abstractmethod
22
 
23
  from gliner import GLiNER
24
 
25
+ from app.labels import (
26
+ DEFAULT_LABELS,
27
+ BERT_TYPE_TO_LABEL,
28
+ expand_bilingual,
29
+ labels_to_bert_types,
30
+ )
31
  from app.models import Entity
32
 
33
 
34
  # ── 语言检测 ──────────────────────────────────────────────────────────────────
35
+ #
36
+ # 两层策略:
37
+ # Layer-1 Unicode 脚本比例
38
+ # · 遍历文本中所有字母字符,统计 CJK / Arabic 脚本占比
39
+ # · 优点:零依赖、极快;缺点:对极短或纯拉丁文本判断力弱
40
+ #
41
+ # Layer-2 langdetect(仅 Layer-1 返回 'en' 时作为校验)
42
+ # · 基于 n-gram 概率模型,原理同 Google CLD2
43
+ # · 对短文本(<20 字)仍有一定误判率,以 Layer-1 为主
44
+ # · 若 langdetect 检测到中文/日文/韩文 → 返回 'zh'
45
+ # · 失败时静默回退到 Layer-1 结果
46
 
47
+ def _unicode_script_ratio(text: str) -> str:
48
+ """Layer-1:基于 Unicode 脚本比例的语言分类。"""
 
 
 
 
 
 
49
  cjk = arabic = letters = 0
50
  for ch in text:
51
+ if not unicodedata.category(ch).startswith("L"):
52
+ continue
53
+ letters += 1
54
+ cp = ord(ch)
55
+ if (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or
56
+ 0xF900 <= cp <= 0xFAFF or 0x20000 <= cp <= 0x2A6DF):
57
+ cjk += 1
58
+ elif 0x0600 <= cp <= 0x06FF or 0x0750 <= cp <= 0x077F:
59
+ arabic += 1
 
 
 
60
  if not letters:
61
  return "en"
62
  cjk_r = cjk / letters
63
  ar_r = arabic / letters
64
+ latin_r = (letters - cjk - arabic) / letters
65
+
66
+ # 中文+拉丁都显著 → mixed(优先级高于单纯 zh 判断)
67
+ if cjk_r >= 0.08 and latin_r >= 0.10:
68
+ return "mixed"
69
+ # 阿拉伯+拉丁都显著 → mixed
70
+ if ar_r >= 0.08 and latin_r >= 0.10:
71
+ return "mixed"
72
+ # 单脚本主导
73
+ if cjk_r >= 0.20:
74
  return "zh"
75
+ if ar_r >= 0.20:
76
  return "ar"
77
+ return "en"
78
+
79
+
80
+ def detect_language(text: str) -> str:
81
+ """
82
+ 两层语言检测,返回 'zh' | 'ar' | 'mixed' | 'en'。
83
+
84
+ Layer-1 优先(Unicode 脚本比例);Layer-1 返回 'en' 时,
85
+ 用 langdetect 做一次二次确认,防止把中文误判为英文。
86
+ """
87
+ if not text:
88
+ return "en"
89
+
90
+ layer1 = _unicode_script_ratio(text)
91
+ if layer1 != "en": # 已明确是非英文,直接返回
92
+ return layer1
93
+
94
+ # Layer-2:langdetect 校验(仅对 Layer-1='en' 的文本)
95
+ try:
96
+ from langdetect import detect, DetectorFactory
97
+ DetectorFactory.seed = 0 # 保证结果���定
98
+ lang_code = detect(text) # e.g. 'zh-cn', 'ar', 'en', 'ja' …
99
+ if lang_code.startswith("zh") or lang_code in ("ja", "ko"):
100
+ return "zh"
101
+ if lang_code == "ar":
102
+ return "ar"
103
+ except Exception:
104
+ pass # langdetect 失败时静默回退
105
+
106
  return "en"
107
 
108
 
 
110
 
111
  def _deduplicate(entities: list[Entity]) -> list[Entity]:
112
  """
113
+ 双语标签或模型合并时可能产生同一 (start, end) 的重复结果,
114
+ 保留置信度最高的那条,并按起始位置排序。
115
  """
116
  best: dict[tuple[int, int], Entity] = {}
117
  for e in entities:
 
121
  return sorted(best.values(), key=lambda x: x.start)
122
 
123
 
124
+ # ── 后端基类 ──────────────────────────────────────────────────────────────────
125
+
126
+ class _Backend(ABC):
127
+ @abstractmethod
128
+ def predict(
129
+ self, text: str, labels: list[str], threshold: float
130
+ ) -> tuple[list[Entity], list[str]]:
131
+ """返回 (entities, labels_used)"""
132
+
133
+
134
+ # ── GLiNER 后端(英文 / 阿拉伯文 / 混合) ─────────────────────────────────────
135
+
136
+ class GLiNERBackend(_Backend):
137
+ """
138
+ 零样本 NER:urchade/gliner_multi-v2.1
139
+ • 支持英文、阿拉伯文及混合文本
140
+ • 自动做双语标签扩展,提升召回率
141
+ """
142
 
 
143
  def __init__(self, model_name: str, cache_dir: str) -> None:
144
  self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir)
145
 
146
+ def predict(
147
+ self, text: str, labels: list[str], threshold: float
148
+ ) -> tuple[list[Entity], list[str]]:
149
+ eff_labels = expand_bilingual(labels) if labels else DEFAULT_LABELS
150
+ raw = self._model.predict_entities(text, eff_labels, threshold=threshold)
151
+ entities = [
152
+ Entity(
153
+ text=e["text"],
154
+ label=e["label"],
155
+ score=round(e["score"], 4),
156
+ start=e["start"],
157
+ end=e["end"],
158
+ )
159
+ for e in raw
160
+ ]
161
+ return _deduplicate(entities), eff_labels
162
+
163
+
164
+ # ── 中文 BERT 后端 ─────────────────────────────────────────────────────────────
165
+
166
+ class ChineseBERTBackend(_Backend):
167
+ """
168
+ 专用中文 NER:shibing624/bert4ner-base-chinese
169
+ • 模型大小:~400 MB(BERT-base)
170
+ • 推理速度:~100 ms
171
+ • 固定实体类型:PER / LOC / ORG / TIME → 映射为双语标签
172
+ • 用户传入标签时按标签类型过滤;无法映射的自定义标签不过滤(返回全部)
173
+ """
174
+
175
+ def __init__(self, model_name: str, cache_dir: str) -> None:
176
+ # 延迟导入:避免顶层 import 在测试收集阶段触发 torch.__spec__ 检测
177
+ from transformers import pipeline as hf_pipeline
178
+ self._pipe = hf_pipeline(
179
+ "token-classification",
180
+ model=model_name,
181
+ model_kwargs={"cache_dir": cache_dir},
182
+ aggregation_strategy="simple",
183
+ )
184
+
185
+ def predict(
186
+ self, text: str, labels: list[str], threshold: float
187
+ ) -> tuple[list[Entity], list[str]]:
188
+ raw = self._pipe(text)
189
+ allowed_types = labels_to_bert_types(labels) # None = 不过滤
190
+
191
+ entities: list[Entity] = []
192
+ labels_seen: set[str] = set()
193
+
194
+ for r in raw:
195
+ score = float(r["score"])
196
+ if score < threshold:
197
+ continue
198
+
199
+ bert_type = r.get("entity_group", r.get("entity", ""))
200
+ bert_type = bert_type.lstrip("BI-").strip() # 去掉可能的 B-/I- 前缀
201
+
202
+ if allowed_types is not None and bert_type not in allowed_types:
203
+ continue
204
+
205
+ std_label = BERT_TYPE_TO_LABEL.get(bert_type, bert_type)
206
+ labels_seen.add(std_label)
207
+ entities.append(Entity(
208
+ text=r["word"],
209
+ label=std_label,
210
+ score=round(score, 4),
211
+ start=r["start"],
212
+ end=r["end"],
213
+ ))
214
+
215
+ used = list(labels_seen) if labels_seen else list(BERT_TYPE_TO_LABEL.values())
216
+ return entities, used
217
+
218
+
219
+ # ── NER 服务(路由 + 兜底) ─────────────────────────────���──────────────────────
220
+
221
+ class NERService:
222
+ """
223
+ 持有两个后端,按检测到的语言分发请求。
224
+
225
+ 兜底规则(召回为空时):
226
+ zh 主模型 BERT 无结果 → 用 GLiNER 补充
227
+ en/ar 主模型 GLiNER 无结果 → 用 BERT 补充
228
+ mixed 同时运行两个模型,合并去重后返回
229
+ """
230
+
231
+ def __init__(self, en_model_name: str, zh_model_name: str, cache_dir: str) -> None:
232
+ self._en_name = en_model_name
233
+ self._zh_name = zh_model_name
234
+ self._cache_dir = cache_dir
235
+
236
+ self._en_backend: GLiNERBackend | None = None
237
+ self._zh_backend: ChineseBERTBackend | None = None
238
+ self._en_lock = threading.Lock()
239
+ self._zh_lock = threading.Lock()
240
+
241
+ # ── 懒加载 ────────────────────────────────────────────────────────────────
242
+
243
+ def _en(self) -> GLiNERBackend:
244
+ if self._en_backend is None:
245
+ with self._en_lock:
246
+ if self._en_backend is None:
247
+ self._en_backend = GLiNERBackend(self._en_name, self._cache_dir)
248
+ return self._en_backend
249
+
250
+ def _zh(self) -> ChineseBERTBackend:
251
+ if self._zh_backend is None:
252
+ with self._zh_lock:
253
+ if self._zh_backend is None:
254
+ self._zh_backend = ChineseBERTBackend(self._zh_name, self._cache_dir)
255
+ return self._zh_backend
256
+
257
+ # ── 兜底合并 ──────────────────────────────────────────────────────────────
258
+
259
+ def _merge(
260
+ self,
261
+ primary: tuple[list[Entity], list[str]],
262
+ fallback: tuple[list[Entity], list[str]],
263
+ ) -> tuple[list[Entity], list[str]]:
264
+ """合并两个模型的结果,去重后按位置排序。"""
265
+ p_ents, p_labels = primary
266
+ f_ents, f_labels = fallback
267
+ merged = _deduplicate(p_ents + f_ents)
268
+ used = list(dict.fromkeys(p_labels + f_labels)) # 保序去重
269
+ return merged, used
270
+
271
+ # ── 主入口 ────────────────────────────────────────────────────────────────
272
+
273
  def extract(
274
  self,
275
  text: str,
 
280
  """
281
  返回 (entities, labels_used)。
282
 
283
+ 路由逻辑:
284
+ auto → 检测语言路由
285
+ zh → BERT 主,GLiNER 兜底(主模型无结果时补充)
286
+ en/ar → GLiNER 主,BERT 兜底(主模型无结果时补充)
287
+ mixed → 两模型同时运行,结果合并去重
 
 
288
  """
289
  if not text:
290
  return [], labels
291
 
292
+ lang = language if language != "auto" else detect_language(text)
 
293
 
294
+ if lang == "mixed":
295
+ # 同时运行两个模型,合并结果
296
+ en_result = self._en().predict(text, labels, threshold)
297
+ zh_result = self._zh().predict(text, labels, threshold)
298
+ return self._merge(en_result, zh_result)
299
 
300
+ if lang == "zh":
301
+ primary_result = self._zh().predict(text, labels, threshold)
302
+ if not primary_result[0]: # 主模型无结果 → GLiNER 兜底
303
+ fallback_result = self._en().predict(text, labels, threshold)
304
+ if fallback_result[0]:
305
+ return fallback_result
306
+ return primary_result
307
+
308
+ # en / ar / 其他
309
+ primary_result = self._en().predict(text, labels, threshold)
310
+ if not primary_result[0]: # 主模型无结果 → BERT 兜底
311
+ fallback_result = self._zh().predict(text, labels, threshold)
312
+ if fallback_result[0]:
313
+ return fallback_result
314
+ return primary_result
315
+
316
+ def warmup(self) -> None:
317
+ """启动时预热两个模型,首个请求无需等待。"""
318
+ self._en()
319
+ self._zh()
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  fastapi>=0.111.0
2
  uvicorn[standard]>=0.29.0
3
  gliner>=0.2.0
 
 
 
1
  fastapi>=0.111.0
2
  uvicorn[standard]>=0.29.0
3
  gliner>=0.2.0
4
+ transformers>=4.40.0
5
+ langdetect>=1.0.9
tests/test_extract.py CHANGED
@@ -1,27 +1,39 @@
1
  """
2
- Unit tests — no real model loaded (GLiNER/torch stubbed in conftest.py).
3
  Covers:
4
- - API contract (health, validation, threshold forwarding)
5
- - New v2 features: optional labels, bilingual expansion, labels_used echo
6
- - English / Chinese / Arabic / mixed-language text handling
 
 
7
  """
8
- from unittest.mock import MagicMock, patch
9
 
10
  import pytest
11
  from fastapi.testclient import TestClient
12
 
13
- import app.main as main_module
14
  from app.main import app
15
  from app.models import Entity
16
- from app.labels import DEFAULT_LABELS, expand_bilingual
 
 
 
 
 
 
 
 
17
 
18
 
19
  # ── Fixture ───────────────────────────────────────────────────────────────────
20
 
21
  @pytest.fixture()
22
  def client():
 
 
 
 
23
  mock_ner = MagicMock()
24
- # Default: extract() returns ([], [])
25
  mock_ner.extract.return_value = ([], [])
26
  with pytest.MonkeyPatch().context() as mp:
27
  mp.setattr("app.main.NERService", lambda *_: mock_ner)
@@ -29,61 +41,49 @@ def client():
29
  yield c, mock_ner
30
 
31
 
32
- def _ents(*args) -> tuple[list[Entity], list[str]]:
33
- """Helper: wrap Entity list in the (entities, labels_used) tuple."""
34
- entities = list(args)
35
- labels = [e.label for e in entities]
36
- return entities, labels
37
-
38
-
39
  # ── System / API contract ─────────────────────────────────────────────────────
40
 
41
  def test_health(client):
42
  c, _ = client
43
- resp = c.get("/api/v1/health")
44
- assert resp.status_code == 200
45
- assert resp.json() == {"status": "ok"}
46
 
47
 
48
- def test_extract_empty_text(client):
49
- c, mock_ner = client
50
  resp = c.post("/api/v1/extract", json={"text": "", "labels": ["person"]})
51
  assert resp.status_code == 200
52
  assert resp.json()["entities"] == []
53
 
54
 
55
- def test_extract_empty_labels_uses_defaults(client):
56
- """labels 为空时服务端应自动使用默认双语标签集,报错。"""
57
  c, mock_ner = client
58
  mock_ner.extract.return_value = ([], DEFAULT_LABELS)
59
- resp = c.post("/api/v1/extract", json={"text": "Apple Inc. is in Cupertino."})
60
  assert resp.status_code == 200
61
- data = resp.json()
62
- assert "entities" in data
63
- assert "labels_used" in data
64
- assert len(data["labels_used"]) > 0
65
 
66
 
67
- def test_extract_omit_labels_entirely(client):
68
- """labels 字段完全不传也该正常工作。"""
69
  c, mock_ner = client
70
  mock_ner.extract.return_value = ([], DEFAULT_LABELS)
71
- resp = c.post("/api/v1/extract", json={"text": "Some text."})
72
  assert resp.status_code == 200
 
73
 
74
 
75
  def test_extract_threshold_forwarded(client):
76
  c, mock_ner = client
77
  c.post("/api/v1/extract",
78
- json={"text": "Hello world", "labels": ["person"], "threshold": 0.8})
79
- mock_ner.extract.assert_called_once_with("Hello world", ["person"], 0.8, language="auto")
80
 
81
 
82
  def test_extract_invalid_threshold(client):
83
  c, _ = client
84
- resp = c.post("/api/v1/extract",
85
- json={"text": "Hello", "labels": ["person"], "threshold": 1.5})
86
- assert resp.status_code == 422
87
 
88
 
89
  def test_extract_language_field_forwarded(client):
@@ -94,86 +94,117 @@ def test_extract_language_field_forwarded(client):
94
 
95
 
96
  def test_extract_invalid_language(client):
97
- """不支持的 language 值应返回 422。"""
98
  c, _ = client
99
- resp = c.post("/api/v1/extract",
100
- json={"text": "Hello", "language": "jp"})
101
- assert resp.status_code == 422
 
 
 
 
 
 
 
102
 
103
 
104
  def test_entity_response_fields(client):
105
- """每个实体包含全部必填字段且值合法。"""
106
  c, mock_ner = client
107
  mock_ner.extract.return_value = _ents(
108
  Entity(text="Apple", label="organization", score=0.95, start=0, end=5)
109
  )
110
  resp = c.post("/api/v1/extract",
111
  json={"text": "Apple is great.", "labels": ["organization"]})
112
- assert resp.status_code == 200
113
  e = resp.json()["entities"][0]
114
- assert set(e.keys()) >= {"text", "label", "score", "start", "end"}
115
  assert 0.0 <= e["score"] <= 1.0
116
  assert e["start"] < e["end"]
117
 
118
 
119
- def test_labels_used_echoed(client):
120
- """响应中 labels_used 应回传实际使用的标签列表。"""
121
- c, mock_ner = client
122
- used = ["person", "organization"]
123
- mock_ner.extract.return_value = ([], used)
124
- resp = c.post("/api/v1/extract",
125
- json={"text": "Elon Musk works at Tesla.", "labels": ["person"]})
126
- assert resp.status_code == 200
127
- assert resp.json()["labels_used"] == used
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
- # ── Bilingual label expansion (unit-level, no HTTP) ───────────────────────────
131
 
132
  def test_expand_bilingual_adds_english_for_chinese():
133
  result = expand_bilingual(["人名或姓名"])
134
- assert "人名或姓名" in result
135
  assert "full name of a person" in result
136
 
137
 
138
  def test_expand_bilingual_adds_chinese_for_english():
139
  result = expand_bilingual(["company or organization name"])
140
- assert "company or organization name" in result
141
  assert "公司或组织机构名称" in result
142
 
143
 
144
  def test_expand_bilingual_no_duplicate():
145
- labels = ["人名或姓名", "full name of a person"]
146
- result = expand_bilingual(labels)
147
  assert result.count("人名或姓名") == 1
148
  assert result.count("full name of a person") == 1
149
 
150
 
151
  def test_expand_bilingual_custom_label_preserved():
152
- """自定义标签(不在对照表中)原样保留。"""
153
  result = expand_bilingual(["my custom label"])
154
  assert "my custom label" in result
155
 
156
 
157
- def test_default_labels_nonempty():
158
- assert len(DEFAULT_LABELS) > 0
159
- # 必须包含中英文各至少一个
160
  has_en = any(all(ord(c) < 128 for c in lbl) for lbl in DEFAULT_LABELS)
161
  has_zh = any(any('一' <= c <= '鿿' for c in lbl) for lbl in DEFAULT_LABELS)
162
  assert has_en and has_zh
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # ── English ───────────────────────────────────────────────────────────────────
166
 
167
  def test_english_person_org(client):
168
  c, mock_ner = client
169
  mock_ner.extract.return_value = _ents(
170
- Entity(text="Elon Musk", label="person", score=0.98, start=0, end=9),
171
- Entity(text="Tesla", label="organization", score=0.96, start=18, end=23),
172
- Entity(text="SpaceX", label="organization", score=0.97, start=28, end=34),
173
  )
174
  resp = c.post("/api/v1/extract",
175
  json={"text": "Elon Musk is the CEO of Tesla and founded SpaceX.",
176
- "labels": ["full name of a person", "company or organization name"]})
 
177
  assert resp.status_code == 200
178
  texts = {e["text"] for e in resp.json()["entities"]}
179
  assert {"Elon Musk", "Tesla", "SpaceX"} <= texts
@@ -188,28 +219,24 @@ def test_english_location_date(client):
188
  )
189
  resp = c.post("/api/v1/extract",
190
  json={"text": "The summit was held in Paris in 2024, in France.",
191
- "labels": ["geographical location", "date or year"]})
192
- assert resp.status_code == 200
193
  texts = {e["text"] for e in resp.json()["entities"]}
194
  assert {"Paris", "France", "2024"} <= texts
195
 
196
 
197
- def test_english_threshold_filters(client):
198
  c, mock_ner = client
199
- mock_ner.extract.return_value = _ents(
200
- Entity(text="NASA", label="organization", score=0.95, start=0, end=4),
201
- )
202
- resp = c.post("/api/v1/extract",
203
- json={"text": "NASA explored the Moon.",
204
- "labels": ["company or organization name"],
205
- "threshold": 0.8})
206
- assert resp.status_code == 200
207
  mock_ner.extract.assert_called_once_with(
208
- "NASA explored the Moon.", ["company or organization name"], 0.8, language="auto"
209
  )
210
 
211
 
212
- # ── Chinese ───────────────────────────────────────────────────────────────────
213
 
214
  def test_chinese_person_org(client):
215
  c, mock_ner = client
@@ -222,13 +249,12 @@ def test_chinese_person_org(client):
222
  json={"text": "阿里巴巴集团创始人马云卸任,由张勇接任。",
223
  "labels": ["人名或姓名", "公司或组织机构名称"],
224
  "language": "zh"})
225
- assert resp.status_code == 200
226
  texts = {e["text"] for e in resp.json()["entities"]}
227
  assert {"马云", "张勇", "阿里巴巴"} <= texts
228
 
229
 
230
  def test_chinese_entity_boundary(client):
231
- """实体边界不应包含动词 — '尤氏来请' 应只取 '尤氏'。"""
232
  c, mock_ner = client
233
  mock_ner.extract.return_value = _ents(
234
  Entity(text="尤氏", label="人名或姓名", score=0.82, start=0, end=2),
@@ -236,8 +262,7 @@ def test_chinese_entity_boundary(client):
236
  )
237
  resp = c.post("/api/v1/extract",
238
  json={"text": "尤氏来请,王熙凤笑道:'你来了。'",
239
- "labels": ["人名或姓名"]})
240
- assert resp.status_code == 200
241
  texts = {e["text"] for e in resp.json()["entities"]}
242
  assert "尤氏" in texts
243
  assert "王熙凤" in texts
@@ -248,33 +273,42 @@ def test_chinese_entity_boundary(client):
248
  def test_chinese_location_product(client):
249
  c, mock_ner = client
250
  mock_ner.extract.return_value = _ents(
251
- Entity(text="杭州", label="地名或城市", score=0.93, start=17, end=19),
252
- Entity(text="淘宝", label="产品或品牌名称", score=0.91, start=22, end=24),
253
- Entity(text="天猫", label="产品或品牌名称", score=0.92, start=25, end=27),
254
- Entity(text="支付宝", label="产品或品牌名称", score=0.90, start=28, end=31),
255
  )
256
  resp = c.post("/api/v1/extract",
257
  json={"text": "阿里巴巴总部位于杭州,旗下有淘宝、天猫、支付宝。",
258
- "labels": ["地名或城市", "产品或品牌名称"]})
259
- assert resp.status_code == 200
260
  texts = {e["text"] for e in resp.json()["entities"]}
261
  assert {"杭州", "淘宝", "天猫", "支付宝"} <= texts
262
 
263
 
 
 
 
 
 
 
 
 
 
 
 
264
  # ── Arabic ────────────────────────────────────────────────────────────────────
265
 
266
  def test_arabic_person_location(client):
267
- """阿拉伯语:识别人名与地名。"""
268
  c, mock_ner = client
269
  mock_ner.extract.return_value = _ents(
270
  Entity(text="محمد بن سلمان", label="full name of a person", score=0.82, start=12, end=26),
271
  Entity(text="المملكة العربية السعودية", label="geographical location", score=0.85, start=44, end=68),
272
  )
273
  resp = c.post("/api/v1/extract",
274
- json={"text": "أعلن الرئيس محمد بن سلمان عن مشروع نيوم في المملكة العربية السعودية.",
275
  "labels": ["full name of a person", "geographical location"],
276
  "language": "ar"})
277
- assert resp.status_code == 200
278
  texts = {e["text"] for e in resp.json()["entities"]}
279
  assert "محمد بن سلمان" in texts
280
  assert "المملكة العربية السعودية" in texts
@@ -295,29 +329,12 @@ def test_mixed_entities_both_scripts(client):
295
  "labels": ["full name of a person", "人名或姓名",
296
  "company or organization name", "公司或组织机构名称",
297
  "geographical location", "地名或城市",
298
- "product or technology name"]})
299
- assert resp.status_code == 200
300
  texts = {e["text"] for e in resp.json()["entities"]}
301
  assert {"张伟", "Google", "北京", "Android"} <= texts
302
 
303
 
304
- def test_mixed_labels_chinese_and_english(client):
305
- c, mock_ner = client
306
- mock_ner.extract.return_value = _ents(
307
- Entity(text="李明", label="人名或姓名", score=0.94, start=0, end=2),
308
- Entity(text="Tesla", label="人名或姓名", score=0.96, start=10, end=15),
309
- Entity(text="上海", label="地名或城市", score=0.92, start=22, end=24),
310
- )
311
- resp = c.post("/api/v1/extract",
312
- json={"text": "李明在上海加入了 Tesla。",
313
- "labels": ["人名或姓名", "full name of a person",
314
- "地名或城市", "geographical location",
315
- "company or organization name"]})
316
- assert resp.status_code == 200
317
- texts = {e["text"] for e in resp.json()["entities"]}
318
- assert {"李明", "Tesla", "上海"} <= texts
319
-
320
-
321
  def test_mixed_no_cross_language_contamination(client):
322
  c, mock_ner = client
323
  mock_ner.extract.return_value = _ents(
@@ -327,7 +344,108 @@ def test_mixed_no_cross_language_contamination(client):
327
  resp = c.post("/api/v1/extract",
328
  json={"text": "他在 OpenAI 工作,同事王芳也在同一部门。",
329
  "labels": ["person", "organization"]})
330
- assert resp.status_code == 200
331
  entities = resp.json()["entities"]
332
  assert any(e["text"] == "OpenAI" and e["label"] == "organization" for e in entities)
333
  assert any(e["text"] == "王芳" and e["label"] == "person" for e in entities)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Unit tests — no real model loaded (GLiNER/torch stubbed via conftest.py).
3
  Covers:
4
+ - API contract (health, validation, threshold, language field)
5
+ - Dual-model routing (ZH BERT backend, EN/AR/mixed GLiNER backend)
6
+ - Optional labels + bilingual auto-expansion
7
+ - English / Chinese / Arabic / mixed-language scenarios
8
+ - labels_used echo in response
9
  """
10
+ from unittest.mock import MagicMock, patch, PropertyMock
11
 
12
  import pytest
13
  from fastapi.testclient import TestClient
14
 
 
15
  from app.main import app
16
  from app.models import Entity
17
+ from app.labels import DEFAULT_LABELS, expand_bilingual, labels_to_bert_types
18
+ from app.ner import detect_language
19
+
20
+
21
+ # ── Helpers ───────────────────────────────────────────────────────────────────
22
+
23
+ def _ents(*args: Entity) -> tuple[list[Entity], list[str]]:
24
+ """Wrap entities in (entities, labels_used) tuple expected by NERService."""
25
+ return list(args), [e.label for e in args]
26
 
27
 
28
  # ── Fixture ───────────────────────────────────────────────────────────────────
29
 
30
  @pytest.fixture()
31
  def client():
32
+ """
33
+ Patch NERService so no model is actually loaded.
34
+ mock_ner.extract() returns ([], []) by default.
35
+ """
36
  mock_ner = MagicMock()
 
37
  mock_ner.extract.return_value = ([], [])
38
  with pytest.MonkeyPatch().context() as mp:
39
  mp.setattr("app.main.NERService", lambda *_: mock_ner)
 
41
  yield c, mock_ner
42
 
43
 
 
 
 
 
 
 
 
44
  # ── System / API contract ─────────────────────────────────────────────────────
45
 
46
  def test_health(client):
47
  c, _ = client
48
+ assert c.get("/api/v1/health").json() == {"status": "ok"}
 
 
49
 
50
 
51
+ def test_extract_empty_text_returns_empty(client):
52
+ c, _ = client
53
  resp = c.post("/api/v1/extract", json={"text": "", "labels": ["person"]})
54
  assert resp.status_code == 200
55
  assert resp.json()["entities"] == []
56
 
57
 
58
+ def test_extract_labels_optional(client):
59
+ """labels 字段完全传应正常返回 200。"""
60
  c, mock_ner = client
61
  mock_ner.extract.return_value = ([], DEFAULT_LABELS)
62
+ resp = c.post("/api/v1/extract", json={"text": "Some text."})
63
  assert resp.status_code == 200
64
+ assert len(resp.json()["labels_used"]) > 0
 
 
 
65
 
66
 
67
+ def test_extract_empty_labels_uses_defaults(client):
68
+ """labels=[] 使用默认双语标签集。"""
69
  c, mock_ner = client
70
  mock_ner.extract.return_value = ([], DEFAULT_LABELS)
71
+ resp = c.post("/api/v1/extract", json={"text": "Hello world.", "labels": []})
72
  assert resp.status_code == 200
73
+ assert resp.json()["labels_used"] == DEFAULT_LABELS
74
 
75
 
76
  def test_extract_threshold_forwarded(client):
77
  c, mock_ner = client
78
  c.post("/api/v1/extract",
79
+ json={"text": "Hello", "labels": ["person"], "threshold": 0.8})
80
+ mock_ner.extract.assert_called_once_with("Hello", ["person"], 0.8, language="auto")
81
 
82
 
83
  def test_extract_invalid_threshold(client):
84
  c, _ = client
85
+ assert c.post("/api/v1/extract",
86
+ json={"text": "x", "threshold": 1.5}).status_code == 422
 
87
 
88
 
89
  def test_extract_language_field_forwarded(client):
 
94
 
95
 
96
  def test_extract_invalid_language(client):
 
97
  c, _ = client
98
+ assert c.post("/api/v1/extract",
99
+ json={"text": "x", "language": "jp"}).status_code == 422
100
+
101
+
102
+ def test_labels_used_echoed(client):
103
+ c, mock_ner = client
104
+ used = ["人名或姓名", "地名或城市"]
105
+ mock_ner.extract.return_value = ([], used)
106
+ resp = c.post("/api/v1/extract", json={"text": "马云在杭州。", "labels": ["人名或姓名"]})
107
+ assert resp.json()["labels_used"] == used
108
 
109
 
110
  def test_entity_response_fields(client):
 
111
  c, mock_ner = client
112
  mock_ner.extract.return_value = _ents(
113
  Entity(text="Apple", label="organization", score=0.95, start=0, end=5)
114
  )
115
  resp = c.post("/api/v1/extract",
116
  json={"text": "Apple is great.", "labels": ["organization"]})
 
117
  e = resp.json()["entities"][0]
118
+ assert {"text", "label", "score", "start", "end"} <= e.keys()
119
  assert 0.0 <= e["score"] <= 1.0
120
  assert e["start"] < e["end"]
121
 
122
 
123
+ # ── Language detection ────────────────────────────────────────────────────────
124
+
125
+ def test_detect_language_english():
126
+ assert detect_language("Elon Musk founded SpaceX in California.") == "en"
127
+
128
+
129
+ def test_detect_language_chinese():
130
+ assert detect_language("阿里巴巴集团创始人马云于杭州卸任。") == "zh"
131
+
132
+
133
+ def test_detect_language_arabic():
134
+ assert detect_language("أعلن الرئيس محمد بن سلمان عن مشروع نيوم.") == "ar"
135
+
136
+
137
+ def test_detect_language_mixed():
138
+ assert detect_language("张伟加入了 Google 北京研发中心,负责 Android 优化。") == "mixed"
139
+
140
+
141
+ def test_detect_language_empty():
142
+ assert detect_language("") == "en"
143
 
144
 
145
+ # ── Bilingual label expansion ─────────────────────────────────────────────────
146
 
147
  def test_expand_bilingual_adds_english_for_chinese():
148
  result = expand_bilingual(["人名或姓名"])
 
149
  assert "full name of a person" in result
150
 
151
 
152
  def test_expand_bilingual_adds_chinese_for_english():
153
  result = expand_bilingual(["company or organization name"])
 
154
  assert "公司或组织机构名称" in result
155
 
156
 
157
  def test_expand_bilingual_no_duplicate():
158
+ result = expand_bilingual(["人名或姓名", "full name of a person"])
 
159
  assert result.count("人名或姓名") == 1
160
  assert result.count("full name of a person") == 1
161
 
162
 
163
  def test_expand_bilingual_custom_label_preserved():
 
164
  result = expand_bilingual(["my custom label"])
165
  assert "my custom label" in result
166
 
167
 
168
+ def test_default_labels_bilingual():
 
 
169
  has_en = any(all(ord(c) < 128 for c in lbl) for lbl in DEFAULT_LABELS)
170
  has_zh = any(any('一' <= c <= '鿿' for c in lbl) for lbl in DEFAULT_LABELS)
171
  assert has_en and has_zh
172
 
173
 
174
+ # ── BERT label mapping ────────────────────────────────────────────────────────
175
+
176
+ def test_labels_to_bert_types_chinese_label():
177
+ types = labels_to_bert_types(["人名或姓名"])
178
+ assert "PER" in types
179
+
180
+
181
+ def test_labels_to_bert_types_english_label():
182
+ types = labels_to_bert_types(["geographical location"])
183
+ assert "LOC" in types
184
+
185
+
186
+ def test_labels_to_bert_types_empty_returns_none():
187
+ assert labels_to_bert_types([]) is None
188
+
189
+
190
+ def test_labels_to_bert_types_unmapped_returns_none():
191
+ # 无法映射的标签 → 不过滤(返回 None)
192
+ assert labels_to_bert_types(["some unknown label"]) is None
193
+
194
+
195
  # ── English ───────────────────────────────────────────────────────────────────
196
 
197
  def test_english_person_org(client):
198
  c, mock_ner = client
199
  mock_ner.extract.return_value = _ents(
200
+ Entity(text="Elon Musk", label="person", score=0.98, start=0, end=9),
201
+ Entity(text="Tesla", label="organization", score=0.96, start=18, end=23),
202
+ Entity(text="SpaceX", label="organization", score=0.97, start=28, end=34),
203
  )
204
  resp = c.post("/api/v1/extract",
205
  json={"text": "Elon Musk is the CEO of Tesla and founded SpaceX.",
206
+ "labels": ["full name of a person", "company or organization name"],
207
+ "language": "en"})
208
  assert resp.status_code == 200
209
  texts = {e["text"] for e in resp.json()["entities"]}
210
  assert {"Elon Musk", "Tesla", "SpaceX"} <= texts
 
219
  )
220
  resp = c.post("/api/v1/extract",
221
  json={"text": "The summit was held in Paris in 2024, in France.",
222
+ "labels": ["geographical location", "date or year"],
223
+ "language": "en"})
224
  texts = {e["text"] for e in resp.json()["entities"]}
225
  assert {"Paris", "France", "2024"} <= texts
226
 
227
 
228
+ def test_english_threshold_forwarded(client):
229
  c, mock_ner = client
230
+ c.post("/api/v1/extract",
231
+ json={"text": "NASA explored the Moon.",
232
+ "labels": ["company or organization name"],
233
+ "threshold": 0.8, "language": "en"})
 
 
 
 
234
  mock_ner.extract.assert_called_once_with(
235
+ "NASA explored the Moon.", ["company or organization name"], 0.8, language="en"
236
  )
237
 
238
 
239
+ # ── Chinese (BERT backend) ────────────────────────────────────────────────────
240
 
241
  def test_chinese_person_org(client):
242
  c, mock_ner = client
 
249
  json={"text": "阿里巴巴集团创始人马云卸任,由张勇接任。",
250
  "labels": ["人名或姓名", "公司或组织机构名称"],
251
  "language": "zh"})
 
252
  texts = {e["text"] for e in resp.json()["entities"]}
253
  assert {"马云", "张勇", "阿里巴巴"} <= texts
254
 
255
 
256
  def test_chinese_entity_boundary(client):
257
+ """BERT NER 应精确截断实体边界不含动词。"""
258
  c, mock_ner = client
259
  mock_ner.extract.return_value = _ents(
260
  Entity(text="尤氏", label="人名或姓名", score=0.82, start=0, end=2),
 
262
  )
263
  resp = c.post("/api/v1/extract",
264
  json={"text": "尤氏来请,王熙凤笑道:'你来了。'",
265
+ "labels": ["人名或姓名"], "language": "zh"})
 
266
  texts = {e["text"] for e in resp.json()["entities"]}
267
  assert "尤氏" in texts
268
  assert "王熙凤" in texts
 
273
  def test_chinese_location_product(client):
274
  c, mock_ner = client
275
  mock_ner.extract.return_value = _ents(
276
+ Entity(text="杭州", label="地名或城市", score=0.93, start=9, end=11),
277
+ Entity(text="淘宝", label="产品或品牌名称", score=0.91, start=14, end=16),
278
+ Entity(text="天猫", label="产品或品牌名称", score=0.92, start=17, end=19),
279
+ Entity(text="支付宝", label="产品或品牌名称", score=0.90, start=20, end=23),
280
  )
281
  resp = c.post("/api/v1/extract",
282
  json={"text": "阿里巴巴总部位于杭州,旗下有淘宝、天猫、支付宝。",
283
+ "labels": ["地名或城市", "产品或品牌名称"],
284
+ "language": "zh"})
285
  texts = {e["text"] for e in resp.json()["entities"]}
286
  assert {"杭州", "淘宝", "天猫", "支付宝"} <= texts
287
 
288
 
289
+ def test_chinese_auto_routes_to_zh(client):
290
+ """auto 检测到中文应路由到 ZH 模型(language 透传为 'auto',内部检测为 zh)。"""
291
+ c, mock_ner = client
292
+ c.post("/api/v1/extract",
293
+ json={"text": "马云创立了阿里巴巴。"})
294
+ # NERService.extract 被调用时 language='auto',路由逻辑在 ner.py 内部处理
295
+ mock_ner.extract.assert_called_once()
296
+ call_kwargs = mock_ner.extract.call_args
297
+ assert call_kwargs[1].get("language") == "auto" or call_kwargs[0][3] == "auto"
298
+
299
+
300
  # ── Arabic ────────────────────────────────────────────────────────────────────
301
 
302
  def test_arabic_person_location(client):
 
303
  c, mock_ner = client
304
  mock_ner.extract.return_value = _ents(
305
  Entity(text="محمد بن سلمان", label="full name of a person", score=0.82, start=12, end=26),
306
  Entity(text="المملكة العربية السعودية", label="geographical location", score=0.85, start=44, end=68),
307
  )
308
  resp = c.post("/api/v1/extract",
309
+ json={"text": "أعلن الرئيس محمد بن سلمان مشروع نيوم في المملكة العربية السعودية.",
310
  "labels": ["full name of a person", "geographical location"],
311
  "language": "ar"})
 
312
  texts = {e["text"] for e in resp.json()["entities"]}
313
  assert "محمد بن سلمان" in texts
314
  assert "المملكة العربية السعودية" in texts
 
329
  "labels": ["full name of a person", "人名或姓名",
330
  "company or organization name", "公司或组织机构名称",
331
  "geographical location", "地名或城市",
332
+ "product or technology name"],
333
+ "language": "mixed"})
334
  texts = {e["text"] for e in resp.json()["entities"]}
335
  assert {"张伟", "Google", "北京", "Android"} <= texts
336
 
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def test_mixed_no_cross_language_contamination(client):
339
  c, mock_ner = client
340
  mock_ner.extract.return_value = _ents(
 
344
  resp = c.post("/api/v1/extract",
345
  json={"text": "他在 OpenAI 工作,同事王芳也在同一部门。",
346
  "labels": ["person", "organization"]})
 
347
  entities = resp.json()["entities"]
348
  assert any(e["text"] == "OpenAI" and e["label"] == "organization" for e in entities)
349
  assert any(e["text"] == "王芳" and e["label"] == "person" for e in entities)
350
+
351
+
352
+ # ── Fallback & merge (NERService unit tests, no HTTP) ────────────────────────
353
+
354
+ def test_fallback_zh_empty_uses_en():
355
+ """ZH 主模型返回空时,应使用 GLiNER 兜底。"""
356
+ from app.ner import NERService
357
+
358
+ svc = NERService.__new__(NERService)
359
+ svc._en_lock = __import__("threading").Lock()
360
+ svc._zh_lock = __import__("threading").Lock()
361
+
362
+ # ZH backend: returns nothing
363
+ zh_mock = MagicMock()
364
+ zh_mock.predict.return_value = ([], [])
365
+ # EN fallback: returns one entity
366
+ en_mock = MagicMock()
367
+ en_mock.predict.return_value = _ents(
368
+ Entity(text="马云", label="person", score=0.75, start=0, end=2)
369
+ )
370
+ svc._zh_backend = zh_mock
371
+ svc._en_backend = en_mock
372
+
373
+ entities, _ = svc.extract("马云", [], 0.4, language="zh")
374
+ assert any(e.text == "马云" for e in entities)
375
+ zh_mock.predict.assert_called_once()
376
+ en_mock.predict.assert_called_once() # 兜底被调用
377
+
378
+
379
+ def test_fallback_zh_has_results_no_en_called():
380
+ """ZH 主模型有结果时,不应调用 GLiNER 兜底。"""
381
+ from app.ner import NERService
382
+
383
+ svc = NERService.__new__(NERService)
384
+ svc._en_lock = __import__("threading").Lock()
385
+ svc._zh_lock = __import__("threading").Lock()
386
+
387
+ zh_mock = MagicMock()
388
+ zh_mock.predict.return_value = _ents(
389
+ Entity(text="马云", label="person", score=0.92, start=0, end=2)
390
+ )
391
+ en_mock = MagicMock()
392
+ svc._zh_backend = zh_mock
393
+ svc._en_backend = en_mock
394
+
395
+ svc.extract("马云", [], 0.4, language="zh")
396
+ en_mock.predict.assert_not_called() # 不应调用兜底
397
+
398
+
399
+ def test_mixed_runs_both_models_and_merges():
400
+ """Mixed 语言应同时运行两个模型并合并结果。"""
401
+ from app.ner import NERService
402
+
403
+ svc = NERService.__new__(NERService)
404
+ svc._en_lock = __import__("threading").Lock()
405
+ svc._zh_lock = __import__("threading").Lock()
406
+
407
+ en_mock = MagicMock()
408
+ en_mock.predict.return_value = _ents(
409
+ Entity(text="Google", label="organization", score=0.95, start=5, end=11)
410
+ )
411
+ zh_mock = MagicMock()
412
+ zh_mock.predict.return_value = _ents(
413
+ Entity(text="张伟", label="person", score=0.91, start=0, end=2)
414
+ )
415
+ svc._en_backend = en_mock
416
+ svc._zh_backend = zh_mock
417
+
418
+ entities, _ = svc.extract("张伟加入 Google。", [], 0.4, language="mixed")
419
+ texts = {e.text for e in entities}
420
+ assert "Google" in texts
421
+ assert "张伟" in texts
422
+ en_mock.predict.assert_called_once()
423
+ zh_mock.predict.assert_called_once()
424
+
425
+
426
+ def test_mixed_deduplicates_overlapping_spans():
427
+ """两个模型对同一 span 都命中时,只保留得分最高的。"""
428
+ from app.ner import NERService
429
+
430
+ svc = NERService.__new__(NERService)
431
+ svc._en_lock = __import__("threading").Lock()
432
+ svc._zh_lock = __import__("threading").Lock()
433
+
434
+ en_mock = MagicMock()
435
+ en_mock.predict.return_value = (
436
+ [Entity(text="张伟", label="person", score=0.70, start=0, end=2)],
437
+ ["person"],
438
+ )
439
+ zh_mock = MagicMock()
440
+ zh_mock.predict.return_value = (
441
+ [Entity(text="张伟", label="人名或姓名", score=0.92, start=0, end=2)],
442
+ ["人名或姓名"],
443
+ )
444
+ svc._en_backend = en_mock
445
+ svc._zh_backend = zh_mock
446
+
447
+ entities, _ = svc.extract("张伟", [], 0.4, language="mixed")
448
+ # 去重后只有 1 个 "张伟",且是得分更高的那条
449
+ zhang_wei = [e for e in entities if e.text == "张伟"]
450
+ assert len(zhang_wei) == 1
451
+ assert zhang_wei[0].score == 0.92