Token Classification
GLiNER
PyTorch
multilingual
alfonsovelp commited on
Commit
45bd82d
·
verified ·
1 Parent(s): 6477b30

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -120
handler.py CHANGED
@@ -1,120 +1,48 @@
1
- import re
2
-
3
-
4
- class TokenSplitterBase():
5
- def __init__(self):
6
- pass
7
-
8
- def __call__(self, text) -> (str, int, int):
9
- pass
10
-
11
-
12
- class WhitespaceTokenSplitter(TokenSplitterBase):
13
- def __init__(self):
14
- self.whitespace_pattern = re.compile(r'\w+(?:[-_]\w+)*|\S')
15
-
16
- def __call__(self, text):
17
- for match in self.whitespace_pattern.finditer(text):
18
- yield match.group(), match.start(), match.end()
19
-
20
-
21
- class SpaCyTokenSplitter(TokenSplitterBase):
22
- def __init__(self, lang=None):
23
- try:
24
- import spacy # noqa
25
- except ModuleNotFoundError as error:
26
- raise error.__class__(
27
- "Please install spacy with: `pip install spacy`"
28
- )
29
- if lang is None:
30
- lang = 'en' # Default to English if no language is specified
31
- self.nlp = spacy.blank(lang)
32
-
33
- def __call__(self, text):
34
- doc = self.nlp(text)
35
- for token in doc:
36
- yield token.text, token.idx, token.idx + len(token.text)
37
-
38
-
39
- class MecabKoTokenSplitter(TokenSplitterBase):
40
- def __init__(self):
41
- try:
42
- import mecab # noqa
43
- except ModuleNotFoundError as error:
44
- raise error.__class__(
45
- "Please install python-mecab-ko with: `pip install python-mecab-ko`"
46
- )
47
- self.tagger = mecab.MeCab()
48
-
49
- def __call__(self, text):
50
- tokens = self.tagger.morphs(text)
51
-
52
- last_idx = 0
53
- for morph in tokens:
54
- start_idx = text.find(morph, last_idx)
55
- end_idx = start_idx + len(morph)
56
- last_idx = end_idx
57
- yield morph, start_idx, end_idx
58
-
59
- class JiebaTokenSplitter(TokenSplitterBase):
60
- def __init__(self):
61
- try:
62
- import jieba # noqa
63
- except ModuleNotFoundError as error:
64
- raise error.__class__(
65
- "Please install jieba with: `pip install jieba`"
66
- )
67
- self.tagger = jieba
68
-
69
- def __call__(self, text):
70
- tokens = self.tagger.cut(text)
71
- last_idx = 0
72
- for token in tokens:
73
- start_idx = text.find(token, last_idx)
74
- end_idx = start_idx + len(token)
75
- last_idx = end_idx
76
- yield token, start_idx, end_idx
77
-
78
- class HanLPTokenSplitter(TokenSplitterBase):
79
- def __init__(self, model_name="FINE_ELECTRA_SMALL_ZH"):
80
- try:
81
- import hanlp # noqa
82
- import hanlp.pretrained
83
- except ModuleNotFoundError as error:
84
- raise error.__class__(
85
- "Please install hanlp with: `pip install hanlp`"
86
- )
87
-
88
- models = hanlp.pretrained.tok.ALL
89
- if model_name not in models:
90
- raise ValueError(f"HanLP: {model_name} is not available, choose between {models.keys()}")
91
- url = models[model_name]
92
- self.tagger = hanlp.load(url)
93
-
94
- def __call__(self, text):
95
- tokens = self.tagger(text)
96
- last_idx = 0
97
- for token in tokens:
98
- start_idx = text.find(token, last_idx)
99
- end_idx = start_idx + len(token)
100
- last_idx = end_idx
101
- yield token, start_idx, end_idx
102
-
103
- class WordsSplitter(TokenSplitterBase):
104
- def __init__(self, splitter_type='whitespace'):
105
- if splitter_type=='whitespace':
106
- self.splitter = WhitespaceTokenSplitter()
107
- elif splitter_type == 'spacy':
108
- self.splitter = SpaCyTokenSplitter()
109
- elif splitter_type == 'mecab':
110
- self.splitter = MecabKoTokenSplitter()
111
- elif splitter_type == 'jieba':
112
- self.splitter = JiebaTokenSplitter()
113
- elif splitter_type == 'hanlp':
114
- self.splitter = HanLPTokenSplitter()
115
- else:
116
- raise ValueError(f"{splitter_type} is not implemented, choose between 'whitespace', 'spacy', 'jieba', 'hanlp' and 'mecab'")
117
-
118
- def __call__(self, text):
119
- for token in self.splitter(text):
120
- yield token
 
1
+ from transformers import AutoTokenizer
2
+ from gliner import GLiNER
3
+ from huggingface_inference_toolkit.base import BaseHandler
4
+
5
+ class EndpointHandler(BaseHandler):
6
+ def __init__(self, path=""):
7
+ self.model = GLiNER.from_pretrained(path)
8
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
9
+ self.initialized = True
10
+
11
+ def __call__(self, data):
12
+ """
13
+ Args:
14
+ data: Dictionary with:
15
+ - text (str): Input text
16
+ - labels (str): Comma-separated labels
17
+ - threshold (float, optional): Confidence threshold
18
+ - nested_ner (bool, optional): Enable nested NER
19
+ Returns:
20
+ Dictionary with predicted entities
21
+ """
22
+ # Get inputs
23
+ text = data.pop("inputs", data.get("text", ""))
24
+ labels = data.get("labels", "").split(",")
25
+ threshold = float(data.get("threshold", 0.3))
26
+ nested_ner = bool(data.get("nested_ner", True))
27
+
28
+ # Run prediction
29
+ entities = self.model.predict_entities(
30
+ text,
31
+ labels,
32
+ flat_ner=not nested_ner,
33
+ threshold=threshold
34
+ )
35
+
36
+ # Format output
37
+ return {
38
+ "entities": [
39
+ {
40
+ "entity": entity["label"],
41
+ "word": entity["text"],
42
+ "start": entity["start"],
43
+ "end": entity["end"],
44
+ "score": 0
45
+ }
46
+ for entity in entities
47
+ ]
48
+ }