nerserver / app /ner.py
Robin
fix(zh): slice entity text from original input to avoid BERT tokenizer spaces
f90826c
"""
NER ๆœๅŠกๅฑ‚ โ€” ๅŒๆจกๅž‹่ทฏ็”ฑ + ๅ…œๅบ•ๅˆๅนถ
โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
่ฏญ่จ€ๆฃ€ๆต‹๏ผˆไธคๅฑ‚๏ผ‰๏ผš
1. Unicode ่„šๆœฌๆฏ”ไพ‹๏ผšๅฟซ้€Ÿ๏ผŒ้€‚ๅˆไธญๆ–‡ / ้˜ฟๆ‹‰ไผฏๆ–‡็ญ‰่„šๆœฌๆ˜Žๆ˜พ็š„่ฏญ่จ€
2. langdetect ๅบ“ๅ…œๅบ•๏ผš่ฆ†็›–็บฏ่‹ฑๆ–‡ๅŠ่พน็•Œๆ–‡ๆœฌ
ๅ……ๅˆ†ๆ€งๅˆคๅฎš๏ผˆๆ›ฟไปฃ็ฒ—ๆšด็š„ ==0๏ผ‰๏ผš
expected_min = max( length_floor, label_floor )
length_floor: text<30โ†’1, <100โ†’2, <300โ†’3, โ‰ฅ300โ†’4
label_floor : โŒˆlen(labels)/3โŒ‰๏ผŒๆ—  labels ๆ—ถไธบ 1
ไธปๆจกๅž‹ๅฎžไฝ“ๆ•ฐ < expected_min โ†’ ่งฆๅ‘ๅ…œๅบ•
่ฐƒ็”จๆ–นๅฏๅœจ่ฏทๆฑ‚้‡Œ็›ดๆŽฅไผ  min_entities ่ฆ†็›–ๅฏๅ‘ๅผ
ๅ…œๅบ•ๅˆๅนถ๏ผˆๅ…ณ้”ฎ๏ผš็›ธๅŠ ่€Œ้žๆ›ฟๆข๏ผ‰๏ผš
1. ไธปๆจกๅž‹ๅ…ˆ่ท‘ไธ€้๏ผŒ็ป“ๆžœไฟ็•™
2. ่‹ฅไธๅ……ๅˆ†๏ผŒๅ…œๅบ•ๆจกๅž‹ๅ†่ท‘ไธ€้
3. ไธคไปฝ็ป“ๆžœๅˆๅนถ โ†’ ๆŒ‰ (start, end) ๅŽป้‡๏ผŒๅŒไธ€ span ไฟ็•™ๅพ—ๅˆ†ๆœ€้ซ˜็š„
่ทฏ็”ฑ๏ผš
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ language โ”‚ ไธปๆจกๅž‹ โ†’ ๅ…œๅบ•ๆจกๅž‹ โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ zh โ”‚ BERT-Chinese โ†’ GLiNER โ”‚
โ”‚ en / ar โ”‚ GLiNER โ†’ BERT-Chinese โ”‚
โ”‚ mixed โ”‚ ไธคไธชๆจกๅž‹ๅŒๆ—ถ่ฟ่กŒๅŽๅˆๅนถ โ”‚
โ”‚ auto โ”‚ ๅ…ˆๆฃ€ๆต‹่ฏญ่จ€ๅ†่ทฏ็”ฑ โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
"""
import threading
import unicodedata
from abc import ABC, abstractmethod
from gliner import GLiNER
from app.labels import (
DEFAULT_LABELS,
BERT_TYPE_TO_LABEL,
expand_bilingual,
labels_to_bert_types,
)
from app.models import Entity
# โ”€โ”€ ่ฏญ่จ€ๆฃ€ๆต‹ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
#
# ไธคๅฑ‚็ญ–็•ฅ๏ผš
# Layer-1 Unicode ่„šๆœฌๆฏ”ไพ‹
# ยท ้ๅކๆ–‡ๆœฌไธญๆ‰€ๆœ‰ๅญ—ๆฏๅญ—็ฌฆ๏ผŒ็ปŸ่ฎก CJK / Arabic ่„šๆœฌๅ ๆฏ”
# ยท ไผ˜็‚น๏ผš้›ถไพ่ต–ใ€ๆžๅฟซ๏ผ›็ผบ็‚น๏ผšๅฏนๆž็Ÿญๆˆ–็บฏๆ‹‰ไธๆ–‡ๆœฌๅˆคๆ–ญๅŠ›ๅผฑ
#
# Layer-2 langdetect๏ผˆไป… Layer-1 ่ฟ”ๅ›ž 'en' ๆ—ถไฝœไธบๆ ก้ชŒ๏ผ‰
# ยท ๅŸบไบŽ n-gram ๆฆ‚็އๆจกๅž‹๏ผŒๅŽŸ็†ๅŒ Google CLD2
# ยท ๅฏน็Ÿญๆ–‡ๆœฌ๏ผˆ<20 ๅญ—๏ผ‰ไปๆœ‰ไธ€ๅฎš่ฏฏๅˆค็އ๏ผŒไปฅ Layer-1 ไธบไธป
# ยท ่‹ฅ langdetect ๆฃ€ๆต‹ๅˆฐไธญๆ–‡/ๆ—ฅๆ–‡/้Ÿฉๆ–‡ โ†’ ่ฟ”ๅ›ž 'zh'
# ยท ๅคฑ่ดฅๆ—ถ้™้ป˜ๅ›ž้€€ๅˆฐ Layer-1 ็ป“ๆžœ
def _unicode_script_ratio(text: str) -> str:
"""Layer-1๏ผšๅŸบไบŽ Unicode ่„šๆœฌๆฏ”ไพ‹็š„่ฏญ่จ€ๅˆ†็ฑปใ€‚"""
cjk = arabic = letters = 0
for ch in text:
if not unicodedata.category(ch).startswith("L"):
continue
letters += 1
cp = ord(ch)
if (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or
0xF900 <= cp <= 0xFAFF or 0x20000 <= cp <= 0x2A6DF):
cjk += 1
elif 0x0600 <= cp <= 0x06FF or 0x0750 <= cp <= 0x077F:
arabic += 1
if not letters:
return "en"
cjk_r = cjk / letters
ar_r = arabic / letters
latin_r = (letters - cjk - arabic) / letters
# ไธญๆ–‡+ๆ‹‰ไธ้ƒฝๆ˜พ่‘— โ†’ mixed๏ผˆไผ˜ๅ…ˆ็บง้ซ˜ไบŽๅ•็บฏ zh ๅˆคๆ–ญ๏ผ‰
if cjk_r >= 0.08 and latin_r >= 0.10:
return "mixed"
# ้˜ฟๆ‹‰ไผฏ+ๆ‹‰ไธ้ƒฝๆ˜พ่‘— โ†’ mixed
if ar_r >= 0.08 and latin_r >= 0.10:
return "mixed"
# ๅ•่„šๆœฌไธปๅฏผ
if cjk_r >= 0.20:
return "zh"
if ar_r >= 0.20:
return "ar"
return "en"
def detect_language(text: str) -> str:
"""
ไธคๅฑ‚่ฏญ่จ€ๆฃ€ๆต‹๏ผŒ่ฟ”ๅ›ž 'zh' | 'ar' | 'mixed' | 'en'ใ€‚
Layer-1 ไผ˜ๅ…ˆ๏ผˆUnicode ่„šๆœฌๆฏ”ไพ‹๏ผ‰๏ผ›Layer-1 ่ฟ”ๅ›ž 'en' ๆ—ถ๏ผŒ
็”จ langdetect ๅšไธ€ๆฌกไบŒๆฌก็กฎ่ฎค๏ผŒ้˜ฒๆญขๆŠŠไธญๆ–‡่ฏฏๅˆคไธบ่‹ฑๆ–‡ใ€‚
"""
if not text:
return "en"
layer1 = _unicode_script_ratio(text)
if layer1 != "en": # ๅทฒๆ˜Ž็กฎๆ˜ฏ้ž่‹ฑๆ–‡๏ผŒ็›ดๆŽฅ่ฟ”ๅ›ž
return layer1
# Layer-2๏ผšlangdetect ๆ ก้ชŒ๏ผˆไป…ๅฏน Layer-1='en' ็š„ๆ–‡ๆœฌ๏ผ‰
try:
from langdetect import detect, DetectorFactory
DetectorFactory.seed = 0 # ไฟ่ฏ็ป“ๆžœ็จณๅฎš
lang_code = detect(text) # e.g. 'zh-cn', 'ar', 'en', 'ja' โ€ฆ
if lang_code.startswith("zh") or lang_code in ("ja", "ko"):
return "zh"
if lang_code == "ar":
return "ar"
except Exception:
pass # langdetect ๅคฑ่ดฅๆ—ถ้™้ป˜ๅ›ž้€€
return "en"
# โ”€โ”€ Span ๅŽป้‡ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _deduplicate(entities: list[Entity]) -> list[Entity]:
"""
ๅŒ่ฏญๆ ‡็ญพๆˆ–ๆจกๅž‹ๅˆๅนถๆ—ถๅฏ่ƒฝไบง็”ŸๅŒไธ€ (start, end) ็š„้‡ๅค็ป“ๆžœ๏ผŒ
ไฟ็•™็ฝฎไฟกๅบฆๆœ€้ซ˜็š„้‚ฃๆก๏ผŒๅนถๆŒ‰่ตทๅง‹ไฝ็ฝฎๆŽ’ๅบใ€‚
"""
best: dict[tuple[int, int], Entity] = {}
for e in entities:
key = (e.start, e.end)
if key not in best or e.score > best[key].score:
best[key] = e
return sorted(best.values(), key=lambda x: x.start)
# โ”€โ”€ ๅŽ็ซฏๅŸบ็ฑป โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class _Backend(ABC):
@abstractmethod
def predict(
self, text: str, labels: list[str], threshold: float
) -> tuple[list[Entity], list[str]]:
"""่ฟ”ๅ›ž (entities, labels_used)"""
# โ”€โ”€ GLiNER ๅŽ็ซฏ๏ผˆ่‹ฑๆ–‡ / ้˜ฟๆ‹‰ไผฏๆ–‡ / ๆททๅˆ๏ผ‰ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class GLiNERBackend(_Backend):
"""
้›ถๆ ทๆœฌ NER๏ผšurchade/gliner_multi-v2.1
โ€ข ๆ”ฏๆŒ่‹ฑๆ–‡ใ€้˜ฟๆ‹‰ไผฏๆ–‡ๅŠๆททๅˆๆ–‡ๆœฌ
โ€ข ่‡ชๅŠจๅšๅŒ่ฏญๆ ‡็ญพๆ‰ฉๅฑ•๏ผŒๆๅ‡ๅฌๅ›ž็އ
"""
def __init__(self, model_name: str, cache_dir: str) -> None:
self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir)
def predict(
self, text: str, labels: list[str], threshold: float
) -> tuple[list[Entity], list[str]]:
eff_labels = expand_bilingual(labels) if labels else DEFAULT_LABELS
raw = self._model.predict_entities(text, eff_labels, threshold=threshold)
entities = [
Entity(
text=e["text"],
label=e["label"],
score=round(e["score"], 4),
start=e["start"],
end=e["end"],
)
for e in raw
]
return _deduplicate(entities), eff_labels
# โ”€โ”€ ไธญๆ–‡ BERT ๅŽ็ซฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class ChineseBERTBackend(_Backend):
"""
ไธ“็”จไธญๆ–‡ NER๏ผšshibing624/bert4ner-base-chinese
โ€ข ๆจกๅž‹ๅคงๅฐ๏ผš~400 MB๏ผˆBERT-base๏ผ‰
โ€ข ๆŽจ็†้€Ÿๅบฆ๏ผš~100 ms
โ€ข ๅ›บๅฎšๅฎžไฝ“็ฑปๅž‹๏ผšPER / LOC / ORG / TIME โ†’ ๆ˜ ๅฐ„ไธบๅŒ่ฏญๆ ‡็ญพ
โ€ข ็”จๆˆทไผ ๅ…ฅๆ ‡็ญพๆ—ถๆŒ‰ๆ ‡็ญพ็ฑปๅž‹่ฟ‡ๆปค๏ผ›ๆ— ๆณ•ๆ˜ ๅฐ„็š„่‡ชๅฎšไน‰ๆ ‡็ญพไธ่ฟ‡ๆปค๏ผˆ่ฟ”ๅ›žๅ…จ้ƒจ๏ผ‰
"""
def __init__(self, model_name: str, cache_dir: str) -> None:
# ๅปถ่ฟŸๅฏผๅ…ฅ๏ผš้ฟๅ…้กถๅฑ‚ import ๅœจๆต‹่ฏ•ๆ”ถ้›†้˜ถๆฎต่งฆๅ‘ torch.__spec__ ๆฃ€ๆต‹
from transformers import pipeline as hf_pipeline
self._pipe = hf_pipeline(
"token-classification",
model=model_name,
model_kwargs={"cache_dir": cache_dir},
aggregation_strategy="simple",
)
def predict(
self, text: str, labels: list[str], threshold: float
) -> tuple[list[Entity], list[str]]:
raw = self._pipe(text)
allowed_types = labels_to_bert_types(labels) # None = ไธ่ฟ‡ๆปค
entities: list[Entity] = []
labels_seen: set[str] = set()
for r in raw:
score = float(r["score"])
if score < threshold:
continue
bert_type = r.get("entity_group", r.get("entity", ""))
bert_type = bert_type.lstrip("BI-").strip() # ๅŽปๆމๅฏ่ƒฝ็š„ B-/I- ๅ‰็ผ€
if allowed_types is not None and bert_type not in allowed_types:
continue
std_label = BERT_TYPE_TO_LABEL.get(bert_type, bert_type)
labels_seen.add(std_label)
# Chinese BERT tokenizer ไผšๅœจๅญ่ฏ้—ดๆ’ๅ…ฅ็ฉบๆ ผ๏ผˆ"้ฉฌ ไบ‘"๏ผ‰๏ผŒ
# ็›ดๆŽฅ็”จ start/end ไปŽๅŽŸๆ–‡ๅˆ‡็‰‡๏ผŒ้ฟๅ…็ฉบๆ ผๆฑกๆŸ“
entity_text = text[r["start"]:r["end"]]
entities.append(Entity(
text=entity_text,
label=std_label,
score=round(score, 4),
start=r["start"],
end=r["end"],
))
used = list(labels_seen) if labels_seen else list(BERT_TYPE_TO_LABEL.values())
return entities, used
# โ”€โ”€ NER ๆœๅŠก๏ผˆ่ทฏ็”ฑ + ๅ…œๅบ•๏ผ‰ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class NERService:
"""
ๆŒๆœ‰ไธคไธชๅŽ็ซฏ๏ผŒๆŒ‰ๆฃ€ๆต‹ๅˆฐ็š„่ฏญ่จ€ๅˆ†ๅ‘่ฏทๆฑ‚ใ€‚
ๅ…œๅบ•่ง„ๅˆ™๏ผˆๅฌๅ›žไธบ็ฉบๆ—ถ๏ผ‰๏ผš
zh ไธปๆจกๅž‹ BERT ๆ— ็ป“ๆžœ โ†’ ็”จ GLiNER ่กฅๅ……
en/ar ไธปๆจกๅž‹ GLiNER ๆ— ็ป“ๆžœ โ†’ ็”จ BERT ่กฅๅ……
mixed ๅŒๆ—ถ่ฟ่กŒไธคไธชๆจกๅž‹๏ผŒๅˆๅนถๅŽป้‡ๅŽ่ฟ”ๅ›ž
"""
def __init__(self, en_model_name: str, zh_model_name: str, cache_dir: str) -> None:
self._en_name = en_model_name
self._zh_name = zh_model_name
self._cache_dir = cache_dir
self._en_backend: GLiNERBackend | None = None
self._zh_backend: ChineseBERTBackend | None = None
self._en_lock = threading.Lock()
self._zh_lock = threading.Lock()
# โ”€โ”€ ๆ‡’ๅŠ ่ฝฝ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _en(self) -> GLiNERBackend:
if self._en_backend is None:
with self._en_lock:
if self._en_backend is None:
self._en_backend = GLiNERBackend(self._en_name, self._cache_dir)
return self._en_backend
def _zh(self) -> ChineseBERTBackend:
if self._zh_backend is None:
with self._zh_lock:
if self._zh_backend is None:
self._zh_backend = ChineseBERTBackend(self._zh_name, self._cache_dir)
return self._zh_backend
# โ”€โ”€ ๅ……ๅˆ†ๆ€งๅˆคๅฎš โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@staticmethod
def _expected_min(text: str, labels: list[str]) -> int:
"""
ๅฏๅ‘ๅผ๏ผšๆ นๆฎๆ–‡ๆœฌ้•ฟๅบฆๅ’Œๆ ‡็ญพๆ•ฐ่ฎก็ฎ—ๆœ€ๅฐๆœŸๆœ›ๅฎžไฝ“ๆ•ฐใ€‚
ๅ– length_floor ไธŽ label_floor ไธญ็š„่พƒๅคงๅ€ผใ€‚
"""
n = len(text)
if n < 30: length_floor = 1
elif n < 100: length_floor = 2
elif n < 300: length_floor = 3
else: length_floor = 4
label_floor = max(1, (len(labels) + 2) // 3) if labels else 1
return max(length_floor, label_floor)
# โ”€โ”€ ๅ…œๅบ•ๅˆๅนถ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@staticmethod
def _merge(
primary: tuple[list[Entity], list[str]],
fallback: tuple[list[Entity], list[str]],
) -> tuple[list[Entity], list[str]]:
"""
็›ธๅŠ ๅˆๅนถ๏ผšไฟ็•™ไธปๆจกๅž‹ๆ‰€ๆœ‰็ป“ๆžœ๏ผŒๅ†ๅŠ ไธŠๅ…œๅบ•ๆจกๅž‹็š„็ป“ๆžœ๏ผŒ
ๆŒ‰ (start, end) ๅŽป้‡๏ผˆๅŒไธ€ span ไฟ็•™ๅพ—ๅˆ†ๆœ€้ซ˜๏ผ‰๏ผŒๆŒ‰ไฝ็ฝฎๆŽ’ๅบใ€‚
"""
p_ents, p_labels = primary
f_ents, f_labels = fallback
merged = _deduplicate(p_ents + f_ents)
used = list(dict.fromkeys(p_labels + f_labels)) # ไฟๅบๅŽป้‡
return merged, used
# โ”€โ”€ ไธปๅ…ฅๅฃ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def extract(
self,
text: str,
labels: list[str],
threshold: float,
language: str = "auto",
min_entities: int | None = None,
) -> tuple[list[Entity], list[str]]:
"""
่ฟ”ๅ›ž (entities, labels_used)ใ€‚
่ทฏ็”ฑ๏ผš
auto โ†’ ๆฃ€ๆต‹่ฏญ่จ€ โ†’ ่ทฏ็”ฑ
zh โ†’ BERT ไธป๏ผŒGLiNER ๅ…œๅบ•
en/ar โ†’ GLiNER ไธป๏ผŒBERT ๅ…œๅบ•
mixed โ†’ ไธคๆจกๅž‹ๅŒๆ—ถ่ฟ่กŒ โ†’ ๅˆๅนถ
ๅ…œๅบ•่งฆๅ‘ๆกไปถ๏ผˆzh / en / ar๏ผ‰๏ผš
ไธปๆจกๅž‹ๅฎžไฝ“ๆ•ฐ < expected_min๏ผˆ้ป˜่ฎคๅฏๅ‘ๅผ๏ผŒๅฏ็”ฑ min_entities ่ฆ†็›–๏ผ‰
่งฆๅ‘ๅŽ๏ผšไธป็ป“ๆžœ + ๅ…œๅบ•็ป“ๆžœไธ€ๅนถ่ฟ”ๅ›ž๏ผŒๆŒ‰ span ๅŽป้‡ใ€‚
"""
if not text:
return [], labels
lang = language if language != "auto" else detect_language(text)
# mixed ๆฐธ่ฟœ่ท‘ๅŒๆจกๅž‹ๅนถๅˆๅนถ
if lang == "mixed":
return self._merge(
self._en().predict(text, labels, threshold),
self._zh().predict(text, labels, threshold),
)
# ๅ•่ฏญ่จ€๏ผš้€‰ไธปๆจกๅž‹ + ๅ…œๅบ•ๆจกๅž‹
if lang == "zh":
primary, fallback = self._zh(), self._en()
else: # en / ar
primary, fallback = self._en(), self._zh()
primary_result = primary.predict(text, labels, threshold)
# ๅ……ๅˆ†ๆ€งๅˆคๅฎš
threshold_n = (
min_entities if min_entities is not None
else self._expected_min(text, labels)
)
if len(primary_result[0]) >= threshold_n:
return primary_result
# ไธๅ……ๅˆ† โ†’ ๅ…œๅบ•็›ธๅŠ 
fallback_result = fallback.predict(text, labels, threshold)
return self._merge(primary_result, fallback_result)
def warmup(self) -> None:
"""ๅฏๅŠจๆ—ถ้ข„็ƒญไธคไธชๆจกๅž‹๏ผŒ้ฆ–ไธช่ฏทๆฑ‚ๆ— ้œ€็ญ‰ๅพ…ใ€‚"""
self._en()
self._zh()