# /// script # requires-python = ">=3.10" # dependencies = [ # "underthesea[deep]>=6.8.0", # "click>=8.0.0", # "torch>=2.0.0", # "transformers>=4.30.0", # ] # /// """ Prediction script for Bamboo-1 Vietnamese Dependency Parser. Usage: # Interactive mode uv run scripts/predict.py --model models/bamboo-1 # File input uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu # Single sentence uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam" """ import sys from pathlib import Path import click def format_tree_ascii(tokens, heads, deprels): """Format dependency tree as ASCII art.""" n = len(tokens) lines = [] # Header lines.append(" " + " ".join(f"{i+1:>3}" for i in range(n))) lines.append(" " + " ".join(f"{t[:3]:>3}" for t in tokens)) # Draw arcs for i in range(n): head = heads[i] if head == 0: lines.append(f" {tokens[i]} <- ROOT ({deprels[i]})") else: arrow = "<-" if head > i + 1 else "->" lines.append(f" {tokens[i]} {arrow} {tokens[head-1]} ({deprels[i]})") return "\n".join(lines) def format_conllu(tokens, heads, deprels, sent_id=None, text=None): """Format result as CoNLL-U.""" lines = [] if sent_id: lines.append(f"# sent_id = {sent_id}") if text: lines.append(f"# text = {text}") for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): lines.append(f"{i+1}\t{token}\t_\t_\t_\t_\t{head}\t{deprel}\t_\t_") lines.append("") return "\n".join(lines) @click.command() @click.option( "--model", "-m", required=True, help="Path to trained model directory", ) @click.option( "--input", "-i", "input_file", help="Input file (one sentence per line)", ) @click.option( "--output", "-o", "output_file", help="Output file (CoNLL-U format)", ) @click.option( "--text", "-t", help="Single sentence to parse", ) @click.option( "--format", "output_format", type=click.Choice(["conllu", "simple", "tree"]), default="simple", help="Output format", show_default=True, ) def predict(model, input_file, output_file, text, output_format): """Parse Vietnamese sentences with Bamboo-1 Dependency Parser.""" from underthesea.models.dependency_parser import DependencyParser click.echo(f"Loading model from {model}...") parser = DependencyParser.load(model) click.echo("Model loaded.\n") def parse_and_print(sentence, sent_id=None): """Parse a sentence and print the result.""" result = parser.predict(sentence) tokens = [r[0] for r in result] heads = [r[1] for r in result] deprels = [r[2] for r in result] if output_format == "conllu": return format_conllu(tokens, heads, deprels, sent_id, sentence) elif output_format == "tree": output = f"Sentence: {sentence}\n" output += format_tree_ascii(tokens, heads, deprels) return output else: # simple output = f"Input: {sentence}\n" output += "Output:\n" for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): head_word = "ROOT" if head == 0 else tokens[head - 1] output += f" {i+1}. {token} -> {head_word} ({deprel})\n" return output # Single text mode if text: result = parse_and_print(text, sent_id=1) click.echo(result) return # File mode if input_file: click.echo(f"Reading from {input_file}...") with open(input_file, "r", encoding="utf-8") as f: sentences = [line.strip() for line in f if line.strip()] click.echo(f"Parsing {len(sentences)} sentences...") results = [] for i, sentence in enumerate(sentences, 1): result = parse_and_print(sentence, sent_id=i) results.append(result) if i % 100 == 0: click.echo(f" Processed {i}/{len(sentences)}...") if output_file: with open(output_file, "w", encoding="utf-8") as f: f.write("\n".join(results)) click.echo(f"Results saved to {output_file}") else: for result in results: click.echo(result) click.echo() return # Interactive mode click.echo("Interactive mode. Enter sentences to parse (Ctrl+C to exit).\n") sent_id = 1 while True: try: sentence = input(">>> ").strip() if not sentence: continue result = parse_and_print(sentence, sent_id=sent_id) click.echo(result) click.echo() sent_id += 1 except KeyboardInterrupt: click.echo("\nGoodbye!") break except EOFError: break if __name__ == "__main__": predict()