tre-1 / handler.py
rain1024's picture
Add word segmentation support and underthesea-core integration
5d8bdc8
"""
Custom handler for Vietnamese POS Tagger inference on Hugging Face.
Supports two model formats:
- CRFsuite format (.crfsuite) - loaded with pycrfsuite
- underthesea-core format (.crf) - loaded with underthesea_core
"""
import os
import re
from typing import Dict, List, Any
# Try importing both taggers
try:
import pycrfsuite
HAS_PYCRFSUITE = True
except ImportError:
HAS_PYCRFSUITE = False
try:
from underthesea_core import CRFModel, CRFTagger
HAS_UNDERTHESEA_CORE = True
except ImportError:
try:
from underthesea_core.underthesea_core import CRFModel, CRFTagger
HAS_UNDERTHESEA_CORE = True
except ImportError:
HAS_UNDERTHESEA_CORE = False
class PythonCRFFeaturizer:
"""
Python implementation of CRFFeaturizer compatible with underthesea_core API.
"""
def __init__(self, feature_templates, dictionary=None):
self.feature_templates = feature_templates
self.dictionary = dictionary or set()
def _parse_template(self, template):
match = re.match(r'T\[([^\]]+)\](?:\.(\w+))?', template)
if not match:
return None, None, None
indices_str = match.group(1)
attribute = match.group(2)
indices = [int(i.strip()) for i in indices_str.split(',')]
return indices, attribute, template
def _get_token_value(self, tokens, position, index):
actual_pos = position + index
if actual_pos < 0:
return '__BOS__'
elif actual_pos >= len(tokens):
return '__EOS__'
return tokens[actual_pos]
def _apply_attribute(self, value, attribute):
if value in ('__BOS__', '__EOS__'):
return value
if attribute is None:
return value
elif attribute == 'lower':
return value.lower()
elif attribute == 'upper':
return value.upper()
elif attribute == 'istitle':
return str(value.istitle())
elif attribute == 'isupper':
return str(value.isupper())
elif attribute == 'islower':
return str(value.islower())
elif attribute == 'isdigit':
return str(value.isdigit())
elif attribute == 'isalpha':
return str(value.isalpha())
elif attribute == 'is_in_dict':
return str(value in self.dictionary)
elif attribute.startswith('prefix'):
n = int(attribute[6:]) if len(attribute) > 6 else 2
return value[:n] if len(value) >= n else value
elif attribute.startswith('suffix'):
n = int(attribute[6:]) if len(attribute) > 6 else 2
return value[-n:] if len(value) >= n else value
else:
return value
def extract_features(self, tokens, position):
features = {}
for template in self.feature_templates:
indices, attribute, template_str = self._parse_template(template)
if indices is None:
continue
if len(indices) == 1:
value = self._get_token_value(tokens, position, indices[0])
value = self._apply_attribute(value, attribute)
features[template_str] = value
else:
values = [self._get_token_value(tokens, position, idx) for idx in indices]
if attribute == 'is_in_dict':
combined = ' '.join(values)
features[template_str] = str(combined in self.dictionary)
else:
combined = '|'.join(values)
features[template_str] = combined
return features
class EndpointHandler:
def __init__(self, path: str = ""):
import os
# Feature templates
self.feature_templates = [
"T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper",
"T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3",
"T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower",
"T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower",
"T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper",
"T[2]", "T[2].lower", "T[-1,0]", "T[0,1]",
"T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict",
]
self.featurizer = PythonCRFFeaturizer(self.feature_templates)
# Load CRF model - check multiple possible locations and formats
# Priority: .crfsuite (pycrfsuite) > .crf (underthesea-core)
model_candidates = [
(os.path.join(path, "model.crfsuite"), "pycrfsuite"),
(os.path.join(path, "pos_tagger.crfsuite"), "pycrfsuite"),
(os.path.join(path, "model.crf"), "underthesea-core"),
]
model_path = None
model_format = None
for candidate, fmt in model_candidates:
if os.path.exists(candidate):
model_path = candidate
model_format = fmt
break
if model_path is None:
raise FileNotFoundError(
f"No model found. Checked: {[c for c, _ in model_candidates]}"
)
# Load model based on format
self.model_format = model_format
if model_format == "pycrfsuite":
if not HAS_PYCRFSUITE:
raise ImportError("pycrfsuite not installed. Install with: pip install python-crfsuite")
self.tagger = pycrfsuite.Tagger()
self.tagger.open(model_path)
elif model_format == "underthesea-core":
if not HAS_UNDERTHESEA_CORE:
raise ImportError("underthesea-core not installed")
model = CRFModel.load(model_path)
self.tagger = CRFTagger.from_model(model)
def _tokenize(self, text: str) -> List[str]:
"""Simple whitespace tokenization."""
return text.strip().split()
def _extract_features(self, tokens: List[str]) -> List[List[str]]:
"""Extract features for all tokens in a sentence."""
features = []
for i in range(len(tokens)):
feat_dict = self.featurizer.extract_features(tokens, i)
feature_list = [f"{k}={v}" for k, v in feat_dict.items()]
features.append(feature_list)
return features
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle inference requests.
Args:
data: Dict with "inputs" key containing text or list of texts
Returns:
List of dicts with token and POS tag pairs
"""
inputs = data.get("inputs", data.get("text", ""))
# Handle single string or list
if isinstance(inputs, str):
inputs = [inputs]
results = []
for text in inputs:
tokens = self._tokenize(text)
if not tokens:
results.append([])
continue
features = self._extract_features(tokens)
tags = self.tagger.tag(features)
result = [{"token": token, "tag": tag} for token, tag in zip(tokens, tags)]
results.append(result)
return results if len(results) > 1 else results[0]