File size: 7,132 Bytes
d470d45
 
 
 
b10a07b
d470d45
b10a07b
 
6a5726c
b10a07b
d470d45
b10a07b
d470d45
8b89393
 
 
d470d45
 
b10a07b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d470d45
b10a07b
 
d470d45
 
 
 
 
b10a07b
f988bf5
 
 
 
 
 
b10a07b
 
 
 
d470d45
 
 
 
 
 
 
 
 
 
 
b10a07b
f988bf5
 
 
 
 
 
 
b10a07b
 
 
 
 
d470d45
 
 
 
 
 
f988bf5
b10a07b
 
 
d470d45
 
 
 
 
f988bf5
 
 
 
 
 
 
 
b10a07b
 
 
d470d45
 
 
 
 
f988bf5
b10a07b
 
 
d470d45
 
 
 
b10a07b
f988bf5
 
 
 
 
 
 
b10a07b
 
 
 
d470d45
 
 
 
 
6a5726c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a7f7d0
6a5726c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Integration tests — require the server to be running.

    python run.py   # in another terminal
    pytest tests/test_api_integration.py -v -s
"""
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

import pytest
import requests

import os
BASE_URL = os.getenv("NER_BASE_URL", "http://localhost:4000/api/v1")
HF_BASE_URL = "https://robinwu-nerserver.hf.space/api/v1"


def _call(method: str, path: str, payload: dict | None = None) -> tuple[requests.Response, float]:
    t0 = time.perf_counter()
    if method == "GET":
        resp = requests.get(f"{BASE_URL}{path}")
    else:
        resp = requests.post(f"{BASE_URL}{path}", json=payload)
    elapsed = time.perf_counter() - t0
    return resp, elapsed


def _print(label: str, payload: dict | None, resp: requests.Response, elapsed: float):
    print(f"\n{'─' * 60}")
    print(f"[{label}]")
    if payload is not None:
        print(f"  input : {json.dumps(payload, ensure_ascii=False)}")
    print(f"  output: {json.dumps(resp.json(), ensure_ascii=False)}")
    print(f"  status: {resp.status_code}  time: {elapsed * 1000:.1f} ms")


def test_health():
    resp, elapsed = _call("GET", "/health")
    _print("health", None, resp, elapsed)
    assert resp.status_code == 200
    assert resp.json()["status"] == "ok"


def test_extract_person_and_org():
    payload = {
        "text": (
            "Elon Musk, the CEO of Tesla and founder of SpaceX, announced a new partnership "
            "with NASA last Tuesday. The deal, signed at Kennedy Space Center in Florida, "
            "will see SpaceX supply rockets for upcoming lunar missions planned by NASA over the next decade."
        ),
        "labels": ["person", "organization", "location"],
    }
    resp, elapsed = _call("POST", "/extract", payload)
    _print("extract person & org", payload, resp, elapsed)

    assert resp.status_code == 200
    entities = resp.json()["entities"]
    labels = {e["label"] for e in entities}
    texts  = {e["text"]  for e in entities}
    assert "person"       in labels
    assert "organization" in labels
    assert "Elon Musk"    in texts
    assert "SpaceX"       in texts


def test_extract_with_high_threshold():
    payload = {
        "text": (
            "Former US President Barack Obama delivered a keynote speech at the United Nations "
            "headquarters in New York City on Monday, addressing climate change alongside French "
            "President Emmanuel Macron and German Chancellor Olaf Scholz. The event drew leaders "
            "from over fifty countries including Japan, Brazil, and South Africa."
        ),
        "labels": ["person", "location", "organization"],
        "threshold": 0.9,
    }
    resp, elapsed = _call("POST", "/extract", payload)
    _print("extract high threshold", payload, resp, elapsed)

    assert resp.status_code == 200
    for e in resp.json()["entities"]:
        assert e["score"] >= 0.9


def test_extract_empty_text_returns_empty():
    payload = {"text": "", "labels": ["person", "organization", "location"]}
    resp, elapsed = _call("POST", "/extract", payload)
    _print("extract empty text", payload, resp, elapsed)

    assert resp.status_code == 200
    assert resp.json()["entities"] == []


def test_extract_empty_labels_returns_empty():
    payload = {
        "text": (
            "Apple Inc. reported record quarterly earnings on Thursday, with CEO Tim Cook "
            "crediting strong iPhone sales in markets across Europe and Southeast Asia. "
            "The company also announced plans to expand its research center in Austin, Texas."
        ),
        "labels": [],
    }
    resp, elapsed = _call("POST", "/extract", payload)
    _print("extract empty labels", payload, resp, elapsed)

    assert resp.status_code == 200
    assert resp.json()["entities"] == []


def test_extract_invalid_threshold_rejected():
    payload = {"text": "Hello world, this is a simple test sentence.", "labels": ["person"], "threshold": 2.0}
    resp, elapsed = _call("POST", "/extract", payload)
    _print("extract invalid threshold", payload, resp, elapsed)

    assert resp.status_code == 422


def test_entity_fields_present():
    payload = {
        "text": (
            "Tim Cook, CEO of Apple, met with Sundar Pichai from Google and Satya Nadella "
            "from Microsoft at a technology summit held in San Francisco last week. The three "
            "executives discussed artificial intelligence regulation and data privacy policies "
            "being proposed by the European Union and the US Congress."
        ),
        "labels": ["person", "organization", "location"],
    }
    resp, elapsed = _call("POST", "/extract", payload)
    _print("extract field check", payload, resp, elapsed)

    assert resp.status_code == 200
    for e in resp.json()["entities"]:
        assert {"text", "label", "score", "start", "end"} <= e.keys()
        assert 0.0 <= e["score"] <= 1.0
        assert e["start"] < e["end"]


def test_concurrent_two_requests():
    payloads = [
        {
            "text": (
                "Jeff Bezos founded Amazon in a garage in Bellevue, Washington in 1994. "
                "The company started as an online bookstore before expanding into cloud computing "
                "through AWS, making Amazon one of the most valuable companies in the world."
            ),
            "labels": ["person", "organization", "location"],
        },
        {
            "text": (
                "The World Health Organization declared a new health advisory after researchers "
                "at Johns Hopkins University and the University of Oxford published findings on "
                "antibiotic resistance. Dr. Maria Chen led the study, which was funded by the "
                "Bill & Melinda Gates Foundation."
            ),
            "labels": ["person", "organization", "location"],
        },
    ]

    results = {}
    t0 = time.perf_counter()

    def fetch(idx: int, payload: dict):
        t = time.perf_counter()
        resp = requests.post(f"{BASE_URL}/extract", json=payload)  # BASE_URL already includes /api/v1
        return idx, payload, resp, time.perf_counter() - t

    with ThreadPoolExecutor(max_workers=2) as pool:
        futures = [pool.submit(fetch, i, p) for i, p in enumerate(payloads)]
        for f in as_completed(futures):
            idx, payload, resp, elapsed = f.result()
            results[idx] = (payload, resp, elapsed)

    total = time.perf_counter() - t0

    print(f"\n{'─' * 60}")
    print("[concurrent 2 requests]")
    for idx in sorted(results):
        payload, resp, elapsed = results[idx]
        print(f"  --- request {idx + 1} ---")
        print(f"  input : {json.dumps(payload, ensure_ascii=False)}")
        print(f"  output: {json.dumps(resp.json(), ensure_ascii=False)}")
        print(f"  time  : {elapsed * 1000:.1f} ms")
    print(f"  total wall time: {total * 1000:.1f} ms")

    for idx in sorted(results):
        _, resp, _ = results[idx]
        assert resp.status_code == 200
        assert len(resp.json()["entities"]) > 0