| """ |
| NER ๆๅกๅฑ โ ๅๆจกๅ่ทฏ็ฑ + ๅ
ๅบๅๅนถ |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ |
| ่ฏญ่จๆฃๆต๏ผไธคๅฑ๏ผ๏ผ |
| 1. Unicode ่ๆฌๆฏไพ๏ผๅฟซ้๏ผ้ๅไธญๆ / ้ฟๆไผฏๆ็ญ่ๆฌๆๆพ็่ฏญ่จ |
| 2. langdetect ๅบๅ
ๅบ๏ผ่ฆ็็บฏ่ฑๆๅ่พน็ๆๆฌ |
| |
| ๅ
ๅๆงๅคๅฎ๏ผๆฟไปฃ็ฒๆด็ ==0๏ผ๏ผ |
| expected_min = max( length_floor, label_floor ) |
| length_floor: text<30โ1, <100โ2, <300โ3, โฅ300โ4 |
| label_floor : โlen(labels)/3โ๏ผๆ labels ๆถไธบ 1 |
| ไธปๆจกๅๅฎไฝๆฐ < expected_min โ ่งฆๅๅ
ๅบ |
| ่ฐ็จๆนๅฏๅจ่ฏทๆฑ้็ดๆฅไผ min_entities ่ฆ็ๅฏๅๅผ |
| |
| ๅ
ๅบๅๅนถ๏ผๅ
ณ้ฎ๏ผ็ธๅ ่้ๆฟๆข๏ผ๏ผ |
| 1. ไธปๆจกๅๅ
่ทไธ้๏ผ็ปๆไฟ็ |
| 2. ่ฅไธๅ
ๅ๏ผๅ
ๅบๆจกๅๅ่ทไธ้ |
| 3. ไธคไปฝ็ปๆๅๅนถ โ ๆ (start, end) ๅป้๏ผๅไธ span ไฟ็ๅพๅๆ้ซ็ |
| |
| ่ทฏ็ฑ๏ผ |
| โโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโ |
| โ language โ ไธปๆจกๅ โ ๅ
ๅบๆจกๅ โ |
| โโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโค |
| โ zh โ BERT-Chinese โ GLiNER โ |
| โ en / ar โ GLiNER โ BERT-Chinese โ |
| โ mixed โ ไธคไธชๆจกๅๅๆถ่ฟ่กๅๅๅนถ โ |
| โ auto โ ๅ
ๆฃๆต่ฏญ่จๅ่ทฏ็ฑ โ |
| โโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโ |
| """ |
|
|
| import threading |
| import unicodedata |
| from abc import ABC, abstractmethod |
|
|
| from gliner import GLiNER |
|
|
| from app.labels import ( |
| DEFAULT_LABELS, |
| BERT_TYPE_TO_LABEL, |
| expand_bilingual, |
| labels_to_bert_types, |
| ) |
| from app.models import Entity |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _unicode_script_ratio(text: str) -> str: |
| """Layer-1๏ผๅบไบ Unicode ่ๆฌๆฏไพ็่ฏญ่จๅ็ฑปใ""" |
| cjk = arabic = letters = 0 |
| for ch in text: |
| if not unicodedata.category(ch).startswith("L"): |
| continue |
| letters += 1 |
| cp = ord(ch) |
| if (0x4E00 <= cp <= 0x9FFF or 0x3400 <= cp <= 0x4DBF or |
| 0xF900 <= cp <= 0xFAFF or 0x20000 <= cp <= 0x2A6DF): |
| cjk += 1 |
| elif 0x0600 <= cp <= 0x06FF or 0x0750 <= cp <= 0x077F: |
| arabic += 1 |
| if not letters: |
| return "en" |
| cjk_r = cjk / letters |
| ar_r = arabic / letters |
| latin_r = (letters - cjk - arabic) / letters |
|
|
| |
| if cjk_r >= 0.08 and latin_r >= 0.10: |
| return "mixed" |
| |
| if ar_r >= 0.08 and latin_r >= 0.10: |
| return "mixed" |
| |
| if cjk_r >= 0.20: |
| return "zh" |
| if ar_r >= 0.20: |
| return "ar" |
| return "en" |
|
|
|
|
| def detect_language(text: str) -> str: |
| """ |
| ไธคๅฑ่ฏญ่จๆฃๆต๏ผ่ฟๅ 'zh' | 'ar' | 'mixed' | 'en'ใ |
| |
| Layer-1 ไผๅ
๏ผUnicode ่ๆฌๆฏไพ๏ผ๏ผLayer-1 ่ฟๅ 'en' ๆถ๏ผ |
| ็จ langdetect ๅไธๆฌกไบๆฌก็กฎ่ฎค๏ผ้ฒๆญขๆไธญๆ่ฏฏๅคไธบ่ฑๆใ |
| """ |
| if not text: |
| return "en" |
|
|
| layer1 = _unicode_script_ratio(text) |
| if layer1 != "en": |
| return layer1 |
|
|
| |
| try: |
| from langdetect import detect, DetectorFactory |
| DetectorFactory.seed = 0 |
| lang_code = detect(text) |
| if lang_code.startswith("zh") or lang_code in ("ja", "ko"): |
| return "zh" |
| if lang_code == "ar": |
| return "ar" |
| except Exception: |
| pass |
|
|
| return "en" |
|
|
|
|
| |
|
|
| def _deduplicate(entities: list[Entity]) -> list[Entity]: |
| """ |
| ๅ่ฏญๆ ็ญพๆๆจกๅๅๅนถๆถๅฏ่ฝไบง็ๅไธ (start, end) ็้ๅค็ปๆ๏ผ |
| ไฟ็็ฝฎไฟกๅบฆๆ้ซ็้ฃๆก๏ผๅนถๆ่ตทๅงไฝ็ฝฎๆๅบใ |
| """ |
| best: dict[tuple[int, int], Entity] = {} |
| for e in entities: |
| key = (e.start, e.end) |
| if key not in best or e.score > best[key].score: |
| best[key] = e |
| return sorted(best.values(), key=lambda x: x.start) |
|
|
|
|
| |
|
|
| class _Backend(ABC): |
| @abstractmethod |
| def predict( |
| self, text: str, labels: list[str], threshold: float |
| ) -> tuple[list[Entity], list[str]]: |
| """่ฟๅ (entities, labels_used)""" |
|
|
|
|
| |
|
|
| class GLiNERBackend(_Backend): |
| """ |
| ้ถๆ ทๆฌ NER๏ผurchade/gliner_multi-v2.1 |
| โข ๆฏๆ่ฑๆใ้ฟๆไผฏๆๅๆททๅๆๆฌ |
| โข ่ชๅจๅๅ่ฏญๆ ็ญพๆฉๅฑ๏ผๆๅๅฌๅ็ |
| """ |
|
|
| def __init__(self, model_name: str, cache_dir: str) -> None: |
| self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
| def predict( |
| self, text: str, labels: list[str], threshold: float |
| ) -> tuple[list[Entity], list[str]]: |
| eff_labels = expand_bilingual(labels) if labels else DEFAULT_LABELS |
| raw = self._model.predict_entities(text, eff_labels, threshold=threshold) |
| entities = [ |
| Entity( |
| text=e["text"], |
| label=e["label"], |
| score=round(e["score"], 4), |
| start=e["start"], |
| end=e["end"], |
| ) |
| for e in raw |
| ] |
| return _deduplicate(entities), eff_labels |
|
|
|
|
| |
|
|
| class ChineseBERTBackend(_Backend): |
| """ |
| ไธ็จไธญๆ NER๏ผshibing624/bert4ner-base-chinese |
| โข ๆจกๅๅคงๅฐ๏ผ~400 MB๏ผBERT-base๏ผ |
| โข ๆจ็้ๅบฆ๏ผ~100 ms |
| โข ๅบๅฎๅฎไฝ็ฑปๅ๏ผPER / LOC / ORG / TIME โ ๆ ๅฐไธบๅ่ฏญๆ ็ญพ |
| โข ็จๆทไผ ๅ
ฅๆ ็ญพๆถๆๆ ็ญพ็ฑปๅ่ฟๆปค๏ผๆ ๆณๆ ๅฐ็่ชๅฎไนๆ ็ญพไธ่ฟๆปค๏ผ่ฟๅๅ
จ้จ๏ผ |
| """ |
|
|
| def __init__(self, model_name: str, cache_dir: str) -> None: |
| |
| from transformers import pipeline as hf_pipeline |
| self._pipe = hf_pipeline( |
| "token-classification", |
| model=model_name, |
| model_kwargs={"cache_dir": cache_dir}, |
| aggregation_strategy="simple", |
| ) |
|
|
| def predict( |
| self, text: str, labels: list[str], threshold: float |
| ) -> tuple[list[Entity], list[str]]: |
| raw = self._pipe(text) |
| allowed_types = labels_to_bert_types(labels) |
|
|
| entities: list[Entity] = [] |
| labels_seen: set[str] = set() |
|
|
| for r in raw: |
| score = float(r["score"]) |
| if score < threshold: |
| continue |
|
|
| bert_type = r.get("entity_group", r.get("entity", "")) |
| bert_type = bert_type.lstrip("BI-").strip() |
|
|
| if allowed_types is not None and bert_type not in allowed_types: |
| continue |
|
|
| std_label = BERT_TYPE_TO_LABEL.get(bert_type, bert_type) |
| labels_seen.add(std_label) |
| |
| |
| entity_text = text[r["start"]:r["end"]] |
| entities.append(Entity( |
| text=entity_text, |
| label=std_label, |
| score=round(score, 4), |
| start=r["start"], |
| end=r["end"], |
| )) |
|
|
| used = list(labels_seen) if labels_seen else list(BERT_TYPE_TO_LABEL.values()) |
| return entities, used |
|
|
|
|
| |
|
|
| class NERService: |
| """ |
| ๆๆไธคไธชๅ็ซฏ๏ผๆๆฃๆตๅฐ็่ฏญ่จๅๅ่ฏทๆฑใ |
| |
| ๅ
ๅบ่งๅ๏ผๅฌๅไธบ็ฉบๆถ๏ผ๏ผ |
| zh ไธปๆจกๅ BERT ๆ ็ปๆ โ ็จ GLiNER ่กฅๅ
|
| en/ar ไธปๆจกๅ GLiNER ๆ ็ปๆ โ ็จ BERT ่กฅๅ
|
| mixed ๅๆถ่ฟ่กไธคไธชๆจกๅ๏ผๅๅนถๅป้ๅ่ฟๅ |
| """ |
|
|
| def __init__(self, en_model_name: str, zh_model_name: str, cache_dir: str) -> None: |
| self._en_name = en_model_name |
| self._zh_name = zh_model_name |
| self._cache_dir = cache_dir |
|
|
| self._en_backend: GLiNERBackend | None = None |
| self._zh_backend: ChineseBERTBackend | None = None |
| self._en_lock = threading.Lock() |
| self._zh_lock = threading.Lock() |
|
|
| |
|
|
| def _en(self) -> GLiNERBackend: |
| if self._en_backend is None: |
| with self._en_lock: |
| if self._en_backend is None: |
| self._en_backend = GLiNERBackend(self._en_name, self._cache_dir) |
| return self._en_backend |
|
|
| def _zh(self) -> ChineseBERTBackend: |
| if self._zh_backend is None: |
| with self._zh_lock: |
| if self._zh_backend is None: |
| self._zh_backend = ChineseBERTBackend(self._zh_name, self._cache_dir) |
| return self._zh_backend |
|
|
| |
|
|
| @staticmethod |
| def _expected_min(text: str, labels: list[str]) -> int: |
| """ |
| ๅฏๅๅผ๏ผๆ นๆฎๆๆฌ้ฟๅบฆๅๆ ็ญพๆฐ่ฎก็ฎๆๅฐๆๆๅฎไฝๆฐใ |
| ๅ length_floor ไธ label_floor ไธญ็่พๅคงๅผใ |
| """ |
| n = len(text) |
| if n < 30: length_floor = 1 |
| elif n < 100: length_floor = 2 |
| elif n < 300: length_floor = 3 |
| else: length_floor = 4 |
|
|
| label_floor = max(1, (len(labels) + 2) // 3) if labels else 1 |
| return max(length_floor, label_floor) |
|
|
| |
|
|
| @staticmethod |
| def _merge( |
| primary: tuple[list[Entity], list[str]], |
| fallback: tuple[list[Entity], list[str]], |
| ) -> tuple[list[Entity], list[str]]: |
| """ |
| ็ธๅ ๅๅนถ๏ผไฟ็ไธปๆจกๅๆๆ็ปๆ๏ผๅๅ ไธๅ
ๅบๆจกๅ็็ปๆ๏ผ |
| ๆ (start, end) ๅป้๏ผๅไธ span ไฟ็ๅพๅๆ้ซ๏ผ๏ผๆไฝ็ฝฎๆๅบใ |
| """ |
| p_ents, p_labels = primary |
| f_ents, f_labels = fallback |
| merged = _deduplicate(p_ents + f_ents) |
| used = list(dict.fromkeys(p_labels + f_labels)) |
| return merged, used |
|
|
| |
|
|
| def extract( |
| self, |
| text: str, |
| labels: list[str], |
| threshold: float, |
| language: str = "auto", |
| min_entities: int | None = None, |
| ) -> tuple[list[Entity], list[str]]: |
| """ |
| ่ฟๅ (entities, labels_used)ใ |
| |
| ่ทฏ็ฑ๏ผ |
| auto โ ๆฃๆต่ฏญ่จ โ ่ทฏ็ฑ |
| zh โ BERT ไธป๏ผGLiNER ๅ
ๅบ |
| en/ar โ GLiNER ไธป๏ผBERT ๅ
ๅบ |
| mixed โ ไธคๆจกๅๅๆถ่ฟ่ก โ ๅๅนถ |
| |
| ๅ
ๅบ่งฆๅๆกไปถ๏ผzh / en / ar๏ผ๏ผ |
| ไธปๆจกๅๅฎไฝๆฐ < expected_min๏ผ้ป่ฎคๅฏๅๅผ๏ผๅฏ็ฑ min_entities ่ฆ็๏ผ |
| ่งฆๅๅ๏ผไธป็ปๆ + ๅ
ๅบ็ปๆไธๅนถ่ฟๅ๏ผๆ span ๅป้ใ |
| """ |
| if not text: |
| return [], labels |
|
|
| lang = language if language != "auto" else detect_language(text) |
|
|
| |
| if lang == "mixed": |
| return self._merge( |
| self._en().predict(text, labels, threshold), |
| self._zh().predict(text, labels, threshold), |
| ) |
|
|
| |
| if lang == "zh": |
| primary, fallback = self._zh(), self._en() |
| else: |
| primary, fallback = self._en(), self._zh() |
|
|
| primary_result = primary.predict(text, labels, threshold) |
|
|
| |
| threshold_n = ( |
| min_entities if min_entities is not None |
| else self._expected_min(text, labels) |
| ) |
| if len(primary_result[0]) >= threshold_n: |
| return primary_result |
|
|
| |
| fallback_result = fallback.predict(text, labels, threshold) |
| return self._merge(primary_result, fallback_result) |
|
|
| def warmup(self) -> None: |
| """ๅฏๅจๆถ้ข็ญไธคไธชๆจกๅ๏ผ้ฆไธช่ฏทๆฑๆ ้็ญๅพ
ใ""" |
| self._en() |
| self._zh() |
|
|