bamboo-1 / scripts /predict.py
rain1024's picture
Initial commit: Vietnamese dependency parser with Biaffine architecture
b85c683
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "underthesea[deep]>=6.8.0",
# "click>=8.0.0",
# "torch>=2.0.0",
# "transformers>=4.30.0",
# ]
# ///
"""
Prediction script for Bamboo-1 Vietnamese Dependency Parser.
Usage:
# Interactive mode
uv run scripts/predict.py --model models/bamboo-1
# File input
uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu
# Single sentence
uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam"
"""
import sys
from pathlib import Path
import click
def format_tree_ascii(tokens, heads, deprels):
"""Format dependency tree as ASCII art."""
n = len(tokens)
lines = []
# Header
lines.append(" " + " ".join(f"{i+1:>3}" for i in range(n)))
lines.append(" " + " ".join(f"{t[:3]:>3}" for t in tokens))
# Draw arcs
for i in range(n):
head = heads[i]
if head == 0:
lines.append(f" {tokens[i]} <- ROOT ({deprels[i]})")
else:
arrow = "<-" if head > i + 1 else "->"
lines.append(f" {tokens[i]} {arrow} {tokens[head-1]} ({deprels[i]})")
return "\n".join(lines)
def format_conllu(tokens, heads, deprels, sent_id=None, text=None):
"""Format result as CoNLL-U."""
lines = []
if sent_id:
lines.append(f"# sent_id = {sent_id}")
if text:
lines.append(f"# text = {text}")
for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)):
lines.append(f"{i+1}\t{token}\t_\t_\t_\t_\t{head}\t{deprel}\t_\t_")
lines.append("")
return "\n".join(lines)
@click.command()
@click.option(
"--model", "-m",
required=True,
help="Path to trained model directory",
)
@click.option(
"--input", "-i",
"input_file",
help="Input file (one sentence per line)",
)
@click.option(
"--output", "-o",
"output_file",
help="Output file (CoNLL-U format)",
)
@click.option(
"--text", "-t",
help="Single sentence to parse",
)
@click.option(
"--format",
"output_format",
type=click.Choice(["conllu", "simple", "tree"]),
default="simple",
help="Output format",
show_default=True,
)
def predict(model, input_file, output_file, text, output_format):
"""Parse Vietnamese sentences with Bamboo-1 Dependency Parser."""
from underthesea.models.dependency_parser import DependencyParser
click.echo(f"Loading model from {model}...")
parser = DependencyParser.load(model)
click.echo("Model loaded.\n")
def parse_and_print(sentence, sent_id=None):
"""Parse a sentence and print the result."""
result = parser.predict(sentence)
tokens = [r[0] for r in result]
heads = [r[1] for r in result]
deprels = [r[2] for r in result]
if output_format == "conllu":
return format_conllu(tokens, heads, deprels, sent_id, sentence)
elif output_format == "tree":
output = f"Sentence: {sentence}\n"
output += format_tree_ascii(tokens, heads, deprels)
return output
else: # simple
output = f"Input: {sentence}\n"
output += "Output:\n"
for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)):
head_word = "ROOT" if head == 0 else tokens[head - 1]
output += f" {i+1}. {token} -> {head_word} ({deprel})\n"
return output
# Single text mode
if text:
result = parse_and_print(text, sent_id=1)
click.echo(result)
return
# File mode
if input_file:
click.echo(f"Reading from {input_file}...")
with open(input_file, "r", encoding="utf-8") as f:
sentences = [line.strip() for line in f if line.strip()]
click.echo(f"Parsing {len(sentences)} sentences...")
results = []
for i, sentence in enumerate(sentences, 1):
result = parse_and_print(sentence, sent_id=i)
results.append(result)
if i % 100 == 0:
click.echo(f" Processed {i}/{len(sentences)}...")
if output_file:
with open(output_file, "w", encoding="utf-8") as f:
f.write("\n".join(results))
click.echo(f"Results saved to {output_file}")
else:
for result in results:
click.echo(result)
click.echo()
return
# Interactive mode
click.echo("Interactive mode. Enter sentences to parse (Ctrl+C to exit).\n")
sent_id = 1
while True:
try:
sentence = input(">>> ").strip()
if not sentence:
continue
result = parse_and_print(sentence, sent_id=sent_id)
click.echo(result)
click.echo()
sent_id += 1
except KeyboardInterrupt:
click.echo("\nGoodbye!")
break
except EOFError:
break
if __name__ == "__main__":
predict()