Token Classification
GLiNER
PyTorch
multilingual
alfonsovelp commited on
Commit
7d5296f
·
verified ·
1 Parent(s): 45bd82d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -9
handler.py CHANGED
@@ -1,25 +1,27 @@
 
1
  from transformers import AutoTokenizer
2
  from gliner import GLiNER
3
- from huggingface_inference_toolkit.base import BaseHandler
4
 
5
- class EndpointHandler(BaseHandler):
6
- def __init__(self, path=""):
7
- self.model = GLiNER.from_pretrained(path)
 
 
8
  self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
9
  self.initialized = True
10
 
11
- def __call__(self, data):
12
  """
13
  Args:
14
- data: Dictionary with:
15
- - text (str): Input text
16
  - labels (str): Comma-separated labels
17
  - threshold (float, optional): Confidence threshold
18
  - nested_ner (bool, optional): Enable nested NER
19
  Returns:
20
- Dictionary with predicted entities
21
  """
22
- # Get inputs
23
  text = data.pop("inputs", data.get("text", ""))
24
  labels = data.get("labels", "").split(",")
25
  threshold = float(data.get("threshold", 0.3))
 
1
+ from typing import Dict, Any, List
2
  from transformers import AutoTokenizer
3
  from gliner import GLiNER
 
4
 
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path: str = ""):
8
+ """Initialize the model and tokenizer"""
9
+ self.model = GLiNER.from_pretrained(path if path else "urchade/gliner_multi-v2.1")
10
  self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
11
  self.initialized = True
12
 
13
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
14
  """
15
  Args:
16
+ data (Dict[str, Any]): Dictionary containing:
17
+ - inputs/text (str): Input text
18
  - labels (str): Comma-separated labels
19
  - threshold (float, optional): Confidence threshold
20
  - nested_ner (bool, optional): Enable nested NER
21
  Returns:
22
+ Dict[str, List[Dict[str, Any]]]: Dictionary with predicted entities
23
  """
24
+ # Get inputs - handle both "inputs" and "text" keys for flexibility
25
  text = data.pop("inputs", data.get("text", ""))
26
  labels = data.get("labels", "").split(",")
27
  threshold = float(data.get("threshold", 0.3))