File size: 5,025 Bytes
b85c683 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "underthesea[deep]>=6.8.0",
# "click>=8.0.0",
# "torch>=2.0.0",
# "transformers>=4.30.0",
# ]
# ///
"""
Prediction script for Bamboo-1 Vietnamese Dependency Parser.
Usage:
# Interactive mode
uv run scripts/predict.py --model models/bamboo-1
# File input
uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu
# Single sentence
uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam"
"""
import sys
from pathlib import Path
import click
def format_tree_ascii(tokens, heads, deprels):
"""Format dependency tree as ASCII art."""
n = len(tokens)
lines = []
# Header
lines.append(" " + " ".join(f"{i+1:>3}" for i in range(n)))
lines.append(" " + " ".join(f"{t[:3]:>3}" for t in tokens))
# Draw arcs
for i in range(n):
head = heads[i]
if head == 0:
lines.append(f" {tokens[i]} <- ROOT ({deprels[i]})")
else:
arrow = "<-" if head > i + 1 else "->"
lines.append(f" {tokens[i]} {arrow} {tokens[head-1]} ({deprels[i]})")
return "\n".join(lines)
def format_conllu(tokens, heads, deprels, sent_id=None, text=None):
"""Format result as CoNLL-U."""
lines = []
if sent_id:
lines.append(f"# sent_id = {sent_id}")
if text:
lines.append(f"# text = {text}")
for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)):
lines.append(f"{i+1}\t{token}\t_\t_\t_\t_\t{head}\t{deprel}\t_\t_")
lines.append("")
return "\n".join(lines)
@click.command()
@click.option(
"--model", "-m",
required=True,
help="Path to trained model directory",
)
@click.option(
"--input", "-i",
"input_file",
help="Input file (one sentence per line)",
)
@click.option(
"--output", "-o",
"output_file",
help="Output file (CoNLL-U format)",
)
@click.option(
"--text", "-t",
help="Single sentence to parse",
)
@click.option(
"--format",
"output_format",
type=click.Choice(["conllu", "simple", "tree"]),
default="simple",
help="Output format",
show_default=True,
)
def predict(model, input_file, output_file, text, output_format):
"""Parse Vietnamese sentences with Bamboo-1 Dependency Parser."""
from underthesea.models.dependency_parser import DependencyParser
click.echo(f"Loading model from {model}...")
parser = DependencyParser.load(model)
click.echo("Model loaded.\n")
def parse_and_print(sentence, sent_id=None):
"""Parse a sentence and print the result."""
result = parser.predict(sentence)
tokens = [r[0] for r in result]
heads = [r[1] for r in result]
deprels = [r[2] for r in result]
if output_format == "conllu":
return format_conllu(tokens, heads, deprels, sent_id, sentence)
elif output_format == "tree":
output = f"Sentence: {sentence}\n"
output += format_tree_ascii(tokens, heads, deprels)
return output
else: # simple
output = f"Input: {sentence}\n"
output += "Output:\n"
for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)):
head_word = "ROOT" if head == 0 else tokens[head - 1]
output += f" {i+1}. {token} -> {head_word} ({deprel})\n"
return output
# Single text mode
if text:
result = parse_and_print(text, sent_id=1)
click.echo(result)
return
# File mode
if input_file:
click.echo(f"Reading from {input_file}...")
with open(input_file, "r", encoding="utf-8") as f:
sentences = [line.strip() for line in f if line.strip()]
click.echo(f"Parsing {len(sentences)} sentences...")
results = []
for i, sentence in enumerate(sentences, 1):
result = parse_and_print(sentence, sent_id=i)
results.append(result)
if i % 100 == 0:
click.echo(f" Processed {i}/{len(sentences)}...")
if output_file:
with open(output_file, "w", encoding="utf-8") as f:
f.write("\n".join(results))
click.echo(f"Results saved to {output_file}")
else:
for result in results:
click.echo(result)
click.echo()
return
# Interactive mode
click.echo("Interactive mode. Enter sentences to parse (Ctrl+C to exit).\n")
sent_id = 1
while True:
try:
sentence = input(">>> ").strip()
if not sentence:
continue
result = parse_and_print(sentence, sent_id=sent_id)
click.echo(result)
click.echo()
sent_id += 1
except KeyboardInterrupt:
click.echo("\nGoodbye!")
break
except EOFError:
break
if __name__ == "__main__":
predict()
|