File size: 3,087 Bytes
50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 50f98f3 5d8bdc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "python-crfsuite>=0.9.11",
# "click>=8.0.0",
# "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
# ]
# ///
"""
Inference script for Vietnamese POS Tagger (TRE-1).
Usage:
uv run scripts/predict.py "Tôi yêu Việt Nam"
uv run scripts/predict.py --version v1.0.0 "Hà Nội là thủ đô"
uv run scripts/predict.py --model models/pos_tagger/v1.0.0 "Test"
echo "Học sinh đang học bài" | uv run scripts/predict.py -
"""
import json
import sys
import os
from pathlib import Path
import click
# Add parent directory to import handler
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Get project root directory
PROJECT_ROOT = Path(__file__).parent.parent
from handler import EndpointHandler
def get_latest_version(task="pos_tagger"):
"""Get the latest model version (sorted by timestamp)."""
models_dir = PROJECT_ROOT / "models" / task
if not models_dir.exists():
return None
versions = [d.name for d in models_dir.iterdir() if d.is_dir()]
if not versions:
return None
return sorted(versions)[-1] # Latest timestamp
@click.command()
@click.argument("text", default="-")
@click.option(
"--version", "-v",
default=None,
help="Model version to use (default: latest)",
)
@click.option(
"--model", "-m",
default=None,
help="Custom model directory path (overrides version-based path)",
)
@click.option(
"--format", "-f",
"output_format",
type=click.Choice(["inline", "json", "conll"]),
default="inline",
help="Output format",
show_default=True,
)
def predict(text, version, model, output_format):
"""Tag Vietnamese text with POS tags.
TEXT is the input text to tag. Use '-' to read from stdin.
"""
# Use latest version if not specified
if version is None and model is None:
version = get_latest_version("pos_tagger")
if version is None:
raise click.ClickException("No models found in models/pos_tagger/")
# Determine model path
if model:
model_path = model
else:
model_path = str(PROJECT_ROOT / "models" / "pos_tagger" / version)
# Read input
if text == "-":
text = sys.stdin.read().strip()
if not text:
raise click.ClickException("No input text provided")
# Load model
handler = EndpointHandler(path=model_path)
# Predict
result = handler({"inputs": text})
# Format output
if output_format == "json":
click.echo(json.dumps(result, ensure_ascii=False, indent=2))
elif output_format == "conll":
for i, item in enumerate(result, 1):
click.echo(f"{i}\t{item['token']}\t{item['tag']}")
else: # inline
tagged = " ".join(f"{item['token']}/{item['tag']}" for item in result)
click.echo(tagged)
if __name__ == "__main__":
predict()
|