File size: 3,087 Bytes
50f98f3
 
 
 
5d8bdc8
 
50f98f3
 
 
 
 
 
 
5d8bdc8
 
50f98f3
 
 
5d8bdc8
50f98f3
 
5d8bdc8
 
 
50f98f3
 
 
 
5d8bdc8
 
 
50f98f3
 
 
5d8bdc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50f98f3
 
5d8bdc8
50f98f3
 
 
5d8bdc8
50f98f3
 
5d8bdc8
50f98f3
 
 
 
 
5d8bdc8
 
 
50f98f3
5d8bdc8
50f98f3
 
5d8bdc8
50f98f3
 
 
5d8bdc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# /// script
# requires-python = ">=3.9"
# dependencies = [
#     "python-crfsuite>=0.9.11",
#     "click>=8.0.0",
#     "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl",
# ]
# ///
"""
Inference script for Vietnamese POS Tagger (TRE-1).

Usage:
    uv run scripts/predict.py "Tôi yêu Việt Nam"
    uv run scripts/predict.py --version v1.0.0 "Hà Nội là thủ đô"
    uv run scripts/predict.py --model models/pos_tagger/v1.0.0 "Test"
    echo "Học sinh đang học bài" | uv run scripts/predict.py -
"""

import json
import sys
import os
from pathlib import Path

import click

# Add parent directory to import handler
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Get project root directory
PROJECT_ROOT = Path(__file__).parent.parent

from handler import EndpointHandler


def get_latest_version(task="pos_tagger"):
    """Get the latest model version (sorted by timestamp)."""
    models_dir = PROJECT_ROOT / "models" / task
    if not models_dir.exists():
        return None
    versions = [d.name for d in models_dir.iterdir() if d.is_dir()]
    if not versions:
        return None
    return sorted(versions)[-1]  # Latest timestamp


@click.command()
@click.argument("text", default="-")
@click.option(
    "--version", "-v",
    default=None,
    help="Model version to use (default: latest)",
)
@click.option(
    "--model", "-m",
    default=None,
    help="Custom model directory path (overrides version-based path)",
)
@click.option(
    "--format", "-f",
    "output_format",
    type=click.Choice(["inline", "json", "conll"]),
    default="inline",
    help="Output format",
    show_default=True,
)
def predict(text, version, model, output_format):
    """Tag Vietnamese text with POS tags.

    TEXT is the input text to tag. Use '-' to read from stdin.
    """
    # Use latest version if not specified
    if version is None and model is None:
        version = get_latest_version("pos_tagger")
        if version is None:
            raise click.ClickException("No models found in models/pos_tagger/")

    # Determine model path
    if model:
        model_path = model
    else:
        model_path = str(PROJECT_ROOT / "models" / "pos_tagger" / version)

    # Read input
    if text == "-":
        text = sys.stdin.read().strip()

    if not text:
        raise click.ClickException("No input text provided")

    # Load model
    handler = EndpointHandler(path=model_path)

    # Predict
    result = handler({"inputs": text})

    # Format output
    if output_format == "json":
        click.echo(json.dumps(result, ensure_ascii=False, indent=2))
    elif output_format == "conll":
        for i, item in enumerate(result, 1):
            click.echo(f"{i}\t{item['token']}\t{item['tag']}")
    else:  # inline
        tagged = " ".join(f"{item['token']}/{item['tag']}" for item in result)
        click.echo(tagged)


if __name__ == "__main__":
    predict()