Update handler.py
Browse files- handler.py +11 -9
handler.py
CHANGED
|
@@ -1,25 +1,27 @@
|
|
|
|
|
| 1 |
from transformers import AutoTokenizer
|
| 2 |
from gliner import GLiNER
|
| 3 |
-
from huggingface_inference_toolkit.base import BaseHandler
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
|
| 9 |
self.initialized = True
|
| 10 |
|
| 11 |
-
def __call__(self, data):
|
| 12 |
"""
|
| 13 |
Args:
|
| 14 |
-
data: Dictionary
|
| 15 |
-
- text (str): Input text
|
| 16 |
- labels (str): Comma-separated labels
|
| 17 |
- threshold (float, optional): Confidence threshold
|
| 18 |
- nested_ner (bool, optional): Enable nested NER
|
| 19 |
Returns:
|
| 20 |
-
Dictionary with predicted entities
|
| 21 |
"""
|
| 22 |
-
# Get inputs
|
| 23 |
text = data.pop("inputs", data.get("text", ""))
|
| 24 |
labels = data.get("labels", "").split(",")
|
| 25 |
threshold = float(data.get("threshold", 0.3))
|
|
|
|
| 1 |
+
from typing import Dict, Any, List
|
| 2 |
from transformers import AutoTokenizer
|
| 3 |
from gliner import GLiNER
|
|
|
|
| 4 |
|
| 5 |
+
|
| 6 |
+
class EndpointHandler:
|
| 7 |
+
def __init__(self, path: str = ""):
|
| 8 |
+
"""Initialize the model and tokenizer"""
|
| 9 |
+
self.model = GLiNER.from_pretrained(path if path else "urchade/gliner_multi-v2.1")
|
| 10 |
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
|
| 11 |
self.initialized = True
|
| 12 |
|
| 13 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
|
| 14 |
"""
|
| 15 |
Args:
|
| 16 |
+
data (Dict[str, Any]): Dictionary containing:
|
| 17 |
+
- inputs/text (str): Input text
|
| 18 |
- labels (str): Comma-separated labels
|
| 19 |
- threshold (float, optional): Confidence threshold
|
| 20 |
- nested_ner (bool, optional): Enable nested NER
|
| 21 |
Returns:
|
| 22 |
+
Dict[str, List[Dict[str, Any]]]: Dictionary with predicted entities
|
| 23 |
"""
|
| 24 |
+
# Get inputs - handle both "inputs" and "text" keys for flexibility
|
| 25 |
text = data.pop("inputs", data.get("text", ""))
|
| 26 |
labels = data.get("labels", "").split(",")
|
| 27 |
threshold = float(data.get("threshold", 0.3))
|