Robin Claude Sonnet 4.6 commited on
Commit ·
d470d45
0
Parent(s):
feat: GLiNER NER HTTP API
Browse files- POST /extract 接收文本和实体类型,返回抽取结果
- 模型启动时加载一次,多次请求复用
- 通过 HF_ENDPOINT 使用国内镜像,MODEL_CACHE_DIR 本地缓存
- 默认端口 4000,支持环境变量配置
- 包含单元测试(mock)和集成测试(真实 API)
- start.bat 从 conda ai 环境启动
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- .env.example +5 -0
- .gitignore +8 -0
- app/__init__.py +0 -0
- app/config.py +10 -0
- app/main.py +31 -0
- app/models.py +19 -0
- app/ner.py +22 -0
- docs/requirements/PRD-ner-api.md +26 -0
- docs/technical/TDD-ner-api.md +62 -0
- requirements.txt +3 -0
- run.py +5 -0
- start.bat +3 -0
- tests/__init__.py +0 -0
- tests/conftest.py +13 -0
- tests/test_api_integration.py +89 -0
- tests/test_extract.py +88 -0
.env.example
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_NAME=urchade/gliner_medium-v2.1
|
| 2 |
+
HOST=0.0.0.0
|
| 3 |
+
PORT=4000
|
| 4 |
+
HF_ENDPOINT=https://hf-mirror.com
|
| 5 |
+
MODEL_CACHE_DIR=./model_cache
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.env
|
| 4 |
+
model_cache/
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
*.egg-info/
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
app/__init__.py
ADDED
|
File without changes
|
app/config.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
MODEL_NAME: str = os.getenv("MODEL_NAME", "urchade/gliner_medium-v2.1")
|
| 4 |
+
HOST: str = os.getenv("HOST", "0.0.0.0")
|
| 5 |
+
PORT: int = int(os.getenv("PORT", "4000"))
|
| 6 |
+
MODEL_CACHE_DIR: str = os.getenv("MODEL_CACHE_DIR", "./model_cache")
|
| 7 |
+
|
| 8 |
+
# Must be set before huggingface_hub / transformers are imported
|
| 9 |
+
_hf_endpoint = os.getenv("HF_ENDPOINT", "https://hf-mirror.com")
|
| 10 |
+
os.environ["HF_ENDPOINT"] = _hf_endpoint
|
app/main.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
+
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
|
| 5 |
+
from app.config import MODEL_CACHE_DIR, MODEL_NAME
|
| 6 |
+
from app.models import ExtractRequest, ExtractResponse
|
| 7 |
+
from app.ner import NERService
|
| 8 |
+
|
| 9 |
+
ner_service: NERService | None = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@asynccontextmanager
|
| 13 |
+
async def lifespan(app: FastAPI):
|
| 14 |
+
global ner_service
|
| 15 |
+
ner_service = NERService(MODEL_NAME, MODEL_CACHE_DIR)
|
| 16 |
+
yield
|
| 17 |
+
ner_service = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
app = FastAPI(title="NER API", lifespan=lifespan)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@app.get("/health")
|
| 24 |
+
def health():
|
| 25 |
+
return {"status": "ok"}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@app.post("/extract", response_model=ExtractResponse)
|
| 29 |
+
def extract(req: ExtractRequest):
|
| 30 |
+
entities = ner_service.extract(req.text, req.labels, req.threshold)
|
| 31 |
+
return ExtractResponse(entities=entities)
|
app/models.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ExtractRequest(BaseModel):
|
| 5 |
+
text: str
|
| 6 |
+
labels: list[str]
|
| 7 |
+
threshold: float = Field(default=0.5, ge=0.0, le=1.0)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Entity(BaseModel):
|
| 11 |
+
text: str
|
| 12 |
+
label: str
|
| 13 |
+
score: float
|
| 14 |
+
start: int
|
| 15 |
+
end: int
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ExtractResponse(BaseModel):
|
| 19 |
+
entities: list[Entity]
|
app/ner.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gliner import GLiNER
|
| 2 |
+
from app.models import Entity
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class NERService:
|
| 6 |
+
def __init__(self, model_name: str, cache_dir: str) -> None:
|
| 7 |
+
self._model = GLiNER.from_pretrained(model_name, cache_dir=cache_dir)
|
| 8 |
+
|
| 9 |
+
def extract(self, text: str, labels: list[str], threshold: float) -> list[Entity]:
|
| 10 |
+
if not text or not labels:
|
| 11 |
+
return []
|
| 12 |
+
raw = self._model.predict_entities(text, labels, threshold=threshold)
|
| 13 |
+
return [
|
| 14 |
+
Entity(
|
| 15 |
+
text=e["text"],
|
| 16 |
+
label=e["label"],
|
| 17 |
+
score=round(e["score"], 4),
|
| 18 |
+
start=e["start"],
|
| 19 |
+
end=e["end"],
|
| 20 |
+
)
|
| 21 |
+
for e in raw
|
| 22 |
+
]
|
docs/requirements/PRD-ner-api.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PRD-ner-api
|
| 2 |
+
|
| 3 |
+
状态:已确认
|
| 4 |
+
创建日期:2026-04-28
|
| 5 |
+
|
| 6 |
+
## 1. 功能目标
|
| 7 |
+
|
| 8 |
+
提供一个 HTTP API 服务,接收文本和实体类型列表,返回从文本中抽取到的命名实体。
|
| 9 |
+
|
| 10 |
+
## 2. 用户故事
|
| 11 |
+
|
| 12 |
+
作为 API 调用方,我可以传入一段文本和期望识别的实体类型(如 "person"、"organization"、"location"),得到每个实体的文字、类型和在原文中的位置。
|
| 13 |
+
|
| 14 |
+
## 3. 验收标准
|
| 15 |
+
|
| 16 |
+
- [ ] POST /extract 接口可正常调用
|
| 17 |
+
- [ ] 支持传入任意实体类型列表(zero-shot)
|
| 18 |
+
- [ ] 返回实体文字、实体类型、置信度分数
|
| 19 |
+
- [ ] 模型加载一次,多次请求复用
|
| 20 |
+
- [ ] 支持通过环境变量配置模型名称和服务端口
|
| 21 |
+
|
| 22 |
+
## 4. 约束
|
| 23 |
+
|
| 24 |
+
- 基于 GLiNER 库实现(Python)
|
| 25 |
+
- 尽量简单,不引入数据库、认证等复杂机制
|
| 26 |
+
- 默认使用 `urchade/gliner_medium-v2.1` 模型
|
docs/technical/TDD-ner-api.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TDD-ner-api
|
| 2 |
+
|
| 3 |
+
状态:已实现
|
| 4 |
+
关联需求:docs/requirements/PRD-ner-api.md
|
| 5 |
+
创建日期:2026-04-28
|
| 6 |
+
|
| 7 |
+
## 1. 需求摘要
|
| 8 |
+
|
| 9 |
+
用 FastAPI 包装 GLiNER 模型,提供 POST /extract 接口,接收文本与实体类型列表,返回抽取结果。
|
| 10 |
+
|
| 11 |
+
## 2. 方案设计
|
| 12 |
+
|
| 13 |
+
### 方案选型
|
| 14 |
+
|
| 15 |
+
| 方案 | 优点 | 缺点 | 结论 |
|
| 16 |
+
|------|------|------|------|
|
| 17 |
+
| FastAPI + GLiNER | 轻量、async 支持好、自动生成文档 | — | ✅ 采用 |
|
| 18 |
+
| Flask + GLiNER | 更简单 | 无 async,性能差 | ❌ |
|
| 19 |
+
|
| 20 |
+
### 目录结构
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
ner-server/
|
| 24 |
+
├── app/
|
| 25 |
+
│ ├── main.py # FastAPI app 入口,lifespan 加载模型
|
| 26 |
+
│ ├── config.py # 环境变量配置
|
| 27 |
+
│ ├── models.py # Pydantic 请求/响应模型
|
| 28 |
+
│ └── ner.py # GLiNER 封装(NERService)
|
| 29 |
+
├── tests/
|
| 30 |
+
│ └── test_extract.py
|
| 31 |
+
├── requirements.txt
|
| 32 |
+
└── .env.example
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### 核心接口
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
POST /extract
|
| 39 |
+
Request: { "text": str, "labels": list[str], "threshold": float = 0.5 }
|
| 40 |
+
Response: { "entities": [{ "text": str, "label": str, "score": float, "start": int, "end": int }] }
|
| 41 |
+
|
| 42 |
+
GET /health
|
| 43 |
+
Response: { "status": "ok" }
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### 配置项(环境变量)
|
| 47 |
+
|
| 48 |
+
| 变量 | 默认值 | 说明 |
|
| 49 |
+
|------|--------|------|
|
| 50 |
+
| MODEL_NAME | urchade/gliner_medium-v2.1 | GLiNER 模型名称 |
|
| 51 |
+
| PORT | 8000 | 服务端口 |
|
| 52 |
+
| HOST | 0.0.0.0 | 监听地址 |
|
| 53 |
+
|
| 54 |
+
## 3. 测试策略
|
| 55 |
+
|
| 56 |
+
- 正常路径:传入文本和标签,返回实体列表
|
| 57 |
+
- 空文本:返回空实体列表
|
| 58 |
+
- 空标签列表:返回空实体列表
|
| 59 |
+
- threshold 过滤:高阈值时过滤低置信度实体
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
确认记录:2026-04-28 用户确认(口头需求)
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.111.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
gliner>=0.2.0
|
run.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uvicorn
|
| 2 |
+
from app.config import HOST, PORT
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
uvicorn.run("app.main:app", host=HOST, port=PORT, reload=False)
|
start.bat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
call D:\ProgramData\anaconda3\Scripts\activate.bat D:\ProgramData\coda_envs\ai
|
| 3 |
+
python run.py
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mock gliner and torch before any app module is imported.
|
| 3 |
+
This prevents torch's BLAS FPE check from crashing on Windows during tests.
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
from unittest.mock import MagicMock
|
| 7 |
+
|
| 8 |
+
# Stub out gliner and its torch dependency so the app can be imported safely
|
| 9 |
+
for mod in ("torch", "gliner", "gliner.model"):
|
| 10 |
+
sys.modules.setdefault(mod, MagicMock())
|
| 11 |
+
|
| 12 |
+
_gliner_stub = sys.modules["gliner"]
|
| 13 |
+
_gliner_stub.GLiNER = MagicMock()
|
tests/test_api_integration.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration tests — require the server to be running.
|
| 3 |
+
|
| 4 |
+
python run.py # in another terminal
|
| 5 |
+
pytest tests/test_api_integration.py -v
|
| 6 |
+
"""
|
| 7 |
+
import requests
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
BASE_URL = "http://localhost:4000"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_health():
|
| 14 |
+
resp = requests.get(f"{BASE_URL}/health")
|
| 15 |
+
assert resp.status_code == 200
|
| 16 |
+
assert resp.json()["status"] == "ok"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_extract_person_and_org():
|
| 20 |
+
resp = requests.post(
|
| 21 |
+
f"{BASE_URL}/extract",
|
| 22 |
+
json={
|
| 23 |
+
"text": "Elon Musk founded SpaceX in 2002.",
|
| 24 |
+
"labels": ["person", "organization"],
|
| 25 |
+
},
|
| 26 |
+
)
|
| 27 |
+
assert resp.status_code == 200
|
| 28 |
+
entities = resp.json()["entities"]
|
| 29 |
+
labels = {e["label"] for e in entities}
|
| 30 |
+
texts = {e["text"] for e in entities}
|
| 31 |
+
assert "person" in labels
|
| 32 |
+
assert "organization" in labels
|
| 33 |
+
assert "Elon Musk" in texts
|
| 34 |
+
assert "SpaceX" in texts
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_extract_with_high_threshold():
|
| 38 |
+
resp = requests.post(
|
| 39 |
+
f"{BASE_URL}/extract",
|
| 40 |
+
json={
|
| 41 |
+
"text": "Barack Obama visited Paris.",
|
| 42 |
+
"labels": ["person", "location"],
|
| 43 |
+
"threshold": 0.9,
|
| 44 |
+
},
|
| 45 |
+
)
|
| 46 |
+
assert resp.status_code == 200
|
| 47 |
+
for e in resp.json()["entities"]:
|
| 48 |
+
assert e["score"] >= 0.9
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_extract_empty_text_returns_empty():
|
| 52 |
+
resp = requests.post(
|
| 53 |
+
f"{BASE_URL}/extract",
|
| 54 |
+
json={"text": "", "labels": ["person"]},
|
| 55 |
+
)
|
| 56 |
+
assert resp.status_code == 200
|
| 57 |
+
assert resp.json()["entities"] == []
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_extract_empty_labels_returns_empty():
|
| 61 |
+
resp = requests.post(
|
| 62 |
+
f"{BASE_URL}/extract",
|
| 63 |
+
json={"text": "Apple is great.", "labels": []},
|
| 64 |
+
)
|
| 65 |
+
assert resp.status_code == 200
|
| 66 |
+
assert resp.json()["entities"] == []
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_extract_invalid_threshold_rejected():
|
| 70 |
+
resp = requests.post(
|
| 71 |
+
f"{BASE_URL}/extract",
|
| 72 |
+
json={"text": "Hello", "labels": ["person"], "threshold": 2.0},
|
| 73 |
+
)
|
| 74 |
+
assert resp.status_code == 422
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_entity_fields_present():
|
| 78 |
+
resp = requests.post(
|
| 79 |
+
f"{BASE_URL}/extract",
|
| 80 |
+
json={
|
| 81 |
+
"text": "Tim Cook leads Apple.",
|
| 82 |
+
"labels": ["person", "organization"],
|
| 83 |
+
},
|
| 84 |
+
)
|
| 85 |
+
assert resp.status_code == 200
|
| 86 |
+
for e in resp.json()["entities"]:
|
| 87 |
+
assert {"text", "label", "score", "start", "end"} <= e.keys()
|
| 88 |
+
assert 0.0 <= e["score"] <= 1.0
|
| 89 |
+
assert e["start"] < e["end"]
|
tests/test_extract.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
conftest.py stubs gliner/torch before these imports run,
|
| 3 |
+
so no real model is loaded during tests.
|
| 4 |
+
"""
|
| 5 |
+
from unittest.mock import MagicMock
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
from fastapi.testclient import TestClient
|
| 9 |
+
|
| 10 |
+
import app.main as main_module
|
| 11 |
+
from app.main import app
|
| 12 |
+
from app.models import Entity
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@pytest.fixture()
|
| 16 |
+
def client():
|
| 17 |
+
mock_ner = MagicMock()
|
| 18 |
+
# Patch NERService so lifespan assigns our mock instead of a real model
|
| 19 |
+
with pytest.MonkeyPatch().context() as mp:
|
| 20 |
+
mp.setattr("app.main.NERService", lambda *_: mock_ner)
|
| 21 |
+
with TestClient(app) as c:
|
| 22 |
+
yield c, mock_ner
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_health(client):
|
| 26 |
+
c, _ = client
|
| 27 |
+
resp = c.get("/health")
|
| 28 |
+
assert resp.status_code == 200
|
| 29 |
+
assert resp.json() == {"status": "ok"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_extract_returns_entities(client):
|
| 33 |
+
c, mock_ner = client
|
| 34 |
+
mock_ner.extract.return_value = [
|
| 35 |
+
Entity(text="Apple", label="organization", score=0.95, start=0, end=5)
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
resp = c.post(
|
| 39 |
+
"/extract",
|
| 40 |
+
json={"text": "Apple is a tech company.", "labels": ["organization", "person"]},
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
assert resp.status_code == 200
|
| 44 |
+
data = resp.json()
|
| 45 |
+
assert len(data["entities"]) == 1
|
| 46 |
+
assert data["entities"][0]["text"] == "Apple"
|
| 47 |
+
assert data["entities"][0]["label"] == "organization"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_extract_empty_text(client):
|
| 51 |
+
c, mock_ner = client
|
| 52 |
+
mock_ner.extract.return_value = []
|
| 53 |
+
|
| 54 |
+
resp = c.post("/extract", json={"text": "", "labels": ["person"]})
|
| 55 |
+
|
| 56 |
+
assert resp.status_code == 200
|
| 57 |
+
assert resp.json()["entities"] == []
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_extract_empty_labels(client):
|
| 61 |
+
c, mock_ner = client
|
| 62 |
+
mock_ner.extract.return_value = []
|
| 63 |
+
|
| 64 |
+
resp = c.post("/extract", json={"text": "Some text.", "labels": []})
|
| 65 |
+
|
| 66 |
+
assert resp.status_code == 200
|
| 67 |
+
assert resp.json()["entities"] == []
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_extract_threshold_forwarded(client):
|
| 71 |
+
c, mock_ner = client
|
| 72 |
+
mock_ner.extract.return_value = []
|
| 73 |
+
|
| 74 |
+
c.post(
|
| 75 |
+
"/extract",
|
| 76 |
+
json={"text": "Hello world", "labels": ["person"], "threshold": 0.8},
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
mock_ner.extract.assert_called_once_with("Hello world", ["person"], 0.8)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_extract_invalid_threshold(client):
|
| 83 |
+
c, _ = client
|
| 84 |
+
resp = c.post(
|
| 85 |
+
"/extract",
|
| 86 |
+
json={"text": "Hello", "labels": ["person"], "threshold": 1.5},
|
| 87 |
+
)
|
| 88 |
+
assert resp.status_code == 422
|