|
|
""" |
|
|
Custom handler for Vietnamese POS Tagger inference on Hugging Face. |
|
|
|
|
|
Supports two model formats: |
|
|
- CRFsuite format (.crfsuite) - loaded with pycrfsuite |
|
|
- underthesea-core format (.crf) - loaded with underthesea_core |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
|
|
|
try: |
|
|
import pycrfsuite |
|
|
HAS_PYCRFSUITE = True |
|
|
except ImportError: |
|
|
HAS_PYCRFSUITE = False |
|
|
|
|
|
try: |
|
|
from underthesea_core import CRFModel, CRFTagger |
|
|
HAS_UNDERTHESEA_CORE = True |
|
|
except ImportError: |
|
|
try: |
|
|
from underthesea_core.underthesea_core import CRFModel, CRFTagger |
|
|
HAS_UNDERTHESEA_CORE = True |
|
|
except ImportError: |
|
|
HAS_UNDERTHESEA_CORE = False |
|
|
|
|
|
|
|
|
class PythonCRFFeaturizer: |
|
|
""" |
|
|
Python implementation of CRFFeaturizer compatible with underthesea_core API. |
|
|
""" |
|
|
|
|
|
def __init__(self, feature_templates, dictionary=None): |
|
|
self.feature_templates = feature_templates |
|
|
self.dictionary = dictionary or set() |
|
|
|
|
|
def _parse_template(self, template): |
|
|
match = re.match(r'T\[([^\]]+)\](?:\.(\w+))?', template) |
|
|
if not match: |
|
|
return None, None, None |
|
|
indices_str = match.group(1) |
|
|
attribute = match.group(2) |
|
|
indices = [int(i.strip()) for i in indices_str.split(',')] |
|
|
return indices, attribute, template |
|
|
|
|
|
def _get_token_value(self, tokens, position, index): |
|
|
actual_pos = position + index |
|
|
if actual_pos < 0: |
|
|
return '__BOS__' |
|
|
elif actual_pos >= len(tokens): |
|
|
return '__EOS__' |
|
|
return tokens[actual_pos] |
|
|
|
|
|
def _apply_attribute(self, value, attribute): |
|
|
if value in ('__BOS__', '__EOS__'): |
|
|
return value |
|
|
if attribute is None: |
|
|
return value |
|
|
elif attribute == 'lower': |
|
|
return value.lower() |
|
|
elif attribute == 'upper': |
|
|
return value.upper() |
|
|
elif attribute == 'istitle': |
|
|
return str(value.istitle()) |
|
|
elif attribute == 'isupper': |
|
|
return str(value.isupper()) |
|
|
elif attribute == 'islower': |
|
|
return str(value.islower()) |
|
|
elif attribute == 'isdigit': |
|
|
return str(value.isdigit()) |
|
|
elif attribute == 'isalpha': |
|
|
return str(value.isalpha()) |
|
|
elif attribute == 'is_in_dict': |
|
|
return str(value in self.dictionary) |
|
|
elif attribute.startswith('prefix'): |
|
|
n = int(attribute[6:]) if len(attribute) > 6 else 2 |
|
|
return value[:n] if len(value) >= n else value |
|
|
elif attribute.startswith('suffix'): |
|
|
n = int(attribute[6:]) if len(attribute) > 6 else 2 |
|
|
return value[-n:] if len(value) >= n else value |
|
|
else: |
|
|
return value |
|
|
|
|
|
def extract_features(self, tokens, position): |
|
|
features = {} |
|
|
for template in self.feature_templates: |
|
|
indices, attribute, template_str = self._parse_template(template) |
|
|
if indices is None: |
|
|
continue |
|
|
if len(indices) == 1: |
|
|
value = self._get_token_value(tokens, position, indices[0]) |
|
|
value = self._apply_attribute(value, attribute) |
|
|
features[template_str] = value |
|
|
else: |
|
|
values = [self._get_token_value(tokens, position, idx) for idx in indices] |
|
|
if attribute == 'is_in_dict': |
|
|
combined = ' '.join(values) |
|
|
features[template_str] = str(combined in self.dictionary) |
|
|
else: |
|
|
combined = '|'.join(values) |
|
|
features[template_str] = combined |
|
|
return features |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
import os |
|
|
|
|
|
|
|
|
self.feature_templates = [ |
|
|
"T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper", |
|
|
"T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3", |
|
|
"T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower", |
|
|
"T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower", |
|
|
"T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper", |
|
|
"T[2]", "T[2].lower", "T[-1,0]", "T[0,1]", |
|
|
"T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict", |
|
|
] |
|
|
|
|
|
self.featurizer = PythonCRFFeaturizer(self.feature_templates) |
|
|
|
|
|
|
|
|
|
|
|
model_candidates = [ |
|
|
(os.path.join(path, "model.crfsuite"), "pycrfsuite"), |
|
|
(os.path.join(path, "pos_tagger.crfsuite"), "pycrfsuite"), |
|
|
(os.path.join(path, "model.crf"), "underthesea-core"), |
|
|
] |
|
|
|
|
|
model_path = None |
|
|
model_format = None |
|
|
for candidate, fmt in model_candidates: |
|
|
if os.path.exists(candidate): |
|
|
model_path = candidate |
|
|
model_format = fmt |
|
|
break |
|
|
|
|
|
if model_path is None: |
|
|
raise FileNotFoundError( |
|
|
f"No model found. Checked: {[c for c, _ in model_candidates]}" |
|
|
) |
|
|
|
|
|
|
|
|
self.model_format = model_format |
|
|
if model_format == "pycrfsuite": |
|
|
if not HAS_PYCRFSUITE: |
|
|
raise ImportError("pycrfsuite not installed. Install with: pip install python-crfsuite") |
|
|
self.tagger = pycrfsuite.Tagger() |
|
|
self.tagger.open(model_path) |
|
|
elif model_format == "underthesea-core": |
|
|
if not HAS_UNDERTHESEA_CORE: |
|
|
raise ImportError("underthesea-core not installed") |
|
|
model = CRFModel.load(model_path) |
|
|
self.tagger = CRFTagger.from_model(model) |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
"""Simple whitespace tokenization.""" |
|
|
return text.strip().split() |
|
|
|
|
|
def _extract_features(self, tokens: List[str]) -> List[List[str]]: |
|
|
"""Extract features for all tokens in a sentence.""" |
|
|
features = [] |
|
|
for i in range(len(tokens)): |
|
|
feat_dict = self.featurizer.extract_features(tokens, i) |
|
|
feature_list = [f"{k}={v}" for k, v in feat_dict.items()] |
|
|
features.append(feature_list) |
|
|
return features |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Handle inference requests. |
|
|
|
|
|
Args: |
|
|
data: Dict with "inputs" key containing text or list of texts |
|
|
|
|
|
Returns: |
|
|
List of dicts with token and POS tag pairs |
|
|
""" |
|
|
inputs = data.get("inputs", data.get("text", "")) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
|
|
|
results = [] |
|
|
for text in inputs: |
|
|
tokens = self._tokenize(text) |
|
|
if not tokens: |
|
|
results.append([]) |
|
|
continue |
|
|
|
|
|
features = self._extract_features(tokens) |
|
|
tags = self.tagger.tag(features) |
|
|
|
|
|
result = [{"token": token, "tag": tag} for token, tag in zip(tokens, tags)] |
|
|
results.append(result) |
|
|
|
|
|
return results if len(results) > 1 else results[0] |
|
|
|