tre-1 / scripts /predict_word_segmentation.py
rain1024's picture
Add word segmentation support and underthesea-core integration
5d8bdc8
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "python-crfsuite>=0.9.11",
# "click>=8.0.0",
# "underthesea>=6.8.0",
# "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
# ]
# ///
"""
Prediction script for Vietnamese Word Segmentation.
Uses underthesea regex_tokenize to split text into syllables,
then applies CRF model at syllable level to decide word boundaries.
Usage:
uv run scripts/predict_word_segmentation.py "Trên thế giới, giá vàng đang giao dịch"
echo "Text here" | uv run scripts/predict_word_segmentation.py -
"""
import sys
import click
import pycrfsuite
from underthesea.pipeline.word_tokenize.regex_tokenize import tokenize as regex_tokenize
def get_syllable_at(syllables, position, offset):
"""Get syllable at position + offset, with boundary handling."""
idx = position + offset
if idx < 0:
return "__BOS__"
elif idx >= len(syllables):
return "__EOS__"
return syllables[idx]
def is_punct(s):
"""Check if string is punctuation."""
return len(s) == 1 and not s.isalnum()
def extract_syllable_features(syllables, position):
"""Extract features for a syllable at given position."""
features = {}
# Current syllable
s0 = get_syllable_at(syllables, position, 0)
is_boundary = s0 in ("__BOS__", "__EOS__")
features["S[0]"] = s0
features["S[0].lower"] = s0.lower() if not is_boundary else s0
features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False"
features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False"
features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False"
features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False"
features["S[0].len"] = str(len(s0)) if not is_boundary else "0"
features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0
features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0
# Previous syllables
s_1 = get_syllable_at(syllables, position, -1)
s_2 = get_syllable_at(syllables, position, -2)
features["S[-1]"] = s_1
features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1
features["S[-2]"] = s_2
features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2
# Next syllables
s1 = get_syllable_at(syllables, position, 1)
s2 = get_syllable_at(syllables, position, 2)
features["S[1]"] = s1
features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1
features["S[2]"] = s2
features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2
# Bigrams
features["S[-1,0]"] = f"{s_1}|{s0}"
features["S[0,1]"] = f"{s0}|{s1}"
# Trigrams
features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}"
return features
def sentence_to_syllable_features(syllables):
"""Convert syllable sequence to feature sequences."""
return [
[f"{k}={v}" for k, v in extract_syllable_features(syllables, i).items()]
for i in range(len(syllables))
]
def labels_to_words(syllables, labels):
"""Convert syllable sequence and BIO labels back to words."""
words = []
current_word = []
for syl, label in zip(syllables, labels):
if label == "B":
if current_word:
words.append(" ".join(current_word))
current_word = [syl]
else: # I
current_word.append(syl)
if current_word:
words.append(" ".join(current_word))
return words
def segment_text(text, tagger):
"""
Full pipeline: regex tokenize -> CRF segment -> output words.
"""
# Step 1: Regex tokenize into syllables
syllables = regex_tokenize(text)
if not syllables:
return ""
# Step 2: Extract syllable features
X = sentence_to_syllable_features(syllables)
# Step 3: Predict BIO labels
labels = tagger.tag(X)
# Step 4: Convert to words (syllables joined with underscore for compound words)
words = labels_to_words(syllables, labels)
return "_".join(words).replace(" ", "_").replace("_", " ").replace(" ", " _ ")
def segment_text_formatted(text, tagger, use_underscore=True):
"""
Full pipeline with formatted output.
"""
syllables = regex_tokenize(text)
if not syllables:
return ""
X = sentence_to_syllable_features(syllables)
labels = tagger.tag(X)
words = labels_to_words(syllables, labels)
if use_underscore:
# Join compound word syllables with underscore
return " ".join(w.replace(" ", "_") for w in words)
else:
return " ".join(words)
@click.command()
@click.argument("text", required=False)
@click.option(
"--model", "-m",
default="word_segmenter.crfsuite",
help="Path to CRF model file",
show_default=True,
)
@click.option(
"--underscore/--no-underscore",
default=True,
help="Use underscore to join compound word syllables",
)
def main(text, model, underscore):
"""Segment Vietnamese text into words."""
# Handle stdin input
if text == "-" or text is None:
text = sys.stdin.read().strip()
if not text:
click.echo("No input text provided", err=True)
return
# Load model - support both pycrfsuite and underthesea-core formats
if model.endswith(".crf"):
# underthesea-core format
try:
from underthesea_core import CRFModel, CRFTagger
except ImportError:
from underthesea_core.underthesea_core import CRFModel, CRFTagger
crf_model = CRFModel.load(model)
tagger = CRFTagger.from_model(crf_model)
else:
# pycrfsuite format
tagger = pycrfsuite.Tagger()
tagger.open(model)
# Process each line
for line in text.split("\n"):
if line.strip():
result = segment_text_formatted(line, tagger, use_underscore=underscore)
click.echo(result)
if __name__ == "__main__":
main()