File size: 2,845 Bytes
be29f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c956255
 
 
be29f4f
 
c956255
be29f4f
 
 
 
 
c956255
be29f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c956255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from fastapi import FastAPI, HTTPException, Header, Depends, Request
from pydantic import BaseModel
from gliner import GLiNER
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import logging
import os

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

limiter = Limiter(key_func=get_remote_address)

app = FastAPI(title="Panoptifi NER API")
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

API_KEY = os.environ.get("API_KEY", "")


def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")):
  if API_KEY and x_api_key != API_KEY:
      raise HTTPException(status_code=401, detail="Invalid API key")
  return True


logger.info("Loading GLiNER model...")
model = GLiNER.from_pretrained("urchade/gliner_small")
logger.info("Model loaded")

DEFAULT_LABELS = ["company", "stock_ticker", "executive", "regulator", "product", "location"]


class NERInput(BaseModel):
  text: str
  labels: list[str] | None = None


class Entity(BaseModel):
  text: str
  label: str
  score: float
  start: int
  end: int


class NERResult(BaseModel):
  entities: list[Entity]

class BatchNERInput(BaseModel):
  texts: list[str]
  labels: list[str] | None = None

@app.get("/health")
# @limiter.limit("60/minute")
def health(request: Request):
  return {"status": "healthy", "model": "urchade/gliner_small"}


@app.post("/extract", response_model=NERResult)
# @limiter.limit("30/minute")
def extract_entities(request: Request, input: NERInput, _: bool = Depends(verify_api_key)):
  if not input.text.strip():
      raise HTTPException(400, "Text cannot be empty")

  labels = input.labels or DEFAULT_LABELS
  entities = model.predict_entities(input.text[:2000], labels, threshold=0.5)

  return NERResult(entities=[
      Entity(
          text=e["text"],
          label=e["label"],
          score=e["score"],
          start=e["start"],
          end=e["end"]
      )
      for e in entities
  ])

@app.post("/extract/batch", response_model=list[NERResult])
# @limiter.limit("10/minute")
def extract_batch(request: Request, input: BatchNERInput, _: bool = Depends(verify_api_key)):
  if len(input.texts) > 50:
      raise HTTPException(400, "Max 50 texts per batch")

  labels = input.labels or DEFAULT_LABELS
  results = []
  for text in input.texts:
      if text.strip():
          entities = model.predict_entities(text[:2000], labels, threshold=0.5)
          results.append(NERResult(entities=[
              Entity(
                  text=e["text"],
                  label=e["label"],
                  score=e["score"],
                  start=e["start"],
                  end=e["end"]
              )
              for e in entities
          ]))
  return results