panoptifi-ner / app.py
vichter's picture
Update app.py
c956255 verified
from fastapi import FastAPI, HTTPException, Header, Depends, Request
from pydantic import BaseModel
from gliner import GLiNER
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import logging
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(title="Panoptifi NER API")
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
API_KEY = os.environ.get("API_KEY", "")
def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")):
if API_KEY and x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
return True
logger.info("Loading GLiNER model...")
model = GLiNER.from_pretrained("urchade/gliner_small")
logger.info("Model loaded")
DEFAULT_LABELS = ["company", "stock_ticker", "executive", "regulator", "product", "location"]
class NERInput(BaseModel):
text: str
labels: list[str] | None = None
class Entity(BaseModel):
text: str
label: str
score: float
start: int
end: int
class NERResult(BaseModel):
entities: list[Entity]
class BatchNERInput(BaseModel):
texts: list[str]
labels: list[str] | None = None
@app.get("/health")
# @limiter.limit("60/minute")
def health(request: Request):
return {"status": "healthy", "model": "urchade/gliner_small"}
@app.post("/extract", response_model=NERResult)
# @limiter.limit("30/minute")
def extract_entities(request: Request, input: NERInput, _: bool = Depends(verify_api_key)):
if not input.text.strip():
raise HTTPException(400, "Text cannot be empty")
labels = input.labels or DEFAULT_LABELS
entities = model.predict_entities(input.text[:2000], labels, threshold=0.5)
return NERResult(entities=[
Entity(
text=e["text"],
label=e["label"],
score=e["score"],
start=e["start"],
end=e["end"]
)
for e in entities
])
@app.post("/extract/batch", response_model=list[NERResult])
# @limiter.limit("10/minute")
def extract_batch(request: Request, input: BatchNERInput, _: bool = Depends(verify_api_key)):
if len(input.texts) > 50:
raise HTTPException(400, "Max 50 texts per batch")
labels = input.labels or DEFAULT_LABELS
results = []
for text in input.texts:
if text.strip():
entities = model.predict_entities(text[:2000], labels, threshold=0.5)
results.append(NERResult(entities=[
Entity(
text=e["text"],
label=e["label"],
score=e["score"],
start=e["start"],
end=e["end"]
)
for e in entities
]))
return results