| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Prediction script for Bamboo-1 Vietnamese Dependency Parser. |
| |
| Usage: |
| # Interactive mode |
| uv run scripts/predict.py --model models/bamboo-1 |
| |
| # File input |
| uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu |
| |
| # Single sentence |
| uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam" |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| import click |
|
|
|
|
| def format_tree_ascii(tokens, heads, deprels): |
| """Format dependency tree as ASCII art.""" |
| n = len(tokens) |
| lines = [] |
|
|
| |
| lines.append(" " + " ".join(f"{i+1:>3}" for i in range(n))) |
| lines.append(" " + " ".join(f"{t[:3]:>3}" for t in tokens)) |
|
|
| |
| for i in range(n): |
| head = heads[i] |
| if head == 0: |
| lines.append(f" {tokens[i]} <- ROOT ({deprels[i]})") |
| else: |
| arrow = "<-" if head > i + 1 else "->" |
| lines.append(f" {tokens[i]} {arrow} {tokens[head-1]} ({deprels[i]})") |
|
|
| return "\n".join(lines) |
|
|
|
|
| def format_conllu(tokens, heads, deprels, sent_id=None, text=None): |
| """Format result as CoNLL-U.""" |
| lines = [] |
| if sent_id: |
| lines.append(f"# sent_id = {sent_id}") |
| if text: |
| lines.append(f"# text = {text}") |
|
|
| for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): |
| lines.append(f"{i+1}\t{token}\t_\t_\t_\t_\t{head}\t{deprel}\t_\t_") |
|
|
| lines.append("") |
| return "\n".join(lines) |
|
|
|
|
| @click.command() |
| @click.option( |
| "--model", "-m", |
| required=True, |
| help="Path to trained model directory", |
| ) |
| @click.option( |
| "--input", "-i", |
| "input_file", |
| help="Input file (one sentence per line)", |
| ) |
| @click.option( |
| "--output", "-o", |
| "output_file", |
| help="Output file (CoNLL-U format)", |
| ) |
| @click.option( |
| "--text", "-t", |
| help="Single sentence to parse", |
| ) |
| @click.option( |
| "--format", |
| "output_format", |
| type=click.Choice(["conllu", "simple", "tree"]), |
| default="simple", |
| help="Output format", |
| show_default=True, |
| ) |
| def predict(model, input_file, output_file, text, output_format): |
| """Parse Vietnamese sentences with Bamboo-1 Dependency Parser.""" |
| from underthesea.models.dependency_parser import DependencyParser |
|
|
| click.echo(f"Loading model from {model}...") |
| parser = DependencyParser.load(model) |
| click.echo("Model loaded.\n") |
|
|
| def parse_and_print(sentence, sent_id=None): |
| """Parse a sentence and print the result.""" |
| result = parser.predict(sentence) |
| tokens = [r[0] for r in result] |
| heads = [r[1] for r in result] |
| deprels = [r[2] for r in result] |
|
|
| if output_format == "conllu": |
| return format_conllu(tokens, heads, deprels, sent_id, sentence) |
| elif output_format == "tree": |
| output = f"Sentence: {sentence}\n" |
| output += format_tree_ascii(tokens, heads, deprels) |
| return output |
| else: |
| output = f"Input: {sentence}\n" |
| output += "Output:\n" |
| for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): |
| head_word = "ROOT" if head == 0 else tokens[head - 1] |
| output += f" {i+1}. {token} -> {head_word} ({deprel})\n" |
| return output |
|
|
| |
| if text: |
| result = parse_and_print(text, sent_id=1) |
| click.echo(result) |
| return |
|
|
| |
| if input_file: |
| click.echo(f"Reading from {input_file}...") |
| with open(input_file, "r", encoding="utf-8") as f: |
| sentences = [line.strip() for line in f if line.strip()] |
|
|
| click.echo(f"Parsing {len(sentences)} sentences...") |
| results = [] |
| for i, sentence in enumerate(sentences, 1): |
| result = parse_and_print(sentence, sent_id=i) |
| results.append(result) |
| if i % 100 == 0: |
| click.echo(f" Processed {i}/{len(sentences)}...") |
|
|
| if output_file: |
| with open(output_file, "w", encoding="utf-8") as f: |
| f.write("\n".join(results)) |
| click.echo(f"Results saved to {output_file}") |
| else: |
| for result in results: |
| click.echo(result) |
| click.echo() |
| return |
|
|
| |
| click.echo("Interactive mode. Enter sentences to parse (Ctrl+C to exit).\n") |
| sent_id = 1 |
| while True: |
| try: |
| sentence = input(">>> ").strip() |
| if not sentence: |
| continue |
| result = parse_and_print(sentence, sent_id=sent_id) |
| click.echo(result) |
| click.echo() |
| sent_id += 1 |
| except KeyboardInterrupt: |
| click.echo("\nGoodbye!") |
| break |
| except EOFError: |
| break |
|
|
|
|
| if __name__ == "__main__": |
| predict() |
|
|