|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Prediction script for Bamboo-1 Vietnamese Dependency Parser. |
|
|
|
|
|
Usage: |
|
|
# Interactive mode |
|
|
uv run scripts/predict.py --model models/bamboo-1 |
|
|
|
|
|
# File input |
|
|
uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu |
|
|
|
|
|
# Single sentence |
|
|
uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam" |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import click |
|
|
|
|
|
|
|
|
def format_tree_ascii(tokens, heads, deprels): |
|
|
"""Format dependency tree as ASCII art.""" |
|
|
n = len(tokens) |
|
|
lines = [] |
|
|
|
|
|
|
|
|
lines.append(" " + " ".join(f"{i+1:>3}" for i in range(n))) |
|
|
lines.append(" " + " ".join(f"{t[:3]:>3}" for t in tokens)) |
|
|
|
|
|
|
|
|
for i in range(n): |
|
|
head = heads[i] |
|
|
if head == 0: |
|
|
lines.append(f" {tokens[i]} <- ROOT ({deprels[i]})") |
|
|
else: |
|
|
arrow = "<-" if head > i + 1 else "->" |
|
|
lines.append(f" {tokens[i]} {arrow} {tokens[head-1]} ({deprels[i]})") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def format_conllu(tokens, heads, deprels, sent_id=None, text=None): |
|
|
"""Format result as CoNLL-U.""" |
|
|
lines = [] |
|
|
if sent_id: |
|
|
lines.append(f"# sent_id = {sent_id}") |
|
|
if text: |
|
|
lines.append(f"# text = {text}") |
|
|
|
|
|
for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): |
|
|
lines.append(f"{i+1}\t{token}\t_\t_\t_\t_\t{head}\t{deprel}\t_\t_") |
|
|
|
|
|
lines.append("") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
@click.command() |
|
|
@click.option( |
|
|
"--model", "-m", |
|
|
required=True, |
|
|
help="Path to trained model directory", |
|
|
) |
|
|
@click.option( |
|
|
"--input", "-i", |
|
|
"input_file", |
|
|
help="Input file (one sentence per line)", |
|
|
) |
|
|
@click.option( |
|
|
"--output", "-o", |
|
|
"output_file", |
|
|
help="Output file (CoNLL-U format)", |
|
|
) |
|
|
@click.option( |
|
|
"--text", "-t", |
|
|
help="Single sentence to parse", |
|
|
) |
|
|
@click.option( |
|
|
"--format", |
|
|
"output_format", |
|
|
type=click.Choice(["conllu", "simple", "tree"]), |
|
|
default="simple", |
|
|
help="Output format", |
|
|
show_default=True, |
|
|
) |
|
|
def predict(model, input_file, output_file, text, output_format): |
|
|
"""Parse Vietnamese sentences with Bamboo-1 Dependency Parser.""" |
|
|
from underthesea.models.dependency_parser import DependencyParser |
|
|
|
|
|
click.echo(f"Loading model from {model}...") |
|
|
parser = DependencyParser.load(model) |
|
|
click.echo("Model loaded.\n") |
|
|
|
|
|
def parse_and_print(sentence, sent_id=None): |
|
|
"""Parse a sentence and print the result.""" |
|
|
result = parser.predict(sentence) |
|
|
tokens = [r[0] for r in result] |
|
|
heads = [r[1] for r in result] |
|
|
deprels = [r[2] for r in result] |
|
|
|
|
|
if output_format == "conllu": |
|
|
return format_conllu(tokens, heads, deprels, sent_id, sentence) |
|
|
elif output_format == "tree": |
|
|
output = f"Sentence: {sentence}\n" |
|
|
output += format_tree_ascii(tokens, heads, deprels) |
|
|
return output |
|
|
else: |
|
|
output = f"Input: {sentence}\n" |
|
|
output += "Output:\n" |
|
|
for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): |
|
|
head_word = "ROOT" if head == 0 else tokens[head - 1] |
|
|
output += f" {i+1}. {token} -> {head_word} ({deprel})\n" |
|
|
return output |
|
|
|
|
|
|
|
|
if text: |
|
|
result = parse_and_print(text, sent_id=1) |
|
|
click.echo(result) |
|
|
return |
|
|
|
|
|
|
|
|
if input_file: |
|
|
click.echo(f"Reading from {input_file}...") |
|
|
with open(input_file, "r", encoding="utf-8") as f: |
|
|
sentences = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
click.echo(f"Parsing {len(sentences)} sentences...") |
|
|
results = [] |
|
|
for i, sentence in enumerate(sentences, 1): |
|
|
result = parse_and_print(sentence, sent_id=i) |
|
|
results.append(result) |
|
|
if i % 100 == 0: |
|
|
click.echo(f" Processed {i}/{len(sentences)}...") |
|
|
|
|
|
if output_file: |
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
|
f.write("\n".join(results)) |
|
|
click.echo(f"Results saved to {output_file}") |
|
|
else: |
|
|
for result in results: |
|
|
click.echo(result) |
|
|
click.echo() |
|
|
return |
|
|
|
|
|
|
|
|
click.echo("Interactive mode. Enter sentences to parse (Ctrl+C to exit).\n") |
|
|
sent_id = 1 |
|
|
while True: |
|
|
try: |
|
|
sentence = input(">>> ").strip() |
|
|
if not sentence: |
|
|
continue |
|
|
result = parse_and_print(sentence, sent_id=sent_id) |
|
|
click.echo(result) |
|
|
click.echo() |
|
|
sent_id += 1 |
|
|
except KeyboardInterrupt: |
|
|
click.echo("\nGoodbye!") |
|
|
break |
|
|
except EOFError: |
|
|
break |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
predict() |
|
|
|