|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Inference script for Vietnamese POS Tagger (TRE-1). |
|
|
|
|
|
Usage: |
|
|
uv run scripts/predict.py "Tôi yêu Việt Nam" |
|
|
uv run scripts/predict.py --version v1.0.0 "Hà Nội là thủ đô" |
|
|
uv run scripts/predict.py --model models/pos_tagger/v1.0.0 "Test" |
|
|
echo "Học sinh đang học bài" | uv run scripts/predict.py - |
|
|
""" |
|
|
|
|
|
import json |
|
|
import sys |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import click |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent |
|
|
|
|
|
from handler import EndpointHandler |
|
|
|
|
|
|
|
|
def get_latest_version(task="pos_tagger"): |
|
|
"""Get the latest model version (sorted by timestamp).""" |
|
|
models_dir = PROJECT_ROOT / "models" / task |
|
|
if not models_dir.exists(): |
|
|
return None |
|
|
versions = [d.name for d in models_dir.iterdir() if d.is_dir()] |
|
|
if not versions: |
|
|
return None |
|
|
return sorted(versions)[-1] |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.argument("text", default="-") |
|
|
@click.option( |
|
|
"--version", "-v", |
|
|
default=None, |
|
|
help="Model version to use (default: latest)", |
|
|
) |
|
|
@click.option( |
|
|
"--model", "-m", |
|
|
default=None, |
|
|
help="Custom model directory path (overrides version-based path)", |
|
|
) |
|
|
@click.option( |
|
|
"--format", "-f", |
|
|
"output_format", |
|
|
type=click.Choice(["inline", "json", "conll"]), |
|
|
default="inline", |
|
|
help="Output format", |
|
|
show_default=True, |
|
|
) |
|
|
def predict(text, version, model, output_format): |
|
|
"""Tag Vietnamese text with POS tags. |
|
|
|
|
|
TEXT is the input text to tag. Use '-' to read from stdin. |
|
|
""" |
|
|
|
|
|
if version is None and model is None: |
|
|
version = get_latest_version("pos_tagger") |
|
|
if version is None: |
|
|
raise click.ClickException("No models found in models/pos_tagger/") |
|
|
|
|
|
|
|
|
if model: |
|
|
model_path = model |
|
|
else: |
|
|
model_path = str(PROJECT_ROOT / "models" / "pos_tagger" / version) |
|
|
|
|
|
|
|
|
if text == "-": |
|
|
text = sys.stdin.read().strip() |
|
|
|
|
|
if not text: |
|
|
raise click.ClickException("No input text provided") |
|
|
|
|
|
|
|
|
handler = EndpointHandler(path=model_path) |
|
|
|
|
|
|
|
|
result = handler({"inputs": text}) |
|
|
|
|
|
|
|
|
if output_format == "json": |
|
|
click.echo(json.dumps(result, ensure_ascii=False, indent=2)) |
|
|
elif output_format == "conll": |
|
|
for i, item in enumerate(result, 1): |
|
|
click.echo(f"{i}\t{item['token']}\t{item['tag']}") |
|
|
else: |
|
|
tagged = " ".join(f"{item['token']}/{item['tag']}" for item in result) |
|
|
click.echo(tagged) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
predict() |
|
|
|