Robin Claude Sonnet 4.6 commited on
Commit
d470d45
·
0 Parent(s):

feat: GLiNER NER HTTP API

Browse files

- POST /extract 接收文本和实体类型,返回抽取结果
- 模型启动时加载一次,多次请求复用
- 通过 HF_ENDPOINT 使用国内镜像,MODEL_CACHE_DIR 本地缓存
- 默认端口 4000,支持环境变量配置
- 包含单元测试(mock)和集成测试(真实 API)
- start.bat 从 conda ai 环境启动

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

.env.example ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ MODEL_NAME=urchade/gliner_medium-v2.1
2
+ HOST=0.0.0.0
3
+ PORT=4000
4
+ HF_ENDPOINT=https://hf-mirror.com
5
+ MODEL_CACHE_DIR=./model_cache
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .env
4
+ model_cache/
5
+ .pytest_cache/
6
+ *.egg-info/
7
+ dist/
8
+ build/
app/__init__.py ADDED
File without changes
app/config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ MODEL_NAME: str = os.getenv("MODEL_NAME", "urchade/gliner_medium-v2.1")
4
+ HOST: str = os.getenv("HOST", "0.0.0.0")
5
+ PORT: int = int(os.getenv("PORT", "4000"))
6
+ MODEL_CACHE_DIR: str = os.getenv("MODEL_CACHE_DIR", "./model_cache")
7
+
8
+ # Must be set before huggingface_hub / transformers are imported
9
+ _hf_endpoint = os.getenv("HF_ENDPOINT", "https://hf-mirror.com")
10
+ os.environ["HF_ENDPOINT"] = _hf_endpoint
app/main.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+
3
+ from fastapi import FastAPI
4
+
5
+ from app.config import MODEL_CACHE_DIR, MODEL_NAME
6
+ from app.models import ExtractRequest, ExtractResponse
7
+ from app.ner import NERService
8
+
9
+ ner_service: NERService | None = None
10
+
11
+
12
+ @asynccontextmanager
13
+ async def lifespan(app: FastAPI):
14
+ global ner_service
15
+ ner_service = NERService(MODEL_NAME, MODEL_CACHE_DIR)
16
+ yield
17
+ ner_service = None
18
+
19
+
20
+ app = FastAPI(title="NER API", lifespan=lifespan)
21
+
22
+
23
+ @app.get("/health")
24
+ def health():
25
+ return {"status": "ok"}
26
+
27
+
28
+ @app.post("/extract", response_model=ExtractResponse)
29
+ def extract(req: ExtractRequest):
30
+ entities = ner_service.extract(req.text, req.labels, req.threshold)
31
+ return ExtractResponse(entities=entities)
app/models.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class ExtractRequest(BaseModel):
5
+ text: str
6
+ labels: list[str]
7
+ threshold: float = Field(default=0.5, ge=0.0, le=1.0)
8
+
9
+
10
+ class Entity(BaseModel):
11
+ text: str
12
+ label: str
13
+ score: float
14
+ start: int
15
+ end: int
16
+
17
+
18
+ class ExtractResponse(BaseModel):
19
+ entities: list[Entity]
app/ner.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gliner import GLiNER
2
+ from app.models import Entity
3
+
4
+
5
+ class NERService:
6
+ def __init__(self, model_name: str, cache_dir: str) -> None:
7
+ self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir)
8
+
9
+ def extract(self, text: str, labels: list[str], threshold: float) -> list[Entity]:
10
+ if not text or not labels:
11
+ return []
12
+ raw = self._model.predict_entities(text, labels, threshold=threshold)
13
+ return [
14
+ Entity(
15
+ text=e["text"],
16
+ label=e["label"],
17
+ score=round(e["score"], 4),
18
+ start=e["start"],
19
+ end=e["end"],
20
+ )
21
+ for e in raw
22
+ ]
docs/requirements/PRD-ner-api.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PRD-ner-api
2
+
3
+ 状态:已确认
4
+ 创建日期:2026-04-28
5
+
6
+ ## 1. 功能目标
7
+
8
+ 提供一个 HTTP API 服务,接收文本和实体类型列表,返回从文本中抽取到的命名实体。
9
+
10
+ ## 2. 用户故事
11
+
12
+ 作为 API 调用方,我可以传入一段文本和期望识别的实体类型(如 "person"、"organization"、"location"),得到每个实体的文字、类型和在原文中的位置。
13
+
14
+ ## 3. 验收标准
15
+
16
+ - [ ] POST /extract 接口可正常调用
17
+ - [ ] 支持传入任意实体类型列表(zero-shot)
18
+ - [ ] 返回实体文字、实体类型、置信度分数
19
+ - [ ] 模型加载一次,多次请求复用
20
+ - [ ] 支持通过环境变量配置模型名称和服务端口
21
+
22
+ ## 4. 约束
23
+
24
+ - 基于 GLiNER 库实现(Python)
25
+ - 尽量简单,不引入数据库、认证等复杂机制
26
+ - 默认使用 `urchade/gliner_medium-v2.1` 模型
docs/technical/TDD-ner-api.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TDD-ner-api
2
+
3
+ 状态:已实现
4
+ 关联需求:docs/requirements/PRD-ner-api.md
5
+ 创建日期:2026-04-28
6
+
7
+ ## 1. 需求摘要
8
+
9
+ 用 FastAPI 包装 GLiNER 模型,提供 POST /extract 接口,接收文本与实体类型列表,返回抽取结果。
10
+
11
+ ## 2. 方案设计
12
+
13
+ ### 方案选型
14
+
15
+ | 方案 | 优点 | 缺点 | 结论 |
16
+ |------|------|------|------|
17
+ | FastAPI + GLiNER | 轻量、async 支持好、自动生成文档 | — | ✅ 采用 |
18
+ | Flask + GLiNER | 更简单 | 无 async,性能差 | ❌ |
19
+
20
+ ### 目录结构
21
+
22
+ ```
23
+ ner-server/
24
+ ├── app/
25
+ │ ├── main.py # FastAPI app 入口,lifespan 加载模型
26
+ │ ├── config.py # 环境变量配置
27
+ │ ├── models.py # Pydantic 请求/响应模型
28
+ │ └── ner.py # GLiNER 封装(NERService)
29
+ ├── tests/
30
+ │ └── test_extract.py
31
+ ├── requirements.txt
32
+ └── .env.example
33
+ ```
34
+
35
+ ### 核心接口
36
+
37
+ ```
38
+ POST /extract
39
+ Request: { "text": str, "labels": list[str], "threshold": float = 0.5 }
40
+ Response: { "entities": [{ "text": str, "label": str, "score": float, "start": int, "end": int }] }
41
+
42
+ GET /health
43
+ Response: { "status": "ok" }
44
+ ```
45
+
46
+ ### 配置项(环境变量)
47
+
48
+ | 变量 | 默认值 | 说明 |
49
+ |------|--------|------|
50
+ | MODEL_NAME | urchade/gliner_medium-v2.1 | GLiNER 模型名称 |
51
+ | PORT | 8000 | 服务端口 |
52
+ | HOST | 0.0.0.0 | 监听地址 |
53
+
54
+ ## 3. 测试策略
55
+
56
+ - 正常路径:传入文本和标签,返回实体列表
57
+ - 空文本:返回空实体列表
58
+ - 空标签列表:返回空实体列表
59
+ - threshold 过滤:高阈值时过滤低置信度实体
60
+
61
+ ---
62
+ 确认记录:2026-04-28 用户确认(口头需求)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn[standard]>=0.29.0
3
+ gliner>=0.2.0
run.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import uvicorn
2
+ from app.config import HOST, PORT
3
+
4
+ if __name__ == "__main__":
5
+ uvicorn.run("app.main:app", host=HOST, port=PORT, reload=False)
start.bat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ @echo off
2
+ call D:\ProgramData\anaconda3\Scripts\activate.bat D:\ProgramData\coda_envs\ai
3
+ python run.py
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mock gliner and torch before any app module is imported.
3
+ This prevents torch's BLAS FPE check from crashing on Windows during tests.
4
+ """
5
+ import sys
6
+ from unittest.mock import MagicMock
7
+
8
+ # Stub out gliner and its torch dependency so the app can be imported safely
9
+ for mod in ("torch", "gliner", "gliner.model"):
10
+ sys.modules.setdefault(mod, MagicMock())
11
+
12
+ _gliner_stub = sys.modules["gliner"]
13
+ _gliner_stub.GLiNER = MagicMock()
tests/test_api_integration.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests — require the server to be running.
3
+
4
+ python run.py # in another terminal
5
+ pytest tests/test_api_integration.py -v
6
+ """
7
+ import requests
8
+ import pytest
9
+
10
+ BASE_URL = "http://localhost:4000"
11
+
12
+
13
+ def test_health():
14
+ resp = requests.get(f"{BASE_URL}/health")
15
+ assert resp.status_code == 200
16
+ assert resp.json()["status"] == "ok"
17
+
18
+
19
+ def test_extract_person_and_org():
20
+ resp = requests.post(
21
+ f"{BASE_URL}/extract",
22
+ json={
23
+ "text": "Elon Musk founded SpaceX in 2002.",
24
+ "labels": ["person", "organization"],
25
+ },
26
+ )
27
+ assert resp.status_code == 200
28
+ entities = resp.json()["entities"]
29
+ labels = {e["label"] for e in entities}
30
+ texts = {e["text"] for e in entities}
31
+ assert "person" in labels
32
+ assert "organization" in labels
33
+ assert "Elon Musk" in texts
34
+ assert "SpaceX" in texts
35
+
36
+
37
+ def test_extract_with_high_threshold():
38
+ resp = requests.post(
39
+ f"{BASE_URL}/extract",
40
+ json={
41
+ "text": "Barack Obama visited Paris.",
42
+ "labels": ["person", "location"],
43
+ "threshold": 0.9,
44
+ },
45
+ )
46
+ assert resp.status_code == 200
47
+ for e in resp.json()["entities"]:
48
+ assert e["score"] >= 0.9
49
+
50
+
51
+ def test_extract_empty_text_returns_empty():
52
+ resp = requests.post(
53
+ f"{BASE_URL}/extract",
54
+ json={"text": "", "labels": ["person"]},
55
+ )
56
+ assert resp.status_code == 200
57
+ assert resp.json()["entities"] == []
58
+
59
+
60
+ def test_extract_empty_labels_returns_empty():
61
+ resp = requests.post(
62
+ f"{BASE_URL}/extract",
63
+ json={"text": "Apple is great.", "labels": []},
64
+ )
65
+ assert resp.status_code == 200
66
+ assert resp.json()["entities"] == []
67
+
68
+
69
+ def test_extract_invalid_threshold_rejected():
70
+ resp = requests.post(
71
+ f"{BASE_URL}/extract",
72
+ json={"text": "Hello", "labels": ["person"], "threshold": 2.0},
73
+ )
74
+ assert resp.status_code == 422
75
+
76
+
77
+ def test_entity_fields_present():
78
+ resp = requests.post(
79
+ f"{BASE_URL}/extract",
80
+ json={
81
+ "text": "Tim Cook leads Apple.",
82
+ "labels": ["person", "organization"],
83
+ },
84
+ )
85
+ assert resp.status_code == 200
86
+ for e in resp.json()["entities"]:
87
+ assert {"text", "label", "score", "start", "end"} <= e.keys()
88
+ assert 0.0 <= e["score"] <= 1.0
89
+ assert e["start"] < e["end"]
tests/test_extract.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ conftest.py stubs gliner/torch before these imports run,
3
+ so no real model is loaded during tests.
4
+ """
5
+ from unittest.mock import MagicMock
6
+
7
+ import pytest
8
+ from fastapi.testclient import TestClient
9
+
10
+ import app.main as main_module
11
+ from app.main import app
12
+ from app.models import Entity
13
+
14
+
15
+ @pytest.fixture()
16
+ def client():
17
+ mock_ner = MagicMock()
18
+ # Patch NERService so lifespan assigns our mock instead of a real model
19
+ with pytest.MonkeyPatch().context() as mp:
20
+ mp.setattr("app.main.NERService", lambda *_: mock_ner)
21
+ with TestClient(app) as c:
22
+ yield c, mock_ner
23
+
24
+
25
+ def test_health(client):
26
+ c, _ = client
27
+ resp = c.get("/health")
28
+ assert resp.status_code == 200
29
+ assert resp.json() == {"status": "ok"}
30
+
31
+
32
+ def test_extract_returns_entities(client):
33
+ c, mock_ner = client
34
+ mock_ner.extract.return_value = [
35
+ Entity(text="Apple", label="organization", score=0.95, start=0, end=5)
36
+ ]
37
+
38
+ resp = c.post(
39
+ "/extract",
40
+ json={"text": "Apple is a tech company.", "labels": ["organization", "person"]},
41
+ )
42
+
43
+ assert resp.status_code == 200
44
+ data = resp.json()
45
+ assert len(data["entities"]) == 1
46
+ assert data["entities"][0]["text"] == "Apple"
47
+ assert data["entities"][0]["label"] == "organization"
48
+
49
+
50
+ def test_extract_empty_text(client):
51
+ c, mock_ner = client
52
+ mock_ner.extract.return_value = []
53
+
54
+ resp = c.post("/extract", json={"text": "", "labels": ["person"]})
55
+
56
+ assert resp.status_code == 200
57
+ assert resp.json()["entities"] == []
58
+
59
+
60
+ def test_extract_empty_labels(client):
61
+ c, mock_ner = client
62
+ mock_ner.extract.return_value = []
63
+
64
+ resp = c.post("/extract", json={"text": "Some text.", "labels": []})
65
+
66
+ assert resp.status_code == 200
67
+ assert resp.json()["entities"] == []
68
+
69
+
70
+ def test_extract_threshold_forwarded(client):
71
+ c, mock_ner = client
72
+ mock_ner.extract.return_value = []
73
+
74
+ c.post(
75
+ "/extract",
76
+ json={"text": "Hello world", "labels": ["person"], "threshold": 0.8},
77
+ )
78
+
79
+ mock_ner.extract.assert_called_once_with("Hello world", ["person"], 0.8)
80
+
81
+
82
+ def test_extract_invalid_threshold(client):
83
+ c, _ = client
84
+ resp = c.post(
85
+ "/extract",
86
+ json={"text": "Hello", "labels": ["person"], "threshold": 1.5},
87
+ )
88
+ assert resp.status_code == 422