# /// script # requires-python = ">=3.9" # dependencies = [ # "python-crfsuite>=0.9.11", # "click>=8.0.0", # "underthesea-core @ file:///home/claude-user/projects/workspace_underthesea/underthesea-core-dev/extensions/underthesea_core/target/wheels/underthesea_core-1.0.7-cp312-cp312-manylinux_2_34_x86_64.whl", # ] # /// """ Inference script for Vietnamese POS Tagger (TRE-1). Usage: uv run scripts/predict.py "Tôi yêu Việt Nam" uv run scripts/predict.py --version v1.0.0 "Hà Nội là thủ đô" uv run scripts/predict.py --model models/pos_tagger/v1.0.0 "Test" echo "Học sinh đang học bài" | uv run scripts/predict.py - """ import json import sys import os from pathlib import Path import click # Add parent directory to import handler sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Get project root directory PROJECT_ROOT = Path(__file__).parent.parent from handler import EndpointHandler def get_latest_version(task="pos_tagger"): """Get the latest model version (sorted by timestamp).""" models_dir = PROJECT_ROOT / "models" / task if not models_dir.exists(): return None versions = [d.name for d in models_dir.iterdir() if d.is_dir()] if not versions: return None return sorted(versions)[-1] # Latest timestamp @click.command() @click.argument("text", default="-") @click.option( "--version", "-v", default=None, help="Model version to use (default: latest)", ) @click.option( "--model", "-m", default=None, help="Custom model directory path (overrides version-based path)", ) @click.option( "--format", "-f", "output_format", type=click.Choice(["inline", "json", "conll"]), default="inline", help="Output format", show_default=True, ) def predict(text, version, model, output_format): """Tag Vietnamese text with POS tags. TEXT is the input text to tag. Use '-' to read from stdin. """ # Use latest version if not specified if version is None and model is None: version = get_latest_version("pos_tagger") if version is None: raise click.ClickException("No models found in models/pos_tagger/") # Determine model path if model: model_path = model else: model_path = str(PROJECT_ROOT / "models" / "pos_tagger" / version) # Read input if text == "-": text = sys.stdin.read().strip() if not text: raise click.ClickException("No input text provided") # Load model handler = EndpointHandler(path=model_path) # Predict result = handler({"inputs": text}) # Format output if output_format == "json": click.echo(json.dumps(result, ensure_ascii=False, indent=2)) elif output_format == "conll": for i, item in enumerate(result, 1): click.echo(f"{i}\t{item['token']}\t{item['tag']}") else: # inline tagged = " ".join(f"{item['token']}/{item['tag']}" for item in result) click.echo(tagged) if __name__ == "__main__": predict()