|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Prediction script for Vietnamese Word Segmentation. |
|
|
|
|
|
Uses underthesea regex_tokenize to split text into syllables, |
|
|
then applies CRF model at syllable level to decide word boundaries. |
|
|
|
|
|
Usage: |
|
|
uv run scripts/predict_word_segmentation.py "Trên thế giới, giá vàng đang giao dịch" |
|
|
echo "Text here" | uv run scripts/predict_word_segmentation.py - |
|
|
""" |
|
|
|
|
|
import sys |
|
|
|
|
|
import click |
|
|
import pycrfsuite |
|
|
from underthesea.pipeline.word_tokenize.regex_tokenize import tokenize as regex_tokenize |
|
|
|
|
|
|
|
|
def get_syllable_at(syllables, position, offset): |
|
|
"""Get syllable at position + offset, with boundary handling.""" |
|
|
idx = position + offset |
|
|
if idx < 0: |
|
|
return "__BOS__" |
|
|
elif idx >= len(syllables): |
|
|
return "__EOS__" |
|
|
return syllables[idx] |
|
|
|
|
|
|
|
|
def is_punct(s): |
|
|
"""Check if string is punctuation.""" |
|
|
return len(s) == 1 and not s.isalnum() |
|
|
|
|
|
|
|
|
def extract_syllable_features(syllables, position): |
|
|
"""Extract features for a syllable at given position.""" |
|
|
features = {} |
|
|
|
|
|
|
|
|
s0 = get_syllable_at(syllables, position, 0) |
|
|
is_boundary = s0 in ("__BOS__", "__EOS__") |
|
|
|
|
|
features["S[0]"] = s0 |
|
|
features["S[0].lower"] = s0.lower() if not is_boundary else s0 |
|
|
features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False" |
|
|
features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False" |
|
|
features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False" |
|
|
features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False" |
|
|
features["S[0].len"] = str(len(s0)) if not is_boundary else "0" |
|
|
features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0 |
|
|
features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0 |
|
|
|
|
|
|
|
|
s_1 = get_syllable_at(syllables, position, -1) |
|
|
s_2 = get_syllable_at(syllables, position, -2) |
|
|
features["S[-1]"] = s_1 |
|
|
features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1 |
|
|
features["S[-2]"] = s_2 |
|
|
features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2 |
|
|
|
|
|
|
|
|
s1 = get_syllable_at(syllables, position, 1) |
|
|
s2 = get_syllable_at(syllables, position, 2) |
|
|
features["S[1]"] = s1 |
|
|
features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1 |
|
|
features["S[2]"] = s2 |
|
|
features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2 |
|
|
|
|
|
|
|
|
features["S[-1,0]"] = f"{s_1}|{s0}" |
|
|
features["S[0,1]"] = f"{s0}|{s1}" |
|
|
|
|
|
|
|
|
features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}" |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
def sentence_to_syllable_features(syllables): |
|
|
"""Convert syllable sequence to feature sequences.""" |
|
|
return [ |
|
|
[f"{k}={v}" for k, v in extract_syllable_features(syllables, i).items()] |
|
|
for i in range(len(syllables)) |
|
|
] |
|
|
|
|
|
|
|
|
def labels_to_words(syllables, labels): |
|
|
"""Convert syllable sequence and BIO labels back to words.""" |
|
|
words = [] |
|
|
current_word = [] |
|
|
|
|
|
for syl, label in zip(syllables, labels): |
|
|
if label == "B": |
|
|
if current_word: |
|
|
words.append(" ".join(current_word)) |
|
|
current_word = [syl] |
|
|
else: |
|
|
current_word.append(syl) |
|
|
|
|
|
if current_word: |
|
|
words.append(" ".join(current_word)) |
|
|
|
|
|
return words |
|
|
|
|
|
|
|
|
def segment_text(text, tagger): |
|
|
""" |
|
|
Full pipeline: regex tokenize -> CRF segment -> output words. |
|
|
""" |
|
|
|
|
|
syllables = regex_tokenize(text) |
|
|
|
|
|
if not syllables: |
|
|
return "" |
|
|
|
|
|
|
|
|
X = sentence_to_syllable_features(syllables) |
|
|
|
|
|
|
|
|
labels = tagger.tag(X) |
|
|
|
|
|
|
|
|
words = labels_to_words(syllables, labels) |
|
|
|
|
|
return "_".join(words).replace(" ", "_").replace("_", " ").replace(" ", " _ ") |
|
|
|
|
|
|
|
|
def segment_text_formatted(text, tagger, use_underscore=True): |
|
|
""" |
|
|
Full pipeline with formatted output. |
|
|
""" |
|
|
syllables = regex_tokenize(text) |
|
|
|
|
|
if not syllables: |
|
|
return "" |
|
|
|
|
|
X = sentence_to_syllable_features(syllables) |
|
|
labels = tagger.tag(X) |
|
|
words = labels_to_words(syllables, labels) |
|
|
|
|
|
if use_underscore: |
|
|
|
|
|
return " ".join(w.replace(" ", "_") for w in words) |
|
|
else: |
|
|
return " ".join(words) |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.argument("text", required=False) |
|
|
@click.option( |
|
|
"--model", "-m", |
|
|
default="word_segmenter.crfsuite", |
|
|
help="Path to CRF model file", |
|
|
show_default=True, |
|
|
) |
|
|
@click.option( |
|
|
"--underscore/--no-underscore", |
|
|
default=True, |
|
|
help="Use underscore to join compound word syllables", |
|
|
) |
|
|
def main(text, model, underscore): |
|
|
"""Segment Vietnamese text into words.""" |
|
|
|
|
|
if text == "-" or text is None: |
|
|
text = sys.stdin.read().strip() |
|
|
|
|
|
if not text: |
|
|
click.echo("No input text provided", err=True) |
|
|
return |
|
|
|
|
|
|
|
|
if model.endswith(".crf"): |
|
|
|
|
|
try: |
|
|
from underthesea_core import CRFModel, CRFTagger |
|
|
except ImportError: |
|
|
from underthesea_core.underthesea_core import CRFModel, CRFTagger |
|
|
crf_model = CRFModel.load(model) |
|
|
tagger = CRFTagger.from_model(crf_model) |
|
|
else: |
|
|
|
|
|
tagger = pycrfsuite.Tagger() |
|
|
tagger.open(model) |
|
|
|
|
|
|
|
|
for line in text.split("\n"): |
|
|
if line.strip(): |
|
|
result = segment_text_formatted(line, tagger, use_underscore=underscore) |
|
|
click.echo(result) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|