Menon-nb-bert-base-v2 / handler.py
RozaA's picture
Upload handler.py with huggingface_hub
411f23a verified
"""
HuggingFace Inference Endpoint handler for the Menon nb-bert relevance scorer.
HuggingFace's Endpoints platform expects a class named `EndpointHandler` with:
- __init__(self, path: str) → load model from `path`
- __call__(self, data: dict) → run inference and return results
This handler wraps the same pipeline as `score.py` so the endpoint behaves
identically to local scoring: needs_review gate → tokenize → model → threshold.
REQUEST FORMAT
POST /predict
body: {"inputs": "Anskaffelse av samfunnsøkonomisk analyse..."}
or: {"inputs": ["text1", "text2", ...]}
RESPONSE
list of dicts, one per input:
- {"label": "RELEVANT", "score": 0.83, "threshold": 0.26, "reason": "ok"}
- {"label": "NOT_RELEVANT", "score": 0.05, "threshold": 0.26, "reason": "ok"}
- {"label": "needs_review", "score": None, "reason": "non_norwegian(en)"}
"""
import json
from pathlib import Path
from typing import Any, Dict, List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from inference_rules import needs_review
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.eval()
with open(Path(path) / "threshold.json") as f:
self.threshold = json.load(f)["threshold"]
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs", "")
if isinstance(inputs, str):
inputs = [inputs]
results = []
for text in inputs:
flag, reason = needs_review(text)
if flag:
results.append({
"label": "needs_review",
"score": None,
"reason": reason,
})
continue
enc = self.tokenizer(
str(text),
truncation=True,
padding="max_length",
max_length=256,
return_tensors="pt",
)
with torch.no_grad():
logits = self.model(**enc).logits
score = torch.softmax(logits, dim=1)[0, 1].item()
label = "RELEVANT" if score >= self.threshold else "NOT_RELEVANT"
results.append({
"label": label,
"score": score,
"threshold": self.threshold,
"reason": "ok",
})
return results