| """ |
| ๅฏนๆฏ gliner_multi-v2.1 ๅ gliner-multitask-large-v0.5 ไธคไธชๆจกๅ |
| ๅจไธญๆใ่ฑๆใ้ฟๆไผฏๆใไธญ่ฑๆททๅๆๆฌไธ็ NER ๆๆใ |
| |
| ไผๅ็น๏ผ |
| - ๆๆๆต่ฏ็จไพ็ปไธไฝฟ็จๅ่ฏญๆ ็ญพ๏ผไธญ่ฑๅนถๅ๏ผ๏ผๆๅไธญๆ่ฏๅซ็ |
| - ็ปๆๅๅ
ฅ UTF-8 Markdown ๆฅๅ๏ผ้ฟๅ
Windows GBK ๆงๅถๅฐไนฑ็ |
| - ๆฐๅข้ฟๆไผฏ่ฏญๆต่ฏ็จไพ |
| - ๆฐๅข span ๅป้๏ผๅ่ฏญๆ ็ญพๅฏ่ฝไบง็้ๅค่ทจๅบฆ๏ผไฟ็ๅพๅๆ้ซ็ |
| |
| ็จๆณ๏ผ |
| python scripts/compare_models.py |
| ๆฅๅ๏ผ |
| reports/comparison_report.md |
| """ |
| import io |
| import os |
| import sys |
| import time |
| from dataclasses import dataclass, field |
| from pathlib import Path |
|
|
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" |
|
|
| from huggingface_hub import snapshot_download |
| from gliner import GLiNER |
|
|
|
|
| |
|
|
| CASES = [ |
| |
| { |
| "name": "EN-01 ่ฑๆ ยท ็งๆไบบ็ฉ", |
| "lang": "en", |
| "text": ( |
| "Elon Musk, CEO of Tesla and founder of SpaceX, announced a new " |
| "Starship launch from Boca Chica, Texas. NASA has partnered with " |
| "SpaceX for the Artemis lunar lander mission planned for 2026." |
| ), |
| "labels": [ |
| "full name of a person", |
| "company or organization name", |
| "geographical location", |
| "product or technology name", |
| "date or year", |
| ], |
| "expected": ["Elon Musk", "Tesla", "SpaceX", "NASA", "Boca Chica", "Texas", "2026"], |
| }, |
| { |
| "name": "EN-02 ่ฑๆ ยท ๆฟๆฒปๆฐ้ป", |
| "lang": "en", |
| "text": ( |
| "President Biden signed the Inflation Reduction Act in Washington D.C. " |
| "on August 16, 2022. The legislation was championed by Senator Chuck Schumer " |
| "and was seen as a major win for the Democratic Party." |
| ), |
| "labels": [ |
| "full name of a person", |
| "company or organization name", |
| "geographical location", |
| "legislation or policy name", |
| "date or year", |
| "political party", |
| ], |
| "expected": ["Biden", "Chuck Schumer", "Washington D.C.", "August 16, 2022", "Democratic Party"], |
| }, |
| |
| { |
| "name": "ZH-01 ไธญๆ ยท ็ฐไปฃๅไธ๏ผๅ่ฏญๆ ็ญพ๏ผ", |
| "lang": "zh", |
| "text": ( |
| "้ฟ้ๅทดๅทด้ๅขๅๅงไบบ้ฉฌไบไบ2019ๅนดๅธไปป่ฃไบๅฑไธปๅธญ๏ผ็ฑๅผ ๅๆฅไปปใ" |
| "ๆป้จไฝไบๆญๅท็้ฟ้ๅทดๅทดๆไธๆฅๆๆทๅฎใๅคฉ็ซใๆฏไปๅฎ็ญไธๅกๆฟๅใ" |
| ), |
| "labels": [ |
| "ไบบๅๆๅงๅ", "full name of a person", |
| "ๅ
ฌๅธๆ็ป็ปๆบๆๅ็งฐ", "company or organization name", |
| "ๅฐๅๆๅๅธ", "geographical location", |
| "ไบงๅๆๅ็ๅ็งฐ", "product or brand name", |
| "ๆฅๆๆๅนดไปฝ", "date or year", |
| ], |
| "expected": ["้ฉฌไบ", "ๅผ ๅ", "้ฟ้ๅทดๅทด", "ๆญๅท", "ๆทๅฎ", "ๅคฉ็ซ", "ๆฏไปๅฎ", "2019"], |
| }, |
| { |
| "name": "ZH-02 ไธญๆ ยท ๅคๅ
ธๆๅญฆ๏ผ่พน็ๆต่ฏ๏ผ", |
| "lang": "zh", |
| "text": ( |
| "ๅฐคๆฐๆฅ่ฏท๏ผ็็ๅค็ฌ้๏ผ'ไฝ ๆฅไบใ'่ดพๆฏๅฝไบบๆ้
๏ผ" |
| "ๅฎ็ๅ้ป็ๅจๅคง่งๅญๆฃๆญฅ๏ผ่ๅฎ้็ฌๅๆขจ้ฆ้ขใ" |
| ), |
| "labels": [ |
| "ไบบๅๆๅงๅ", "full name of a person", |
| "ๅฐๅๆๅบๆ", "place or location name", |
| ], |
| "expected": ["ๅฐคๆฐ", "็็ๅค", "่ดพๆฏ", "ๅฎ็", "้ป็", "่ๅฎ้", "ๅคง่งๅญ", "ๆขจ้ฆ้ข"], |
| "boundary_check": { |
| "must_not_contain": ["ๅฐคๆฐๆฅ่ฏท", "็็ๅค็ฌ้"], |
| }, |
| }, |
| { |
| "name": "ZH-03 ไธญๆ ยท ๅป็ๅบๆฏ๏ผๅ่ฏญๆ ็ญพ๏ผ", |
| "lang": "zh", |
| "text": ( |
| "ๅไบฌๅๅๅป้ขๅฟๅ
็งไธปไปป็ๅปบๅฝๆๆๅข้๏ผไบ2023ๅนดๆๅๅฎๆ้ฆไพ" |
| "ๆบๅจไบบ่พ
ๅฉๅ ็ถๅจ่ๆญๆกฅๆๆฏ๏ผๆฃ่
ๆฅ่ชๅฑฑไธ็ๆตๅๅธใ" |
| ), |
| "labels": [ |
| "ไบบๅๆๅงๅ", "full name of a person", |
| "ๅป้ขๆๆบๆๅ็งฐ", "hospital or institution name", |
| "ๅฐๅๆๅๅธ", "geographical location", |
| "ๅป็ๆๆฏๆๆๆฏๅ็งฐ", "medical procedure or technology", |
| "ๆฅๆๆๅนดไปฝ", "date or year", |
| ], |
| "expected": ["็ๅปบๅฝ", "ๅไบฌๅๅๅป้ข", "ๆตๅ", "ๅฑฑไธ", "2023"], |
| }, |
| |
| { |
| "name": "AR-01 ้ฟๆไผฏ่ฏญ ยท ๆฐ้ป", |
| "lang": "ar", |
| "text": ( |
| "ุฃุนูู ุงูุฑุฆูุณ ู
ุญู
ุฏ ุจู ุณูู
ุงู ุนู ุฅุทูุงู ู
ุดุฑูุน ูููู
ูู ุงูู
ู
ููุฉ ุงูุนุฑุจูุฉ ุงูุณุนูุฏูุฉ " |
| "ุนุงู
2017ุ ูุชุจูุบ ุชูููุชู 500 ู
ููุงุฑ ุฏููุงุฑ." |
| ), |
| "labels": [ |
| "full name of a person", |
| "company or organization name", |
| "geographical location", |
| "project or initiative name", |
| "date or year", |
| "monetary amount", |
| ], |
| "expected": ["ู
ุญู
ุฏ ุจู ุณูู
ุงู", "ูููู
", "ุงูู
ู
ููุฉ ุงูุนุฑุจูุฉ ุงูุณุนูุฏูุฉ", "2017"], |
| }, |
| |
| { |
| "name": "MIX-01 ไธญ่ฑๆททๅ ยท ่ๅบๅบๆฏ๏ผๅ่ฏญๆ ็ญพ๏ผ", |
| "lang": "mixed", |
| "text": ( |
| "ๅผ ไผๅ ๅ
ฅไบ Google ๅไบฌ็ ๅไธญๅฟ๏ผ่ด่ดฃ Android ็ณป็ปไผๅใ" |
| "ไป็ๅไบ Sarah Chen ๆฅ่ช Meta๏ผไธคไบบๅ
ฑๅๅไธไบ 2024 ๅนด็ AI Summitใ" |
| ), |
| "labels": [ |
| "ไบบๅๆๅงๅ", "full name of a person", |
| "ๅ
ฌๅธๆ็ป็ปๆบๆๅ็งฐ", "company or organization name", |
| "ๅฐๅๆๅๅธ", "geographical location", |
| "ไบงๅๆๆๆฏๅ็งฐ", "product or technology name", |
| "ๆฅๆๆๅนดไปฝ", "date or year", |
| ], |
| "expected": ["ๅผ ไผ", "Google", "Sarah Chen", "Meta", "Android", "ๅไบฌ", "2024"], |
| }, |
| { |
| "name": "MIX-02 ไธญ่ฑๆททๅ ยท ๅญฆๆฏๅบๆฏ๏ผๅ่ฏญๆ ็ญพ๏ผ", |
| "lang": "mixed", |
| "text": ( |
| "ๆธ
ๅๅคงๅญฆ่ฎก็ฎๆบ็ณปๆๆๆๆๅจ NeurIPS 2023 ๅ่กจไบๅ
ณไบ Transformer ๆถๆ็่ฎบๆ๏ผ" |
| "ๅไฝ่
ๆฅ่ช MIT ๅ Stanford Universityใ" |
| ), |
| "labels": [ |
| "ไบบๅๆๅงๅ", "full name of a person", |
| "ๅคงๅญฆๆ็ ็ฉถๆบๆ", "university or research institution", |
| "ไผ่ฎฎๆๆๅๅ็งฐ", "conference or journal name", |
| "ๆๆฏๆๆจกๅๅ็งฐ", "technology or model name", |
| "ๆฅๆๆๅนดไปฝ", "date or year", |
| ], |
| "expected": ["ๆๆ", "ๆธ
ๅๅคงๅญฆ", "NeurIPS 2023", "Transformer", "MIT", "Stanford University"], |
| }, |
| ] |
|
|
| THRESHOLD = 0.4 |
| CACHE_DIR = "./model_cache" |
| REPORT_DIR = Path("reports") |
|
|
| MODELS = [ |
| ("gliner_multi-v2.1", "urchade/gliner_multi-v2.1"), |
| ("gliner-multitask-large-v0.5", "knowledgator/gliner-multitask-large-v0.5"), |
| ] |
|
|
|
|
| |
|
|
| def deduplicate(entities: list[dict]) -> list[dict]: |
| """ๅ่ฏญๆ ็ญพๅฏ่ฝๅฏนๅไธ span ไบง็ไธคๆก็ปๆ๏ผไฟ็ๅพๅๆ้ซ็้ฃๆกใ""" |
| best: dict[tuple, dict] = {} |
| for e in entities: |
| key = (e["start"], e["end"]) |
| if key not in best or e["score"] > best[key]["score"]: |
| best[key] = e |
| return sorted(best.values(), key=lambda x: x["start"]) |
|
|
|
|
| |
|
|
| def ensure_local(model_name: str) -> str: |
| safe = model_name.replace("/", "__") |
| local_dir = Path(CACHE_DIR) / safe |
| if local_dir.exists() and any(local_dir.iterdir()): |
| print(f" [cached] {local_dir}") |
| else: |
| print(f" [download] {model_name} -> {local_dir}") |
| snapshot_download(repo_id=model_name, local_dir=str(local_dir)) |
| print(f" [done]") |
| return str(local_dir) |
|
|
|
|
| |
|
|
| @dataclass |
| class CaseResult: |
| case_name: str |
| lang: str |
| text: str |
| expected: list[str] |
| entities: list[dict] |
| elapsed_ms: float |
| boundary_violations: list[str] = field(default_factory=list) |
|
|
| @property |
| def found_texts(self) -> set[str]: |
| return {e["text"] for e in self.entities} |
|
|
| @property |
| def hit_count(self) -> int: |
| return sum(1 for exp in self.expected if exp in self.found_texts) |
|
|
| @property |
| def recall(self) -> float: |
| if not self.expected: |
| return 1.0 |
| return self.hit_count / len(self.expected) |
|
|
|
|
| @dataclass |
| class ModelResult: |
| model_name: str |
| load_ms: float |
| cases: list[CaseResult] = field(default_factory=list) |
|
|
| @property |
| def avg_recall(self) -> float: |
| if not self.cases: |
| return 0.0 |
| return sum(c.recall for c in self.cases) / len(self.cases) |
|
|
| @property |
| def avg_infer_ms(self) -> float: |
| if not self.cases: |
| return 0.0 |
| return sum(c.elapsed_ms for c in self.cases) / len(self.cases) |
|
|
|
|
| |
|
|
| def run_model(short_name: str, model_name: str) -> ModelResult: |
| print(f"\n{'โ'*60}") |
| print(f"Loading model: {model_name}") |
| t0 = time.perf_counter() |
| local_path = ensure_local(model_name) |
| model = GLiNER.from_pretrained(local_path, local_files_only=True) |
| load_ms = (time.perf_counter() - t0) * 1000 |
| print(f"[loaded] {load_ms:.0f}ms") |
|
|
| result = ModelResult(model_name=short_name, load_ms=load_ms) |
| for case in CASES: |
| t1 = time.perf_counter() |
| raw = model.predict_entities(case["text"], case["labels"], threshold=THRESHOLD) |
| elapsed_ms = (time.perf_counter() - t1) * 1000 |
| entities = deduplicate(raw) |
|
|
| bc = case.get("boundary_check", {}) |
| violations = [ |
| e["text"] for e in entities |
| if e["text"] in bc.get("must_not_contain", []) |
| ] |
| result.cases.append(CaseResult( |
| case_name=case["name"], |
| lang=case["lang"], |
| text=case["text"], |
| expected=case.get("expected", []), |
| entities=entities, |
| elapsed_ms=elapsed_ms, |
| boundary_violations=violations, |
| )) |
| status = "OK" if not violations else f"BOUNDARY ERR: {violations}" |
| print(f" {case['name'][:30]:30s} {len(entities):2d} entities {elapsed_ms:.0f}ms {status}") |
|
|
| return result |
|
|
|
|
| |
|
|
| def write_report(all_results: list[ModelResult], out_path: Path): |
| buf = io.StringIO() |
| w = buf.write |
|
|
| w("# NER ๆจกๅๅฏนๆฏๆต่ฏๆฅๅ\n\n") |
| w(f"็ๆๆถ้ด๏ผ{time.strftime('%Y-%m-%d %H:%M:%S')} \n") |
| w(f"้ๅผ๏ผthreshold๏ผ๏ผ`{THRESHOLD}` \n\n") |
|
|
| |
| w("## ไธใๆฑๆปๅฏนๆฏ\n\n") |
| header = "| ๆต่ฏ็จไพ | ่ฏญ่จ |" |
| sep = "|---|---|" |
| for r in all_results: |
| header += f" {r.model_name} ๅฌๅ | {r.model_name} ่ๆถ |" |
| sep += "---|---|" |
| w(header + "\n") |
| w(sep + "\n") |
|
|
| for i, case in enumerate(CASES): |
| row = f"| {case['name']} | `{case['lang']}` |" |
| for r in all_results: |
| cr = r.cases[i] |
| pct = f"{cr.recall*100:.0f}%" |
| row += f" {cr.hit_count}/{len(cr.expected)} ({pct}) | {cr.elapsed_ms:.0f}ms |" |
| w(row + "\n") |
|
|
| |
| avg_row = "| **ๅนณๅ** | โ |" |
| for r in all_results: |
| avg_row += f" **{r.avg_recall*100:.0f}%** | **{r.avg_infer_ms:.0f}ms** |" |
| w(avg_row + "\n\n") |
|
|
| |
| w("## ไบใๆจกๅๅ ่ฝฝๆถ้ด\n\n") |
| w("| ๆจกๅ | ๅ ่ฝฝ่ๆถ |\n|---|---|\n") |
| for r in all_results: |
| w(f"| {r.model_name} | {r.load_ms/1000:.1f}s |\n") |
| w("\n") |
|
|
| |
| w("## ไธใ้็จไพ่ฏฆ็ป็ปๆ\n\n") |
| for i, case in enumerate(CASES): |
| w(f"### {case['name']}\n\n") |
| w(f"**ๆๆฌ**\n```\n{case['text']}\n```\n\n") |
| w(f"**ๆๆๅฎไฝ**๏ผ{', '.join(f'`{e}`' for e in case.get('expected', []))}\n\n") |
|
|
| for r in all_results: |
| cr = r.cases[i] |
| hits = [e for e in cr.expected if e in cr.found_texts] |
| misses = [e for e in cr.expected if e not in cr.found_texts] |
|
|
| w(f"#### {r.model_name} ๏ผ{cr.elapsed_ms:.0f}ms๏ผ{len(cr.entities)} ไธชๅฎไฝ๏ผๅฌๅ {cr.recall*100:.0f}%๏ผ\n\n") |
|
|
| if cr.entities: |
| w("| ๆๆฌ | ๆ ็ญพ | ็ฝฎไฟกๅบฆ | ๅฝไธญๆๆ |\n|---|---|---|---|\n") |
| for e in cr.entities: |
| hit_mark = "โ" if e["text"] in cr.expected else "" |
| w(f"| `{e['text']}` | {e['label']} | {e['score']:.2f} | {hit_mark} |\n") |
| else: |
| w("_ๆช่ฏๅซๅฐๅฎไฝ_\n") |
|
|
| if misses: |
| w(f"\n**ๆชๅฝไธญ**๏ผ{', '.join(f'`{m}`' for m in misses)}\n") |
| if cr.boundary_violations: |
| w(f"\n> โ ๏ธ **่พน็้่ฏฏ**๏ผ{cr.boundary_violations}\n") |
| w("\n") |
|
|
| |
| w("## ๅใ็ป่ฎบไธๅปบ่ฎฎ\n\n") |
| best = max(all_results, key=lambda r: r.avg_recall) |
| fast = min(all_results, key=lambda r: r.avg_infer_ms) |
| w(f"- **็ปผๅๅฌๅๆ้ซ**๏ผ`{best.model_name}`๏ผๅนณๅๅฌๅ {best.avg_recall*100:.0f}%๏ผ\n") |
| w(f"- **ๆจ็ๆๅฟซ**๏ผ`{fast.model_name}`๏ผๅนณๅ {fast.avg_infer_ms:.0f}ms/ๆฌก๏ผ\n\n") |
| w("### ไผๅๅปบ่ฎฎ\n\n") |
| w("1. **ๅ่ฏญๆ ็ญพ็ญ็ฅ**๏ผๅฏนไธญๆๆๆททๅๆๆฌ๏ผๅๆถๆไพไธญ่ฑๆๆ ็ญพๆ่ฟฐ๏ผๅฆ `\"ไบบๅๆๅงๅ\"` + `\"full name of a person\"`๏ผ๏ผๅฏๆพ่ๆๅไธญๆๅฎไฝๅฌๅ็ใGLiNER ๆฏ้ถๆ ทๆฌๆจกๅ๏ผๆ ็ญพๆ่ฟฐ่ถๅ
ทไฝใ่ถๆฅ่ฟ่ฎญ็ป่ฏญๆ็่กจ่พพๆนๅผ๏ผ่ฏๅซๆๆ่ถๅฅฝใ\n") |
| w("2. **Span ๅป้**๏ผไฝฟ็จๅ่ฏญๆ ็ญพๆถๅไธๆๆฌ่ทจๅบฆๅฏ่ฝ่ขซๆไธไธคไธชๆ ็ญพ๏ผๅปบ่ฎฎๅจๆๅกๅฑๆ `(start, end)` ๅป้๏ผไฟ็ๅพๅๆ้ซ็็ปๆ๏ผๅทฒๅจ `app/ner.py` ๅฎ็ฐ๏ผใ\n") |
| w("3. **้ๅผ่ฐไผ**๏ผ่ฑๆๅปบ่ฎฎ `threshold=0.5`๏ผไธญๆๅปบ่ฎฎ `threshold=0.35~0.4`๏ผๆจกๅๅฏนไธญๆ็ฝฎไฟกๅบฆๆฎ้ๅไฝ๏ผใ\n") |
| w("4. **ๅคๅ
ธ/ๆ่จๆ**๏ผไธคไธชๆจกๅๅฏนๆ่จๆๆฏๆๅๅผฑ๏ผๅปบ่ฎฎ็ปๅ่งๅๆไธ็จๆจกๅ๏ผๅฆ `BERT-CRF` ๅจๅคๆฑ่ฏญ่ฏญๆไธๅพฎ่ฐ๏ผๅค็ๆญค็ฑปๆๆฌใ\n") |
| w("5. **้ฟๆไผฏ่ฏญ**๏ผ`gliner-multitask-large-v0.5` ๅจๅค่ฏญ่จไธ่ฎญ็ป๏ผๅฏน้ฟๆไผฏ่ฏญๆๅบ็กๆฏๆ๏ผ`gliner_multi-v2.1` ้ฟๆไผฏ่ฏญๆๆๆ้ใ\n") |
|
|
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(buf.getvalue(), encoding="utf-8") |
| print(f"\n[report] {out_path.resolve()}") |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| all_results: list[ModelResult] = [] |
| for short_name, model_name in MODELS: |
| all_results.append(run_model(short_name, model_name)) |
|
|
| |
| print(f"\n{'='*70}") |
| print(f"{'Case':<42} " + " ".join(f"{r.model_name[:20]:<20}" for r in all_results)) |
| print(f"{'โ'*70}") |
| for i, case in enumerate(CASES): |
| row = f"{case['name'][:40]:<42}" |
| for r in all_results: |
| cr = r.cases[i] |
| row += f" {cr.hit_count}/{len(cr.expected)} {cr.recall*100:3.0f}% {cr.elapsed_ms:5.0f}ms " |
| print(row) |
| print(f"{'โ'*70}") |
| avg_row = f"{'Average':<42}" |
| for r in all_results: |
| avg_row += f" avg {r.avg_recall*100:.0f}% / {r.avg_infer_ms:.0f}ms " |
| print(avg_row) |
|
|
| report_path = REPORT_DIR / "comparison_report.md" |
| write_report(all_results, report_path) |
|
|