File size: 2,553 Bytes
411f23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
HuggingFace Inference Endpoint handler for the Menon nb-bert relevance scorer.

HuggingFace's Endpoints platform expects a class named `EndpointHandler` with:
  - __init__(self, path: str)  → load model from `path`
  - __call__(self, data: dict) → run inference and return results

This handler wraps the same pipeline as `score.py` so the endpoint behaves
identically to local scoring: needs_review gate → tokenize → model → threshold.

REQUEST FORMAT
  POST /predict
  body: {"inputs": "Anskaffelse av samfunnsøkonomisk analyse..."}
   or:  {"inputs": ["text1", "text2", ...]}

RESPONSE
  list of dicts, one per input:
  - {"label": "RELEVANT", "score": 0.83, "threshold": 0.26, "reason": "ok"}
  - {"label": "NOT_RELEVANT", "score": 0.05, "threshold": 0.26, "reason": "ok"}
  - {"label": "needs_review", "score": None, "reason": "non_norwegian(en)"}
"""

import json
from pathlib import Path
from typing import Any, Dict, List

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from inference_rules import needs_review


class EndpointHandler:
    def __init__(self, path: str = ""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.model.eval()
        with open(Path(path) / "threshold.json") as f:
            self.threshold = json.load(f)["threshold"]

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.get("inputs", "")
        if isinstance(inputs, str):
            inputs = [inputs]

        results = []
        for text in inputs:
            flag, reason = needs_review(text)
            if flag:
                results.append({
                    "label": "needs_review",
                    "score": None,
                    "reason": reason,
                })
                continue

            enc = self.tokenizer(
                str(text),
                truncation=True,
                padding="max_length",
                max_length=256,
                return_tensors="pt",
            )
            with torch.no_grad():
                logits = self.model(**enc).logits
            score = torch.softmax(logits, dim=1)[0, 1].item()
            label = "RELEVANT" if score >= self.threshold else "NOT_RELEVANT"
            results.append({
                "label": label,
                "score": score,
                "threshold": self.threshold,
                "reason": "ok",
            })
        return results