Spaces:
Sleeping
Sleeping
Commit
·
586650a
1
Parent(s):
698e6dd
upload files
Browse files- .gradio/certificate.pem +31 -0
- .requirements.txt.un~ +0 -0
- app.py +196 -4
- polish-test-lg.conllu.conllu +0 -0
- requirements.txt +107 -0
- spanish-test-sm.conllu +100 -0
- treetse/__init__.py +0 -0
- treetse/__pycache__/__init__.cpython-312.pyc +0 -0
- treetse/__pycache__/__init__.cpython-313.pyc +0 -0
- treetse/__pycache__/pipeline.cpython-312.pyc +0 -0
- treetse/evaluators/__pycache__/evaluator.cpython-312.pyc +0 -0
- treetse/evaluators/__pycache__/evaluator.cpython-313.pyc +0 -0
- treetse/evaluators/evaluator.py +93 -0
- treetse/evaluators/perplexity.py +32 -0
- treetse/pipeline.py +200 -0
- treetse/preprocessing/__init__.py +0 -0
- treetse/preprocessing/__pycache__/__init__.cpython-312.pyc +0 -0
- treetse/preprocessing/__pycache__/conllu_parser.cpython-312.pyc +0 -0
- treetse/preprocessing/__pycache__/conllu_parser.cpython-313.pyc +0 -0
- treetse/preprocessing/__pycache__/grew_dependencies.cpython-312.pyc +0 -0
- treetse/preprocessing/__pycache__/reconstruction.cpython-312.pyc +0 -0
- treetse/preprocessing/__pycache__/reconstruction.cpython-313.pyc +0 -0
- treetse/preprocessing/conllu_parser.py +402 -0
- treetse/preprocessing/grew_dependencies.py +20 -0
- treetse/preprocessing/reconstruction.py +78 -0
- treetse/visualise/__pycache__/visualiser.cpython-312.pyc +0 -0
- treetse/visualise/__pycache__/visualiser.cpython-313.pyc +0 -0
- treetse/visualise/visualiser.py +114 -0
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
.requirements.txt.un~
ADDED
|
Binary file (947 Bytes). View file
|
|
|
app.py
CHANGED
|
@@ -1,7 +1,199 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import tempfile
|
| 4 |
+
import ast
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
from PIL import Image
|
| 8 |
|
| 9 |
+
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 10 |
+
from treetse.pipeline import Grewtse
|
| 11 |
|
| 12 |
+
grewtse = Grewtse()
|
| 13 |
+
treebank_path = None
|
| 14 |
+
|
| 15 |
+
def parse_treebank(path: str, treebank_selection: str) -> pd.DataFrame:
|
| 16 |
+
if treebank_selection == "None":
|
| 17 |
+
successful_treebank_parse = grewtse.parse_treebank(path)
|
| 18 |
+
treebank_path = path
|
| 19 |
+
else:
|
| 20 |
+
successful_treebank_parse = grewtse.parse_treebank(treebank_selection)
|
| 21 |
+
treebank_path = treebank_selection
|
| 22 |
+
|
| 23 |
+
print("changing treebank parse success")
|
| 24 |
+
is_treebank_parse_success = True
|
| 25 |
+
return grewtse.get_morphological_features().head()
|
| 26 |
+
|
| 27 |
+
def to_masked_dataset(query, node) -> pd.DataFrame:
|
| 28 |
+
df = grewtse.generate_masked_dataset(query, node)
|
| 29 |
+
return df
|
| 30 |
+
|
| 31 |
+
def safe_str_to_dict(s):
|
| 32 |
+
try:
|
| 33 |
+
return ast.literal_eval(s)
|
| 34 |
+
except (ValueError, SyntaxError):
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
def generate_minimal_pairs(query: str, node: str, alt_features: str):
|
| 38 |
+
if not grewtse.is_treebank_loaded():
|
| 39 |
+
raise ValueError("Please parse a treebank first.")
|
| 40 |
+
|
| 41 |
+
# mask each sentence
|
| 42 |
+
resulting_dataset = to_masked_dataset(query, node)
|
| 43 |
+
|
| 44 |
+
# determine whether an alternative LI should be found
|
| 45 |
+
alt_features_as_dict = safe_str_to_dict(alt_features)
|
| 46 |
+
if alt_features_as_dict is not None:
|
| 47 |
+
resulting_dataset = grewtse.generate_minimal_pairs(alt_features_as_dict, {})
|
| 48 |
+
# resulting_dataset = grewtse.get_masked_dataset()
|
| 49 |
+
print(resulting_dataset)
|
| 50 |
+
|
| 51 |
+
# save to a temporary CSV file
|
| 52 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
| 53 |
+
resulting_dataset.to_csv(temp_file.name, index=False)
|
| 54 |
+
return resulting_dataset, temp_file.name
|
| 55 |
+
|
| 56 |
+
def evaluate_model(model_repo: str, target_x_label: str, alt_x_label: str, x_axis_label: str, title: str):
|
| 57 |
+
if not grewtse.are_minimal_pairs_generated():
|
| 58 |
+
raise ValueError("Please parse a treebank, mask a dataset and generate minimal pairs first.")
|
| 59 |
+
|
| 60 |
+
mp_with_eval_dataset = grewtse.evaluate_bert_mlm(model_repo)
|
| 61 |
+
vis_filename = "vis.png"
|
| 62 |
+
|
| 63 |
+
grewtse.visualise_syntactic_performance(vis_filename,
|
| 64 |
+
mp_with_eval_dataset,
|
| 65 |
+
target_x_label,
|
| 66 |
+
alt_x_label,
|
| 67 |
+
x_axis_label,
|
| 68 |
+
"Confidence",
|
| 69 |
+
title)
|
| 70 |
+
|
| 71 |
+
# save to a temporary CSV file
|
| 72 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
| 73 |
+
mp_with_eval_dataset.to_csv(temp_file.name, index=False)
|
| 74 |
+
return mp_with_eval_dataset, temp_file.name, vis_filename
|
| 75 |
+
|
| 76 |
+
def show_df():
|
| 77 |
+
return gr.update(visible=True)
|
| 78 |
+
|
| 79 |
+
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
|
| 80 |
+
is_treebank_parse_success = False
|
| 81 |
+
|
| 82 |
+
with gr.Row():
|
| 83 |
+
gr.Markdown("# GREW-TSE: A Pipeline for Query-based Targeted Syntactic Evaluation")
|
| 84 |
+
|
| 85 |
+
with gr.Row():
|
| 86 |
+
with gr.Column():
|
| 87 |
+
gr.Markdown("""
|
| 88 |
+
#### Load a Treebank
|
| 89 |
+
You can begin by loading up a particular treebank that you'd like to work with.<br>
|
| 90 |
+
You can either select a treebank from the pre-loaded options below, or upload your own.<br>
|
| 91 |
+
""")
|
| 92 |
+
|
| 93 |
+
with gr.Column():
|
| 94 |
+
with gr.Tabs():
|
| 95 |
+
with gr.TabItem("Choose Treebank"):
|
| 96 |
+
treebank_selection = gr.Dropdown(
|
| 97 |
+
choices=["None", "spanish-test-sm.conllu", "polish-test-lg.conllu"],
|
| 98 |
+
label="Select a treebank",
|
| 99 |
+
value="spanish-test-sm.conllu"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
with gr.TabItem("Upload Your Own"):
|
| 103 |
+
gr.Markdown("## Upload a .conllu File")
|
| 104 |
+
file_input = gr.File(
|
| 105 |
+
label="Upload .conllu file",
|
| 106 |
+
file_types=[".conllu"],
|
| 107 |
+
type="filepath"
|
| 108 |
+
)
|
| 109 |
+
parse_file_button = gr.Button("Parse Treebank", size='sm', scale=1)
|
| 110 |
+
|
| 111 |
+
gr.Markdown("## Isolate A Syntactic Phenomenon")
|
| 112 |
+
morph_table = gr.Dataframe(interactive=False, visible=False)
|
| 113 |
+
|
| 114 |
+
parse_file_button.click(
|
| 115 |
+
fn=parse_treebank,
|
| 116 |
+
inputs=[file_input, treebank_selection],
|
| 117 |
+
outputs=[morph_table]
|
| 118 |
+
)
|
| 119 |
+
parse_file_button.click(
|
| 120 |
+
fn=show_df,
|
| 121 |
+
outputs=morph_table
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
with gr.Row():
|
| 125 |
+
with gr.Column():
|
| 126 |
+
gr.Markdown("""
|
| 127 |
+
**GREW (Graph Rewriting for Universal Dependencies)** is a query and transformation language used to search within and manipulate dependency treebanks. A GREW query allows linguists and NLP researchers to find specific syntactic patterns in parsed linguistic data (such as Universal Dependencies treebanks).
|
| 128 |
+
Queries are expressed as graph constraints using a concise pattern-matching syntax.
|
| 129 |
+
|
| 130 |
+
#### Example
|
| 131 |
+
The following short GREW query will find target any verbs. Try it with one of the sample treebanks above.
|
| 132 |
+
Make sure to include the variable V as the target that we're trying to isolate.
|
| 133 |
+
|
| 134 |
+
```grew
|
| 135 |
+
V [upos=\"VERB\"];
|
| 136 |
+
```
|
| 137 |
+
""")
|
| 138 |
+
with gr.Column():
|
| 139 |
+
query_input = gr.Textbox(label="GREW Query", lines=5, placeholder="Enter your GREW query here...", value="V [upos=\"VERB\"];")
|
| 140 |
+
node_input = gr.Textbox(label="Node", placeholder="The variable in your GREW query to isolate, e.g., N", value="V")
|
| 141 |
+
feature_input = gr.Textbox(
|
| 142 |
+
label="Enter Alternative Feature Values for Minimal Pair as a Dictionary",
|
| 143 |
+
placeholder='e.g. {"case": "Acc", "number": "Sing"}',
|
| 144 |
+
value="{\"mood\": \"Sub\"}",
|
| 145 |
+
lines=3
|
| 146 |
+
)
|
| 147 |
+
run_button = gr.Button("Run Query", size='sm', scale=3)
|
| 148 |
+
|
| 149 |
+
output_table = gr.Dataframe(label="Output Table", visible=False)
|
| 150 |
+
download_file = gr.File(label="Download CSV")
|
| 151 |
+
run_button.click(
|
| 152 |
+
fn=generate_minimal_pairs,
|
| 153 |
+
inputs=[query_input, node_input, feature_input],
|
| 154 |
+
outputs=[output_table, download_file]
|
| 155 |
+
)
|
| 156 |
+
run_button.click(
|
| 157 |
+
fn=show_df,
|
| 158 |
+
outputs=output_table
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with gr.Row():
|
| 162 |
+
with gr.Column():
|
| 163 |
+
gr.Markdown("""
|
| 164 |
+
## Evaluate A Model
|
| 165 |
+
You can evaluate any BERT for MLM model by providing the name of the model repository.
|
| 166 |
+
""")
|
| 167 |
+
with gr.Column():
|
| 168 |
+
repository_input = gr.Textbox(label="Model Repository", lines=1, placeholder="Enter the model repository here...", value="dccuchile/distilbert-base-spanish-uncased")
|
| 169 |
+
|
| 170 |
+
with gr.Row():
|
| 171 |
+
with gr.Column():
|
| 172 |
+
gr.Markdown("""
|
| 173 |
+
## Choose Visualisation Settings
|
| 174 |
+
The results will be displayed as a visualisation which you can edit using the following settings.
|
| 175 |
+
""")
|
| 176 |
+
with gr.Column():
|
| 177 |
+
target_x_label_textbox = gr.Textbox(label="Original Label Name i.e type of the 'right' token", lines=1, placeholder="Genitive Version")
|
| 178 |
+
alt_x_label_textbox = gr.Textbox(label="Alternative Label Name i.e type of the 'wrong' token", lines=1, placeholder="Accusative Version")
|
| 179 |
+
x_axis_label_textbox = gr.Textbox(label="X Axis Title i.e what features are you comparing?", lines=1, placeholder="Case of Nouns in Transitive Verbs")
|
| 180 |
+
title_textbox = gr.Textbox(label="Visualisation Title", lines=1, placeholder="Syntactic Performance of BERT on English Transitive Noun Case")
|
| 181 |
+
|
| 182 |
+
evaluate_button = gr.Button("Evaluate Model", size='sm', scale=3)
|
| 183 |
+
|
| 184 |
+
mp_with_eval_output_dataset = gr.Dataframe(label="Output Table", visible=False)
|
| 185 |
+
mp_with_eval_output_download = gr.File(label="Download CSV")
|
| 186 |
+
visualisation_widget = gr.Image(type="pil", label="Loaded Image")
|
| 187 |
+
|
| 188 |
+
evaluate_button.click(
|
| 189 |
+
fn=evaluate_model,
|
| 190 |
+
inputs=[repository_input, target_x_label_textbox, alt_x_label_textbox, x_axis_label_textbox, title_textbox],
|
| 191 |
+
outputs=[mp_with_eval_output_dataset, mp_with_eval_output_download, visualisation_widget]
|
| 192 |
+
)
|
| 193 |
+
evaluate_button.click(
|
| 194 |
+
fn=show_df,
|
| 195 |
+
outputs=[mp_with_eval_output_dataset]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
demo.launch(share=True)
|
polish-test-lg.conllu.conllu
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==24.1.0
|
| 2 |
+
annotated-types==0.7.0
|
| 3 |
+
anyio==4.9.0
|
| 4 |
+
black==25.1.0
|
| 5 |
+
Brotli==1.1.0
|
| 6 |
+
certifi==2025.4.26
|
| 7 |
+
charset-normalizer==3.4.2
|
| 8 |
+
click==8.2.0
|
| 9 |
+
colorama==0.4.6
|
| 10 |
+
conllu==6.0.0
|
| 11 |
+
contourpy==1.3.2
|
| 12 |
+
coverage==7.8.0
|
| 13 |
+
cycler==0.12.1
|
| 14 |
+
fastapi==0.116.1
|
| 15 |
+
ffmpy==0.6.0
|
| 16 |
+
filelock==3.18.0
|
| 17 |
+
fonttools==4.58.4
|
| 18 |
+
fsspec==2025.5.1
|
| 19 |
+
gradio==5.37.0
|
| 20 |
+
gradio_client==1.10.4
|
| 21 |
+
grewpy==0.6.0
|
| 22 |
+
groovy==0.1.2
|
| 23 |
+
h11==0.16.0
|
| 24 |
+
hf-xet==1.1.5
|
| 25 |
+
httpcore==1.0.9
|
| 26 |
+
httpx==0.28.1
|
| 27 |
+
huggingface-hub==0.32.2
|
| 28 |
+
idna==3.10
|
| 29 |
+
iniconfig==2.1.0
|
| 30 |
+
isort==6.0.1
|
| 31 |
+
Jinja2==3.1.6
|
| 32 |
+
kiwisolver==1.4.8
|
| 33 |
+
lark==1.2.2
|
| 34 |
+
markdown-it-py==3.0.0
|
| 35 |
+
MarkupSafe==3.0.2
|
| 36 |
+
matplotlib==3.10.3
|
| 37 |
+
mdurl==0.1.2
|
| 38 |
+
mizani==0.13.5
|
| 39 |
+
mpmath==1.3.0
|
| 40 |
+
mypy==1.15.0
|
| 41 |
+
mypy_extensions==1.1.0
|
| 42 |
+
networkx==3.4.2
|
| 43 |
+
numpy==2.2.5
|
| 44 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 45 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 46 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 47 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 48 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 49 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 50 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 51 |
+
nvidia-curand-cu12==10.3.7.77
|
| 52 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 53 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 54 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 55 |
+
nvidia-nccl-cu12==2.26.2
|
| 56 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 57 |
+
nvidia-nvtx-cu12==12.6.77
|
| 58 |
+
orjson==3.10.18
|
| 59 |
+
packaging==25.0
|
| 60 |
+
pandas==2.2.3
|
| 61 |
+
pandas-stubs==2.2.3.250527
|
| 62 |
+
pathspec==0.12.1
|
| 63 |
+
patsy==1.0.1
|
| 64 |
+
pillow==11.2.1
|
| 65 |
+
platformdirs==4.3.8
|
| 66 |
+
plotnine==0.14.5
|
| 67 |
+
pluggy==1.5.0
|
| 68 |
+
pydantic==2.11.7
|
| 69 |
+
pydantic_core==2.33.2
|
| 70 |
+
pydub==0.25.1
|
| 71 |
+
Pygments==2.19.2
|
| 72 |
+
pyparsing==3.2.3
|
| 73 |
+
pytest==8.3.5
|
| 74 |
+
pytest-cov==6.1.1
|
| 75 |
+
python-dateutil==2.9.0.post0
|
| 76 |
+
python-multipart==0.0.20
|
| 77 |
+
pytz==2025.2
|
| 78 |
+
PyYAML==6.0.2
|
| 79 |
+
regex==2024.11.6
|
| 80 |
+
requests==2.32.3
|
| 81 |
+
rich==14.0.0
|
| 82 |
+
ruff==0.11.9
|
| 83 |
+
safehttpx==0.1.6
|
| 84 |
+
safetensors==0.5.3
|
| 85 |
+
scipy==1.15.3
|
| 86 |
+
semantic-version==2.10.0
|
| 87 |
+
setuptools==80.4.0
|
| 88 |
+
shellingham==1.5.4
|
| 89 |
+
six==1.17.0
|
| 90 |
+
sniffio==1.3.1
|
| 91 |
+
starlette==0.47.1
|
| 92 |
+
statsmodels==0.14.4
|
| 93 |
+
sympy==1.14.0
|
| 94 |
+
tokenizers==0.21.1
|
| 95 |
+
tomlkit==0.13.3
|
| 96 |
+
torch==2.7.0
|
| 97 |
+
tqdm==4.67.1
|
| 98 |
+
transformers==4.52.3
|
| 99 |
+
triton==3.3.0
|
| 100 |
+
typer==0.16.0
|
| 101 |
+
types-pytz==2025.2.0.20250516
|
| 102 |
+
typing-inspection==0.4.1
|
| 103 |
+
typing_extensions==4.13.2
|
| 104 |
+
tzdata==2025.2
|
| 105 |
+
urllib3==2.4.0
|
| 106 |
+
uvicorn==0.35.0
|
| 107 |
+
websockets==15.0.1
|
spanish-test-sm.conllu
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# global.columns = ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC
|
| 2 |
+
# sent_id = 3LB-CAST-c2-2-s18
|
| 3 |
+
# text = Sea enhorabuena.
|
| 4 |
+
# orig_file_sentence 011#21
|
| 5 |
+
1 Sea ser AUX vssp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
|
| 6 |
+
2 enhorabuena enhorabuena NOUN ncfs000 Gender=Fem|Number=Sing 1 nsubj 1:nsubj ArgTem=arg1:tem|SpaceAfter=No
|
| 7 |
+
3 . . PUNCT fp PunctType=Peri 1 punct 1:punct _
|
| 8 |
+
|
| 9 |
+
# sent_id = 3LB-CAST-d2-12-s5
|
| 10 |
+
# text = Esperemos que muy pronto.
|
| 11 |
+
# orig_file_sentence 002#8
|
| 12 |
+
1 Esperemos esperar VERB vmsp1p0 Mood=Sub|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin 0 root 0:root _
|
| 13 |
+
2 que que SCONJ cs _ 4 mark 4:mark _
|
| 14 |
+
3 muy mucho ADV rg _ 4 advmod 4:advmod _
|
| 15 |
+
4 pronto pronto ADV rg _ 1 obj 1:obj ArgTem=arg1:pat|SpaceAfter=No
|
| 16 |
+
5 . . PUNCT fp PunctType=Peri 1 punct 1:punct _
|
| 17 |
+
|
| 18 |
+
# sent_id = CESS-CAST-P-20000701-69-s8
|
| 19 |
+
# text = Mujer hubiera mostrado arrepentimiento.
|
| 20 |
+
# orig_file_sentence 131#40
|
| 21 |
+
1 Mujer mujer NOUN ncfs000 Gender=Fem|Number=Sing 3 nsubj 3:nsubj ArgTem=arg0:agt|Entity=(NOCOREF:Gen--1-gstype:gen)
|
| 22 |
+
2 hubiera haber AUX vasi3s0 Mood=Sub|Number=Sing|Person=3|Tense=Imp|VerbForm=Fin 3 aux 3:aux _
|
| 23 |
+
3 mostrado mostrar VERB vmp00sm Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 0 root 0:root _
|
| 24 |
+
4 arrepentimiento arrepentimiento NOUN ncms000 Gender=Masc|Number=Sing 3 obj 3:obj ArgTem=arg1:pat|Entity=(NOCOREF:Gen--1-gstype:gen)|SpaceAfter=No
|
| 25 |
+
5 . . PUNCT fp PunctType=Peri 3 punct 3:punct _
|
| 26 |
+
|
| 27 |
+
# sent_id = CESS-CAST-P-20000601-41-s10
|
| 28 |
+
# text = Que por ritmo no quede.
|
| 29 |
+
# orig_file_sentence 119#88
|
| 30 |
+
1 Que que SCONJ cs _ 5 mark 5:mark _
|
| 31 |
+
2 por por ADP sps00 _ 3 case 3:case _
|
| 32 |
+
3 ritmo ritmo NOUN ncms000 Gender=Masc|Number=Sing 5 obl 5:obl ArgTem=argM:adv
|
| 33 |
+
4 no no ADV rn Polarity=Neg 5 advmod 5:advmod _
|
| 34 |
+
5 quede quedar VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root SpaceAfter=No
|
| 35 |
+
6 . . PUNCT fp PunctType=Peri 5 punct 5:punct _
|
| 36 |
+
|
| 37 |
+
# sent_id = 3LB-CAST-t3-4-s23
|
| 38 |
+
# text = Yo esperar cuanto haga falta.
|
| 39 |
+
# orig_file_sentence 027#95
|
| 40 |
+
1 Yo yo PRON pp1csn00 Case=Nom|Number=Sing|Person=1|PronType=Prs 2 obj 2:obj _
|
| 41 |
+
2 esperar esperar VERB vmn0000 VerbForm=Inf 0 root 0:root _
|
| 42 |
+
3 cuanto cuanto PRON pr0ms000 Gender=Masc|Number=Sing|NumType=Card|PronType=Int,Rel 4 nsubj 4:nsubj ArgTem=arg1:tem
|
| 43 |
+
4 haga hacer VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 2 advcl 2:advcl MWE=haga_falta|MWEPOS=VERB
|
| 44 |
+
5 falta falta NOUN _ Gender=Fem|Number=Sing 4 compound 4:compound SpaceAfter=No
|
| 45 |
+
6 . . PUNCT fp PunctType=Peri 2 punct 2:punct _
|
| 46 |
+
|
| 47 |
+
# sent_id = 3LB-CAST-c1-4-s5
|
| 48 |
+
# text = Un soldador que me meta fuego.
|
| 49 |
+
# orig_file_sentence 010#41
|
| 50 |
+
1 Un uno DET di0ms0 Definite=Ind|Gender=Masc|Number=Sing|PronType=Art 2 det 2:det _
|
| 51 |
+
2 soldador soldador NOUN ncms000 Gender=Masc|Number=Sing 0 root 0:root _
|
| 52 |
+
3 que que PRON pr0cn000 PronType=Rel 5 nsubj 5:nsubj ArgTem=arg0:agt
|
| 53 |
+
4 me yo PRON pp1cs000 Case=Dat|Number=Sing|Person=1|PrepCase=Npr|PronType=Prs 5 obl:arg 5:obl:arg ArgTem=arg2:ben
|
| 54 |
+
5 meta meter VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 2 acl 2:acl _
|
| 55 |
+
6 fuego fuego NOUN ncms000 Gender=Masc|Number=Sing 5 obj 5:obj ArgTem=arg1:pat|SpaceAfter=No
|
| 56 |
+
7 . . PUNCT fp PunctType=Peri 2 punct 2:punct _
|
| 57 |
+
|
| 58 |
+
# sent_id = CESS-CAST-P-20010701-38-s94
|
| 59 |
+
# text = Por mucho que me lo pidan.
|
| 60 |
+
# orig_file_sentence 117#2
|
| 61 |
+
1 Por por ADP cs _ 6 mark 6:mark MWE=Por_mucho_que|MWEPOS=SCONJ
|
| 62 |
+
2 mucho mucho ADV _ _ 1 fixed 1:fixed _
|
| 63 |
+
3 que que SCONJ _ _ 1 fixed 1:fixed _
|
| 64 |
+
4 me yo PRON pp1cs000 Case=Dat|Number=Sing|Person=1|PrepCase=Npr|PronType=Prs 6 obl:arg 6:obl:arg ArgTem=arg2:ben|Entity=(CESSCASTP2001070138c1-person-1-CorefType:ident,gstype:spec)
|
| 65 |
+
5 lo él PRON pp3msa00 Case=Acc|Gender=Masc|Number=Sing|Person=3|PrepCase=Npr|PronType=Prs 6 obj 6:obj ArgTem=arg1:pat|Entity=(CESSCASTP2001070138c51--1-CorefType:dx.prop)
|
| 66 |
+
6 pidan pedir VERB vmsp3p0 Mood=Sub|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root SpaceAfter=No
|
| 67 |
+
7 . . PUNCT fp PunctType=Peri 6 punct 6:punct _
|
| 68 |
+
|
| 69 |
+
# sent_id = 3LB-CAST-t6-3-s5
|
| 70 |
+
# text = Tal vez eso fuera lo peor.
|
| 71 |
+
# orig_file_sentence 030#40
|
| 72 |
+
1 Tal tal NOUN rg _ 6 advmod 6:advmod MWE=Tal_vez|MWEPOS=ADV
|
| 73 |
+
2 vez vez NOUN _ _ 1 fixed 1:fixed _
|
| 74 |
+
3 eso ese PRON pd0ns000 Number=Sing|PronType=Dem 6 nsubj 6:nsubj ArgTem=arg1:tem
|
| 75 |
+
4 fuera ser AUX vssi3s0 Mood=Sub|Number=Sing|Person=3|Tense=Imp|VerbForm=Fin 6 cop 6:cop _
|
| 76 |
+
5 lo él PRON da0ns0 Case=Acc|Definite=Def|Gender=Masc|Number=Sing|Person=3|PrepCase=Npr|PronType=Prs 6 det 6:det _
|
| 77 |
+
6 peor peor ADJ aq0cs0 Degree=Cmp|Number=Sing 0 root 0:root ArgTem=arg2:atr|SpaceAfter=No
|
| 78 |
+
7 . . PUNCT fp PunctType=Peri 6 punct 6:punct _
|
| 79 |
+
|
| 80 |
+
# sent_id = 3LB-CAST-n1-3-s13
|
| 81 |
+
# text = La bandera, que no falte.
|
| 82 |
+
# orig_file_sentence 024#5
|
| 83 |
+
1 La el DET da0fs0 Definite=Def|Gender=Fem|Number=Sing|PronType=Art 2 det 2:det _
|
| 84 |
+
2 bandera bandera NOUN ncfs000 Gender=Fem|Number=Sing 6 nsubj 6:nsubj ArgTem=arg1:tem|SpaceAfter=No
|
| 85 |
+
3 , , PUNCT fc PunctType=Comm 2 punct 2:punct _
|
| 86 |
+
4 que que SCONJ cs _ 6 mark 6:mark _
|
| 87 |
+
5 no no ADV rn Polarity=Neg 6 advmod 6:advmod _
|
| 88 |
+
6 falte faltar VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root SpaceAfter=No
|
| 89 |
+
7 . . PUNCT fp PunctType=Peri 6 punct 6:punct _
|
| 90 |
+
|
| 91 |
+
# sent_id = CESS-CAST-P-20020202-102-s11
|
| 92 |
+
# text = Diga lo que diga.
|
| 93 |
+
# orig_file_sentence 076#95
|
| 94 |
+
1 Diga decir VERB vmm03s0 Mood=Imp|Number=Sing|Person=3|VerbForm=Fin 0 root 0:root _
|
| 95 |
+
1.1 _ _ PRON p _ _ _ 1:nsubj ArgTem=arg0:agt|Entity=(CESSCASTP20020202102c2--1-CorefType:ident)|wordform=__EMPTY__
|
| 96 |
+
2 lo él PRON da0ns0 Case=Acc|Definite=Def|Gender=Masc|Number=Sing|Person=3|PrepCase=Npr|PronType=Prs 4 det 4:det _
|
| 97 |
+
3 que que PRON pr0cn000 PronType=Rel 4 obj 4:obj ArgTem=arg1:pat
|
| 98 |
+
4 diga decir VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 ccomp 1:ccomp ArgTem=arg1:pat|SpaceAfter=No
|
| 99 |
+
4.1 _ _ PRON p _ _ _ 4:nsubj ArgTem=arg0:agt|Entity=(CESSCASTP20020202102c2--1-CorefType:ident)|wordform=__EMPTY__
|
| 100 |
+
5 . . PUNCT fp PunctType=Peri 1 punct 1:punct _
|
treetse/__init__.py
ADDED
|
File without changes
|
treetse/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (165 Bytes). View file
|
|
|
treetse/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (161 Bytes). View file
|
|
|
treetse/__pycache__/pipeline.cpython-312.pyc
ADDED
|
Binary file (8.34 kB). View file
|
|
|
treetse/evaluators/__pycache__/evaluator.cpython-312.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
treetse/evaluators/__pycache__/evaluator.cpython-313.pyc
ADDED
|
Binary file (5.35 kB). View file
|
|
|
treetse/evaluators/evaluator.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 2 |
+
from typing import Any, Tuple
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Evaluator:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.mask_token_index: int = -1
|
| 10 |
+
self.mask_probs: torch.Tensor | None = None
|
| 11 |
+
self.tokeniser: Any = None
|
| 12 |
+
self.model: Any = None
|
| 13 |
+
self.logits: torch.Tensor = None
|
| 14 |
+
|
| 15 |
+
def setup_parameters(self, model_name: str) -> Tuple[Any, Any]:
|
| 16 |
+
# Q: what sort of tokenisers are being used?
|
| 17 |
+
self.tokeniser = AutoTokenizer.from_pretrained(model_name)
|
| 18 |
+
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 19 |
+
|
| 20 |
+
# set to eval mode, disabling things like dropout
|
| 21 |
+
self.model.eval()
|
| 22 |
+
|
| 23 |
+
return self.model, self.tokeniser
|
| 24 |
+
|
| 25 |
+
def run_masked_prediction(
|
| 26 |
+
self, model: Any, tokeniser: Any, sentence: str, target_token: str
|
| 27 |
+
) -> Tuple[Any, Any]:
|
| 28 |
+
mask_token = tokeniser.mask_token
|
| 29 |
+
sentence_masked = sentence.replace("[MASK]", mask_token)
|
| 30 |
+
|
| 31 |
+
if sentence_masked.count("[MASK]") != 1:
|
| 32 |
+
raise ValueError("Only single-mask sentences are supported.")
|
| 33 |
+
|
| 34 |
+
inputs = tokeniser(sentence_masked, return_tensors="pt")
|
| 35 |
+
|
| 36 |
+
# Get logits from model
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
outputs = model(**inputs)
|
| 39 |
+
logits = outputs.logits
|
| 40 |
+
self.logits = logits
|
| 41 |
+
|
| 42 |
+
self.mask_token_index = self._get_mask_index(inputs, tokeniser)
|
| 43 |
+
self.mask_probs = self._get_mask_probabilities(
|
| 44 |
+
self.mask_token_index, self.logits
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return self.mask_token_index, self.mask_probs
|
| 48 |
+
|
| 49 |
+
def get_token_prob(self, token: str) -> float:
|
| 50 |
+
target_id = self.tokeniser.convert_tokens_to_ids(token)
|
| 51 |
+
prob = self.get_prob_by_id(target_id)
|
| 52 |
+
return prob
|
| 53 |
+
|
| 54 |
+
def get_top_pred(self) -> dict:
|
| 55 |
+
top_pred_id = int(torch.argmax(self.mask_probs, dim=-1).item())
|
| 56 |
+
top_pred_token = self.tokeniser.convert_ids_to_tokens(top_pred_id)
|
| 57 |
+
top_token_prob = self.get_prob_by_id(top_pred_id)
|
| 58 |
+
return top_pred_token, top_token_prob
|
| 59 |
+
|
| 60 |
+
def get_prob_by_id(self, id: int) -> float:
|
| 61 |
+
if self.mask_probs is not None:
|
| 62 |
+
return self.mask_probs[id].item()
|
| 63 |
+
else:
|
| 64 |
+
raise KeyError("Please evaluate a dataset first. Results empty")
|
| 65 |
+
|
| 66 |
+
def _get_mask_index(self, inputs: Any, tokeniser: Any) -> int:
|
| 67 |
+
if "input_ids" not in inputs:
|
| 68 |
+
raise ValueError("Missing 'input_ids' in inputs.")
|
| 69 |
+
|
| 70 |
+
if tokeniser.mask_token_id is None:
|
| 71 |
+
raise ValueError("The tokeniser does not have a defined mask_token_id.")
|
| 72 |
+
|
| 73 |
+
input_ids = inputs["input_ids"]
|
| 74 |
+
mask_positions = torch.where(input_ids == tokeniser.mask_token_id)
|
| 75 |
+
|
| 76 |
+
if len(mask_positions[0]) == 0:
|
| 77 |
+
raise ValueError("No mask token found in input_ids.")
|
| 78 |
+
|
| 79 |
+
if len(mask_positions[0]) > 1:
|
| 80 |
+
raise ValueError("Multiple mask tokens found; expected only one.")
|
| 81 |
+
|
| 82 |
+
return (
|
| 83 |
+
mask_positions[1].item()
|
| 84 |
+
if len(mask_positions) > 1
|
| 85 |
+
else mask_positions[0].item()
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def _get_mask_probabilities(
|
| 89 |
+
self, mask_token_index: int, logits: Any
|
| 90 |
+
) -> torch.Tensor:
|
| 91 |
+
mask_logits = logits[0, mask_token_index, :] # shape: (1, vocab_size)
|
| 92 |
+
probs = F.softmax(mask_logits, dim=-1) # shape: (1, vocab_size)
|
| 93 |
+
return probs
|
treetse/evaluators/perplexity.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class PerplexityEvaluator:
|
| 2 |
+
def __init__(self) -> None:
|
| 3 |
+
pass
|
| 4 |
+
|
| 5 |
+
def compute_perplexity(self, logits: list) -> list:
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
-- Classic TSE --
|
| 10 |
+
Evaluates based on minimal pairs, where a particular feature
|
| 11 |
+
is chosen and two values of that feature are compared.
|
| 12 |
+
|
| 13 |
+
1. Accepts the inputs, logits, feature name, and feature values as input.
|
| 14 |
+
Finds the lexical items which are the same accept for these values of this
|
| 15 |
+
feature, including in UPOS and lemma.
|
| 16 |
+
2. Computes the perplexity scores for the correct value and the alternative syntactic
|
| 17 |
+
option.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def compute_classic_tse(self) -> None:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
--- Generalised TSE --
|
| 25 |
+
Evaluates based on minimal syntactic pairs, that is, a candidate set is created for the
|
| 26 |
+
correct token as well as the alternate values for that particular features
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def compute_generalised_tse(
|
| 30 |
+
self,
|
| 31 |
+
) -> None:
|
| 32 |
+
pass
|
treetse/pipeline.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from treetse.preprocessing.conllu_parser import ConlluParser
|
| 6 |
+
from treetse.evaluators.evaluator import Evaluator
|
| 7 |
+
from treetse.visualise.visualiser import Visualiser
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(
|
| 10 |
+
level=logging.INFO,
|
| 11 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 12 |
+
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()],
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
class Grewtse:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.parser = ConlluParser()
|
| 18 |
+
self.evaluator = Evaluator()
|
| 19 |
+
self.visualiser = Visualiser()
|
| 20 |
+
|
| 21 |
+
self.treebank_path: str = None
|
| 22 |
+
self.lexical_items: pd.DataFrame = None
|
| 23 |
+
self.masked_dataset: pd.DataFrame = None
|
| 24 |
+
self.exception_dataset: pd.DataFrame = None
|
| 25 |
+
self.evaluation_results: pd.DataFrame = None
|
| 26 |
+
|
| 27 |
+
def parse_treebank(self, filepath: str) -> bool:
|
| 28 |
+
try:
|
| 29 |
+
self.treebank_path = filepath
|
| 30 |
+
self.lexical_items = self.parser._build_lexical_item_dataset(filepath)
|
| 31 |
+
return True
|
| 32 |
+
except Exception as e:
|
| 33 |
+
self.treebank_path = None
|
| 34 |
+
self.lexical_items = None
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def is_treebank_loaded(self) -> bool:
|
| 38 |
+
return self.lexical_items is not None
|
| 39 |
+
|
| 40 |
+
def is_dataset_masked(self) -> bool:
|
| 41 |
+
return self.masked_dataset is not None
|
| 42 |
+
|
| 43 |
+
def get_lexical_items(self) -> pd.DataFrame:
|
| 44 |
+
return self.lexical_items
|
| 45 |
+
|
| 46 |
+
def get_morphological_features(self) -> list:
|
| 47 |
+
if self.lexical_items is None:
|
| 48 |
+
raise ValueError("Cannot get features: You must parse a treebank first.")
|
| 49 |
+
|
| 50 |
+
morph_df = self.lexical_items
|
| 51 |
+
morph_df.columns = [col.replace("feats__", "") if col.startswith("feats__") else col for col in morph_df.columns]
|
| 52 |
+
|
| 53 |
+
return morph_df
|
| 54 |
+
|
| 55 |
+
def generate_masked_dataset(
|
| 56 |
+
self, query: str, target_node: str, mask_token: str = "[MASK]"
|
| 57 |
+
) -> pd.DataFrame:
|
| 58 |
+
if self.treebank_path is None:
|
| 59 |
+
raise ValueError("Cannot create masked dataset: no treebank filepath provided.")
|
| 60 |
+
|
| 61 |
+
results = self.parser._build_masked_dataset_grew(
|
| 62 |
+
self.treebank_path, query, target_node, mask_token
|
| 63 |
+
)
|
| 64 |
+
self.masked_dataset = results['masked']
|
| 65 |
+
self.exception_dataset = results['exception']
|
| 66 |
+
return self.masked_dataset
|
| 67 |
+
|
| 68 |
+
def get_masked_dataset(self) -> pd.DataFrame:
|
| 69 |
+
return self.masked_dataset
|
| 70 |
+
|
| 71 |
+
def generate_minimal_pairs(self, morph_features: dict, upos_features: dict | None) -> pd.DataFrame:
|
| 72 |
+
if self.masked_dataset is None:
|
| 73 |
+
raise ValueError("Cannot generate minimal pairs: treebank must be parsed and masked first.")
|
| 74 |
+
|
| 75 |
+
def convert_row_to_feature(row):
|
| 76 |
+
return self.parser.to_syntactic_feature(
|
| 77 |
+
row['sentence_id'],
|
| 78 |
+
row['match_id']-1,
|
| 79 |
+
morph_features,
|
| 80 |
+
{},
|
| 81 |
+
)
|
| 82 |
+
alternative_row = self.masked_dataset.apply(convert_row_to_feature, axis=1)
|
| 83 |
+
self.masked_dataset['alternative'] = alternative_row
|
| 84 |
+
return self.masked_dataset
|
| 85 |
+
|
| 86 |
+
def are_minimal_pairs_generated(self) -> bool:
|
| 87 |
+
return self.is_treebank_loaded() and \
|
| 88 |
+
self.is_dataset_masked() and \
|
| 89 |
+
('alternative' in self.masked_dataset.columns)
|
| 90 |
+
|
| 91 |
+
def evaluate_bert_mlm(self, model_repo: str, row_limit: int = None) -> pd.DataFrame:
|
| 92 |
+
if self.masked_dataset is None:
|
| 93 |
+
raise ValueError("Cannot evaluate: treebank must be parsed and masked first.")
|
| 94 |
+
|
| 95 |
+
test_model, test_tokeniser = self.evaluator.setup_parameters(model_repo)
|
| 96 |
+
results = []
|
| 97 |
+
|
| 98 |
+
counter = 0
|
| 99 |
+
for row in self.masked_dataset.itertuples():
|
| 100 |
+
masked_sentence = row.masked_text
|
| 101 |
+
label = row.match_token
|
| 102 |
+
alternative_form = row.alternative
|
| 103 |
+
|
| 104 |
+
row_results = {
|
| 105 |
+
"sentence_id": row.sentence_id,
|
| 106 |
+
"token_id": row.match_id,
|
| 107 |
+
"masked_sentence": masked_sentence,
|
| 108 |
+
"num_tokens": row.num_tokens,
|
| 109 |
+
"label": label,
|
| 110 |
+
"label_prob": None,
|
| 111 |
+
"alternative": alternative_form,
|
| 112 |
+
"alternative_prob": None,
|
| 113 |
+
"top_pred_label": None,
|
| 114 |
+
"top_pred_prob": None,
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
self.evaluator.run_masked_prediction(
|
| 119 |
+
test_model, test_tokeniser, masked_sentence, label
|
| 120 |
+
)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
raise Exception("There was an issue with the model or tokeniser")
|
| 123 |
+
|
| 124 |
+
# -- LABEL PROB --
|
| 125 |
+
label_prob = self.evaluator.get_token_prob(label)
|
| 126 |
+
row_results["label_prob"] = label_prob
|
| 127 |
+
|
| 128 |
+
# -- ALTERNATIVE FORM --
|
| 129 |
+
if alternative_form:
|
| 130 |
+
logging.info("----")
|
| 131 |
+
logging.info(f"Label Form: {label}")
|
| 132 |
+
logging.info(f"Alternative Form: {alternative_form}")
|
| 133 |
+
logging.info("----")
|
| 134 |
+
|
| 135 |
+
alt_form_prob = self.evaluator.get_token_prob(alternative_form)
|
| 136 |
+
row_results["alternative_prob"] = alt_form_prob
|
| 137 |
+
|
| 138 |
+
# -- HIGHEST PROB --
|
| 139 |
+
top_pred_label, top_pred_prob = self.evaluator.get_top_pred()
|
| 140 |
+
row_results["top_pred_label"] = top_pred_label
|
| 141 |
+
row_results["top_pred_prob"] = top_pred_prob
|
| 142 |
+
|
| 143 |
+
results.append(row_results)
|
| 144 |
+
|
| 145 |
+
if row_limit:
|
| 146 |
+
counter += 1
|
| 147 |
+
if counter == row_limit:
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
results_df = pd.DataFrame(results)
|
| 151 |
+
self.evaluation_dataset = results_df
|
| 152 |
+
return results_df
|
| 153 |
+
|
| 154 |
+
def visualise_syntactic_performance(
|
| 155 |
+
self,
|
| 156 |
+
filename: str,
|
| 157 |
+
results: pd.DataFrame,
|
| 158 |
+
target_x_label: str,
|
| 159 |
+
alt_x_label: str,
|
| 160 |
+
x_axis_label: str,
|
| 161 |
+
y_axis_label: str,
|
| 162 |
+
title: str,
|
| 163 |
+
) -> None:
|
| 164 |
+
visualiser = Visualiser()
|
| 165 |
+
visualiser.visualise_slope(
|
| 166 |
+
filename,
|
| 167 |
+
results,
|
| 168 |
+
target_x_label,
|
| 169 |
+
alt_x_label,
|
| 170 |
+
x_axis_label,
|
| 171 |
+
y_axis_label,
|
| 172 |
+
title,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
"""
|
| 178 |
+
def store_results(
|
| 179 |
+
results_filename: str,
|
| 180 |
+
li_set_filename: str,
|
| 181 |
+
model_results: pd.DataFrame,
|
| 182 |
+
li_set: pd.DataFrame,
|
| 183 |
+
):
|
| 184 |
+
try:
|
| 185 |
+
model_results.to_csv(base_dir / "output" / results_filename, index=False)
|
| 186 |
+
li_set.to_csv(base_dir / li_set_filename, index=True)
|
| 187 |
+
|
| 188 |
+
model_results["difference"] = (
|
| 189 |
+
model_results["label_prob"] - model_results["alternative_prob"]
|
| 190 |
+
)
|
| 191 |
+
model_results = model_results.sort_values("difference")
|
| 192 |
+
model_results.dropna().to_csv(
|
| 193 |
+
base_dir / "output" / f"filtered_{results_filename}", index=False
|
| 194 |
+
)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logging.error(f"Failed to output to CSV: {e}")
|
| 197 |
+
raise
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
|
treetse/preprocessing/__init__.py
ADDED
|
File without changes
|
treetse/preprocessing/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (179 Bytes). View file
|
|
|
treetse/preprocessing/__pycache__/conllu_parser.cpython-312.pyc
ADDED
|
Binary file (16.4 kB). View file
|
|
|
treetse/preprocessing/__pycache__/conllu_parser.cpython-313.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
treetse/preprocessing/__pycache__/grew_dependencies.cpython-312.pyc
ADDED
|
Binary file (1.05 kB). View file
|
|
|
treetse/preprocessing/__pycache__/reconstruction.cpython-312.pyc
ADDED
|
Binary file (2.33 kB). View file
|
|
|
treetse/preprocessing/__pycache__/reconstruction.cpython-313.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
treetse/preprocessing/conllu_parser.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from treetse.preprocessing.grew_dependencies import match_dependencies
|
| 2 |
+
from treetse.preprocessing.reconstruction import Lexer
|
| 3 |
+
from conllu import parse_incr, Token
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
def test_function():
|
| 11 |
+
return True
|
| 12 |
+
|
| 13 |
+
class ConlluParser:
|
| 14 |
+
def __init__(self) -> None:
|
| 15 |
+
self.li_feature_set: pd.DataFrame = None
|
| 16 |
+
self.masked_dataset: pd.DataFrame = None
|
| 17 |
+
self.exception_dataset: pd.DataFrame = None
|
| 18 |
+
self.lexer: Lexer = Lexer()
|
| 19 |
+
|
| 20 |
+
# todo: add error handling here
|
| 21 |
+
def parse_grew(
|
| 22 |
+
self, path: str, grew_query: str, grew_variable_to_mask: str, mask_token: str = "[MASK]"
|
| 23 |
+
) -> bool:
|
| 24 |
+
self.li_feature_set = self._build_lexical_item_dataset(path)
|
| 25 |
+
|
| 26 |
+
masking_results = self._build_masked_dataset_grew(
|
| 27 |
+
path, grew_query, grew_variable_to_mask, mask_token
|
| 28 |
+
)
|
| 29 |
+
self.masked_dataset = masking_results["masked"]
|
| 30 |
+
self.exception_dataset = masking_results["exception"]
|
| 31 |
+
|
| 32 |
+
return self.masked_dataset, self.exception_dataset
|
| 33 |
+
|
| 34 |
+
# todo: add error handling here
|
| 35 |
+
def parse(
|
| 36 |
+
self, path: str, morphological_constraints: dict, universal_constraints: dict, mask_token: str = "[MASK]"
|
| 37 |
+
) -> bool:
|
| 38 |
+
self.li_feature_set = self._build_lexical_item_dataset(path)
|
| 39 |
+
|
| 40 |
+
upos_constraint = universal_constraints["upos"] if "upos" in universal_constraints else None
|
| 41 |
+
|
| 42 |
+
masking_results = self._build_masked_dataset(
|
| 43 |
+
path, morphological_constraints, upos_constraint, mask_token
|
| 44 |
+
)
|
| 45 |
+
self.masked_dataset = masking_results["masked"]
|
| 46 |
+
self.exception_dataset = masking_results["exception"]
|
| 47 |
+
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
def get_masked_dataset(self) -> pd.DataFrame:
|
| 51 |
+
return self.masked_dataset
|
| 52 |
+
|
| 53 |
+
def get_lexical_item_dataset(self) -> pd.DataFrame:
|
| 54 |
+
return self.li_feature_set
|
| 55 |
+
|
| 56 |
+
# this shouldn't be hard coded
|
| 57 |
+
def get_feature_names(self) -> list:
|
| 58 |
+
return self.li_feature_set.columns[4:].to_list()
|
| 59 |
+
|
| 60 |
+
# todo: add more safety
|
| 61 |
+
def get_features(self, sentence_id: str, token_id: int) -> dict:
|
| 62 |
+
print(sentence_id)
|
| 63 |
+
print(token_id)
|
| 64 |
+
print(self.li_feature_set.index)
|
| 65 |
+
return self.li_feature_set.loc[(sentence_id, token_id)][self.get_feature_names()].to_dict()
|
| 66 |
+
|
| 67 |
+
def get_lemma(self, sentence_id: str, token_id: str) -> str:
|
| 68 |
+
return self.li_feature_set.loc[(sentence_id, token_id)]["lemma"]
|
| 69 |
+
|
| 70 |
+
# todo: handle making sure that it is the exact same as the lemma
|
| 71 |
+
def to_syntactic_feature(self, sentence_id: str, token_id: str, alt_morph_constraints: dict, alt_universal_constraints: dict) -> str | None:
|
| 72 |
+
|
| 73 |
+
# distinguish morphological from universal features
|
| 74 |
+
# todo: find a better way to do this
|
| 75 |
+
# prefix = 'feats__'
|
| 76 |
+
prefix = ''
|
| 77 |
+
alt_morph_constraints = {prefix + key: value for key, value in alt_morph_constraints.items()}
|
| 78 |
+
|
| 79 |
+
token_features = self.get_features(sentence_id, token_id)
|
| 80 |
+
|
| 81 |
+
token_features.update(alt_morph_constraints)
|
| 82 |
+
token_features.update(alt_universal_constraints)
|
| 83 |
+
lexical_items = self.li_feature_set
|
| 84 |
+
|
| 85 |
+
# get only those items which are the same lemma
|
| 86 |
+
lemma = self.get_lemma(sentence_id, token_id)
|
| 87 |
+
lemma_mask = lexical_items['lemma'] == lemma
|
| 88 |
+
lexical_items = lexical_items[lemma_mask]
|
| 89 |
+
logging.info(f"Looking for form {lemma}")
|
| 90 |
+
logging.info(lexical_items)
|
| 91 |
+
print(token_features.items())
|
| 92 |
+
|
| 93 |
+
for feat, value in token_features.items():
|
| 94 |
+
# ensure feature is a valid feature in feature set
|
| 95 |
+
if feat not in lexical_items.columns:
|
| 96 |
+
raise KeyError(
|
| 97 |
+
"Invalid feature provided to confound set: {}".format(feat)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# slim the mask down using each feature
|
| 101 |
+
# interesting edge case: np.nan == np.nan returns false!
|
| 102 |
+
mask = (lexical_items[feat] == value) | (lexical_items[feat].isna() & pd.isna(value))
|
| 103 |
+
lexical_items = lexical_items[mask]
|
| 104 |
+
|
| 105 |
+
if len(lexical_items) > 0:
|
| 106 |
+
return lexical_items["form"].iloc[0]
|
| 107 |
+
else:
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
def get_candidate_set(self, universal_constraints: dict, morph_constraints: dict) -> pd.DataFrame:
|
| 111 |
+
has_parsed_conllu = self.li_feature_set is not None
|
| 112 |
+
if not has_parsed_conllu:
|
| 113 |
+
raise ValueError("Please parse a ConLLU file first.")
|
| 114 |
+
|
| 115 |
+
morph_constraints = {f"feats__{k}": v for k, v in morph_constraints.items()}
|
| 116 |
+
are_morph_features_valid = all(
|
| 117 |
+
f in self.li_feature_set.columns for f in morph_constraints.keys()
|
| 118 |
+
)
|
| 119 |
+
are_universal_features_valid = all(
|
| 120 |
+
f in self.li_feature_set.columns for f in universal_constraints.keys()
|
| 121 |
+
)
|
| 122 |
+
if not are_morph_features_valid or not are_universal_features_valid:
|
| 123 |
+
raise KeyError(
|
| 124 |
+
"Features provided for candidate set are not valid features in the dataset."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
all_constraints = {**universal_constraints, **morph_constraints}
|
| 128 |
+
candidate_set = self._construct_candidate_set(
|
| 129 |
+
self.li_feature_set, all_constraints
|
| 130 |
+
)
|
| 131 |
+
return candidate_set
|
| 132 |
+
|
| 133 |
+
def _build_masked_dataset_grew(self, filepath: Path, grew_query: str, dependency_node: str,
|
| 134 |
+
mask_token, encoding: str = "utf-8"):
|
| 135 |
+
masked_dataset = []
|
| 136 |
+
exception_dataset = []
|
| 137 |
+
|
| 138 |
+
get_tokens_to_mask = match_dependencies(filepath, grew_query, dependency_node)
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
with open(filepath, "r", encoding=encoding) as data_file:
|
| 142 |
+
for sentence in parse_incr(data_file):
|
| 143 |
+
logging.info(f"Processing sentence: {sentence.metadata["sent_id"]}")
|
| 144 |
+
|
| 145 |
+
sentence_id = sentence.metadata["sent_id"]
|
| 146 |
+
sentence_text = sentence.metadata["text"]
|
| 147 |
+
if sentence_id in get_tokens_to_mask:
|
| 148 |
+
for i in range(len(sentence)):
|
| 149 |
+
sentence[i]["index"] = i
|
| 150 |
+
|
| 151 |
+
token_to_mask_id = get_tokens_to_mask[sentence_id]
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
t_match = [tok for tok in sentence if tok.get("id") == token_to_mask_id][0]
|
| 155 |
+
t_match_form = t_match["form"]
|
| 156 |
+
t_match_index = t_match["index"]
|
| 157 |
+
sentence_as_str_list = [t["form"] for t in sentence]
|
| 158 |
+
except KeyError:
|
| 159 |
+
logging.info("There was a mismatch for the GREW-based ID and the Conllu ID.")
|
| 160 |
+
exception_dataset.append(
|
| 161 |
+
{
|
| 162 |
+
"sentence_id": sentence_id,
|
| 163 |
+
"match_id": None,
|
| 164 |
+
"all_tokens": None,
|
| 165 |
+
"match_token": None,
|
| 166 |
+
"original_text": sentence_text,
|
| 167 |
+
}
|
| 168 |
+
)
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
matched_token_start_index = self.lexer.recursive_match_token(
|
| 173 |
+
sentence_text, # the original string
|
| 174 |
+
sentence_as_str_list.copy(), # the string as a list of tokens
|
| 175 |
+
t_match_index, # the index of the token to be replaced
|
| 176 |
+
[
|
| 177 |
+
"_",
|
| 178 |
+
" ",
|
| 179 |
+
], # todo: skip lines where we don't encounter accounted for tokens
|
| 180 |
+
)
|
| 181 |
+
except ValueError:
|
| 182 |
+
print("Token not found. Saving as exception.")
|
| 183 |
+
exception_dataset.append(
|
| 184 |
+
{
|
| 185 |
+
"sentence_id": sentence_id,
|
| 186 |
+
"match_id": token_to_mask_id,
|
| 187 |
+
"all_tokens": sentence_as_str_list,
|
| 188 |
+
"match_token": t_match_form,
|
| 189 |
+
"original_text": sentence_text,
|
| 190 |
+
}
|
| 191 |
+
)
|
| 192 |
+
continue
|
| 193 |
+
|
| 194 |
+
# let's replace the matched token with a MASK token
|
| 195 |
+
masked_sentence = self.lexer.perform_token_surgery(
|
| 196 |
+
sentence_text,
|
| 197 |
+
t_match_form,
|
| 198 |
+
mask_token,
|
| 199 |
+
matched_token_start_index,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# the sentence ID and match ID are together a primary key
|
| 203 |
+
masked_dataset.append(
|
| 204 |
+
{
|
| 205 |
+
"sentence_id": sentence_id,
|
| 206 |
+
"match_id": token_to_mask_id,
|
| 207 |
+
"all_tokens": sentence_as_str_list,
|
| 208 |
+
"num_tokens": len(sentence_as_str_list),
|
| 209 |
+
"match_token": t_match_form,
|
| 210 |
+
"original_text": sentence_text,
|
| 211 |
+
"masked_text": masked_sentence,
|
| 212 |
+
}
|
| 213 |
+
)
|
| 214 |
+
except FileNotFoundError:
|
| 215 |
+
print(f"Error: The file '{filepath}' was not found.")
|
| 216 |
+
|
| 217 |
+
masked_dataset_df = pd.DataFrame(masked_dataset)
|
| 218 |
+
exception_dataset_df = pd.DataFrame(exception_dataset)
|
| 219 |
+
|
| 220 |
+
return {"masked": masked_dataset_df, "exception": exception_dataset_df}
|
| 221 |
+
|
| 222 |
+
def _build_masked_dataset(
|
| 223 |
+
self, filepath: str, morph_constraints: dict, upos_constraint: str | None, mask_token: str, encoding: str = "utf-8"
|
| 224 |
+
) -> dict[str, pd.DataFrame]:
|
| 225 |
+
masked_dataset = []
|
| 226 |
+
exception_dataset = []
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
with open(filepath, "r", encoding=encoding) as data_file:
|
| 230 |
+
constraints_kwargs = {f"feats__{k.capitalize()}": v for k, v in morph_constraints.items()}
|
| 231 |
+
|
| 232 |
+
for sentence in parse_incr(data_file):
|
| 233 |
+
logging.info(f"Processing sentence: {sentence.metadata["sent_id"]}")
|
| 234 |
+
|
| 235 |
+
# MORPHOLOGICAL FILTER
|
| 236 |
+
token_constraint_matches = sentence.filter(**constraints_kwargs)
|
| 237 |
+
|
| 238 |
+
# UNIVERSAL POS FILTER
|
| 239 |
+
if upos_constraint:
|
| 240 |
+
token_constraint_matches = sentence.filter(lambda token: token.upos == upos_constraint)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if token_constraint_matches:
|
| 244 |
+
for i in range(len(sentence)):
|
| 245 |
+
sentence[i]["index"] = i
|
| 246 |
+
|
| 247 |
+
# sentence_text = " ".join(token["form"] for token in sentence)
|
| 248 |
+
sentence_text = sentence.metadata["text"]
|
| 249 |
+
sentence_id = sentence.metadata["sent_id"]
|
| 250 |
+
|
| 251 |
+
matches = [t["form"] for t in token_constraint_matches]
|
| 252 |
+
match_indices = [t["index"] for t in token_constraint_matches]
|
| 253 |
+
|
| 254 |
+
# iterate over each match in the sentence
|
| 255 |
+
for t_match_index, t_match in zip(match_indices, matches):
|
| 256 |
+
# we want to create one sentence entry per example
|
| 257 |
+
# so if we have two subjunctive's in one sentence for instance,
|
| 258 |
+
# there will be two test sentences
|
| 259 |
+
|
| 260 |
+
# at what point in the string does the matched token start?
|
| 261 |
+
sentence_as_str_list = [t["form"] for t in sentence]
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
matched_token_start_index = self.lexer.recursive_match_token(
|
| 265 |
+
sentence_text,
|
| 266 |
+
sentence_as_str_list.copy(),
|
| 267 |
+
t_match_index,
|
| 268 |
+
[
|
| 269 |
+
"_",
|
| 270 |
+
" ",
|
| 271 |
+
], # todo: skip lines where we don't encounter accounted for tokens
|
| 272 |
+
)
|
| 273 |
+
except ValueError:
|
| 274 |
+
print("Token not found. Saving as exception.")
|
| 275 |
+
exception_dataset.append(
|
| 276 |
+
{
|
| 277 |
+
"sentence_id": sentence_id,
|
| 278 |
+
"match_id": t_match_index,
|
| 279 |
+
"all_tokens": sentence_as_str_list,
|
| 280 |
+
"match_token": t_match,
|
| 281 |
+
"original_text": sentence_text,
|
| 282 |
+
}
|
| 283 |
+
)
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
# let's replace the matched token with a MASK token
|
| 287 |
+
masked_sentence = self.lexer.perform_token_surgery(
|
| 288 |
+
sentence_text,
|
| 289 |
+
t_match,
|
| 290 |
+
mask_token,
|
| 291 |
+
matched_token_start_index,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# the sentence ID and match ID are together a primary key
|
| 295 |
+
masked_dataset.append(
|
| 296 |
+
{
|
| 297 |
+
"sentence_id": sentence_id,
|
| 298 |
+
"match_id": t_match_index,
|
| 299 |
+
"all_tokens": sentence_as_str_list,
|
| 300 |
+
"num_tokens": len(sentence_as_str_list),
|
| 301 |
+
"match_token": t_match,
|
| 302 |
+
"original_text": sentence_text,
|
| 303 |
+
"masked_text": masked_sentence,
|
| 304 |
+
}
|
| 305 |
+
)
|
| 306 |
+
except FileNotFoundError:
|
| 307 |
+
print(f"Error: The file '{filepath}' was not found.")
|
| 308 |
+
|
| 309 |
+
masked_dataset_df = pd.DataFrame(masked_dataset)
|
| 310 |
+
exception_dataset_df = pd.DataFrame(exception_dataset)
|
| 311 |
+
|
| 312 |
+
return {"masked": masked_dataset_df, "exception": exception_dataset_df}
|
| 313 |
+
|
| 314 |
+
def _is_valid_token(self, token: Token) -> bool:
|
| 315 |
+
punctuation = [".", ",", "!", "?", "*"]
|
| 316 |
+
|
| 317 |
+
# skip multiword tokens, malformed entries and punctuation
|
| 318 |
+
is_punctuation = token.get("form") in punctuation
|
| 319 |
+
is_valid_type = isinstance(token, dict)
|
| 320 |
+
has_valid_id = isinstance(token.get("id"), int)
|
| 321 |
+
return is_valid_type and has_valid_id and not is_punctuation
|
| 322 |
+
|
| 323 |
+
def _build_token_row(self, token: Token, sentence_id: str) -> dict[str, Any]:
|
| 324 |
+
# get all token features such as Person, Mood, etc
|
| 325 |
+
feats = token.get("feats") or {}
|
| 326 |
+
|
| 327 |
+
row = {
|
| 328 |
+
"sentence_id": sentence_id,
|
| 329 |
+
"token_id": token.get("id") - 1, # ID's are reduced by one to start at 0
|
| 330 |
+
"form": token.get("form"),
|
| 331 |
+
"lemma": token.get("lemma"),
|
| 332 |
+
"upos": token.get("upos"),
|
| 333 |
+
"xpos": token.get("xpos"),
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
# add each morphological feature as a column
|
| 337 |
+
for feat_name, feat_value in feats.items():
|
| 338 |
+
row["feats__" + feat_name.lower()] = feat_value
|
| 339 |
+
|
| 340 |
+
return row
|
| 341 |
+
|
| 342 |
+
def _build_lexical_item_dataset(self, conllu_path: str) -> pd.DataFrame:
|
| 343 |
+
rows = []
|
| 344 |
+
|
| 345 |
+
with open(conllu_path, "r", encoding="utf-8") as f:
|
| 346 |
+
for i, tokenlist in enumerate(parse_incr(f)):
|
| 347 |
+
# get the sentence ID in the dataset
|
| 348 |
+
sent_id = tokenlist.metadata["sent_id"]
|
| 349 |
+
logging.info(f"Building LI Set For Sentence: {sent_id}")
|
| 350 |
+
|
| 351 |
+
# iterate over each token
|
| 352 |
+
for token in tokenlist:
|
| 353 |
+
# check if it's worth saving to our lexical item dataset
|
| 354 |
+
is_valid_token = self._is_valid_token(token)
|
| 355 |
+
if not is_valid_token:
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
# from the token object create a dict and append
|
| 359 |
+
row = self._build_token_row(token, sent_id)
|
| 360 |
+
rows.append(row)
|
| 361 |
+
|
| 362 |
+
lexical_item_df = pd.DataFrame(rows)
|
| 363 |
+
|
| 364 |
+
# make sure our nan values are interpreted as such
|
| 365 |
+
lexical_item_df.replace("nan", np.nan, inplace=True)
|
| 366 |
+
|
| 367 |
+
# create the (Sentence ID, Token ID) primary key
|
| 368 |
+
lexical_item_df.set_index(["sentence_id", "token_id"], inplace=True)
|
| 369 |
+
|
| 370 |
+
self.li_feature_set = lexical_item_df
|
| 371 |
+
|
| 372 |
+
return lexical_item_df
|
| 373 |
+
|
| 374 |
+
"""
|
| 375 |
+
-- Candidate Set --
|
| 376 |
+
This constructs a list of words which have the same feature set as the
|
| 377 |
+
target features which are passed as an argument.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def _construct_candidate_set(
|
| 381 |
+
self, li_feature_set: pd.DataFrame, target_features: dict
|
| 382 |
+
) -> pd.DataFrame:
|
| 383 |
+
# optionally restrict search to a certain type of lexical item
|
| 384 |
+
subset = li_feature_set
|
| 385 |
+
|
| 386 |
+
# continuously filter the dataframe so as to be left
|
| 387 |
+
# only with those lexical items which match the target
|
| 388 |
+
# features
|
| 389 |
+
# this includes cases
|
| 390 |
+
for feat, value in target_features.items():
|
| 391 |
+
# ensure feature is a valid feature in feature set
|
| 392 |
+
if feat not in subset.columns:
|
| 393 |
+
raise KeyError(
|
| 394 |
+
"Invalid feature provided to confound set: {}".format(feat)
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# slim the mask down using each feature
|
| 398 |
+
# interesting edge case: np.nan == np.nan returns false!
|
| 399 |
+
mask = (subset[feat] == value) | (subset[feat].isna() & pd.isna(value))
|
| 400 |
+
subset = subset[mask]
|
| 401 |
+
|
| 402 |
+
return subset
|
treetse/preprocessing/grew_dependencies.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from grewpy import Corpus, Request, set_config
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def match_dependencies(
|
| 6 |
+
corpus_path: Path, grew_query: str, dependency_node: str
|
| 7 |
+
) -> dict:
|
| 8 |
+
set_config("sud") # ud or basic
|
| 9 |
+
# run the GREW request on the corpus
|
| 10 |
+
corpus = Corpus(str(corpus_path))
|
| 11 |
+
request = Request().pattern(grew_query)
|
| 12 |
+
occurrences = corpus.search(request)
|
| 13 |
+
|
| 14 |
+
# step 2
|
| 15 |
+
dep_matches = {}
|
| 16 |
+
for occ in occurrences:
|
| 17 |
+
sent_id = occ["sent_id"]
|
| 18 |
+
object_node_id = int(occ["matching"]["nodes"][dependency_node])
|
| 19 |
+
dep_matches[sent_id] = object_node_id
|
| 20 |
+
return dep_matches
|
treetse/preprocessing/reconstruction.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Lexer:
|
| 2 |
+
def __init__(self) -> None:
|
| 3 |
+
pass
|
| 4 |
+
|
| 5 |
+
def perform_token_surgery(
|
| 6 |
+
self,
|
| 7 |
+
sentence: str,
|
| 8 |
+
original_token: str,
|
| 9 |
+
replacement_token: str,
|
| 10 |
+
start_index: int,
|
| 11 |
+
) -> str:
|
| 12 |
+
t_len = len(original_token)
|
| 13 |
+
|
| 14 |
+
return (
|
| 15 |
+
sentence[:start_index] + replacement_token + sentence[start_index + t_len :]
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def recursive_match_token(
|
| 19 |
+
self,
|
| 20 |
+
full_sentence: str,
|
| 21 |
+
token_list: list[str],
|
| 22 |
+
token_list_mask_index: int,
|
| 23 |
+
skippable_tokens: list[str],
|
| 24 |
+
) -> int:
|
| 25 |
+
# ensure we can retrieve another token
|
| 26 |
+
n_remaining_tokens = len(token_list)
|
| 27 |
+
if n_remaining_tokens == 0:
|
| 28 |
+
raise ValueError(
|
| 29 |
+
"Mask index not reached but token list has been iterated for sentence: {}".format(
|
| 30 |
+
full_sentence
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
t = token_list[0]
|
| 34 |
+
|
| 35 |
+
# returns the index of the first occurrence
|
| 36 |
+
# of the token t
|
| 37 |
+
match_index = full_sentence.find(t)
|
| 38 |
+
is_match_found = match_index != -1
|
| 39 |
+
has_reached_mask_token = token_list_mask_index == 0
|
| 40 |
+
|
| 41 |
+
# BASE CASE
|
| 42 |
+
if has_reached_mask_token and is_match_found:
|
| 43 |
+
# we're at the end
|
| 44 |
+
return match_index
|
| 45 |
+
# RECURSIVE CASE
|
| 46 |
+
elif is_match_found:
|
| 47 |
+
sliced_sentence = full_sentence[match_index + len(t) :]
|
| 48 |
+
token_list.pop(0)
|
| 49 |
+
|
| 50 |
+
return (
|
| 51 |
+
match_index
|
| 52 |
+
+ len(t)
|
| 53 |
+
+ self.recursive_match_token(
|
| 54 |
+
sliced_sentence,
|
| 55 |
+
token_list,
|
| 56 |
+
token_list_mask_index - 1,
|
| 57 |
+
skippable_tokens,
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
# no match found, is t irrelevant?
|
| 62 |
+
if t in skippable_tokens:
|
| 63 |
+
# need to watch out with the slicing here
|
| 64 |
+
# tests are important
|
| 65 |
+
sliced_sentence = full_sentence[len(t) - 1 :]
|
| 66 |
+
token_list.pop(0)
|
| 67 |
+
return self.recursive_match_token(
|
| 68 |
+
sliced_sentence,
|
| 69 |
+
token_list,
|
| 70 |
+
token_list_mask_index - 1,
|
| 71 |
+
skippable_tokens,
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
"Token not found in string nor has it been specified as skippable: {}".format(
|
| 76 |
+
t
|
| 77 |
+
)
|
| 78 |
+
)
|
treetse/visualise/__pycache__/visualiser.cpython-312.pyc
ADDED
|
Binary file (4.78 kB). View file
|
|
|
treetse/visualise/__pycache__/visualiser.cpython-313.pyc
ADDED
|
Binary file (4.4 kB). View file
|
|
|
treetse/visualise/visualiser.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from plotnine import labs, theme, theme_bw, guides, position_nudge, aes, geom_violin, geom_boxplot, geom_line, geom_jitter, scale_x_discrete, ggplot
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Visualiser:
|
| 8 |
+
def __init__(self) -> None:
|
| 9 |
+
self.data = None
|
| 10 |
+
|
| 11 |
+
def load_dataset(self, results: pd.DataFrame) -> bool:
|
| 12 |
+
self.data = results
|
| 13 |
+
|
| 14 |
+
def visualise_slope(
|
| 15 |
+
self,
|
| 16 |
+
path: Path,
|
| 17 |
+
results: pd.DataFrame,
|
| 18 |
+
target_x_label: str,
|
| 19 |
+
alt_x_label: str,
|
| 20 |
+
x_axis_label: str,
|
| 21 |
+
y_axis_label: str,
|
| 22 |
+
title: str,
|
| 23 |
+
):
|
| 24 |
+
lsize = 0.65
|
| 25 |
+
fill_alpha = 0.7
|
| 26 |
+
|
| 27 |
+
# X-axis: Acc, Gen
|
| 28 |
+
# Y-axis: surprisal
|
| 29 |
+
filtered_df = results[
|
| 30 |
+
results["alternative"].notna() & (results["alternative"].str.strip() != "")
|
| 31 |
+
]
|
| 32 |
+
print("Number of filtered results: ", len(filtered_df))
|
| 33 |
+
print(filtered_df.head())
|
| 34 |
+
|
| 35 |
+
filtered_df["subject_id"] = filtered_df.index
|
| 36 |
+
|
| 37 |
+
print(filtered_df.head())
|
| 38 |
+
|
| 39 |
+
# Melt the dataframe
|
| 40 |
+
df_long = pd.melt(
|
| 41 |
+
filtered_df,
|
| 42 |
+
id_vars=["subject_id", "num_tokens"],
|
| 43 |
+
value_vars=["label_prob", "alternative_prob"],
|
| 44 |
+
var_name="source",
|
| 45 |
+
value_name="log_prob",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Map source to fixed x-axis labels
|
| 49 |
+
df_long["x_label"] = df_long["source"].map(
|
| 50 |
+
{"label_prob": target_x_label, "alternative_prob": alt_x_label}
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
print(df_long.head())
|
| 54 |
+
|
| 55 |
+
def surprisal(p: float) -> float:
|
| 56 |
+
return -math.log2(p)
|
| 57 |
+
|
| 58 |
+
def confidence(p: float) -> float:
|
| 59 |
+
return math.log2(p)
|
| 60 |
+
|
| 61 |
+
df_long["surprisal"] = df_long["log_prob"].apply(confidence)
|
| 62 |
+
print(df_long.head())
|
| 63 |
+
|
| 64 |
+
p = (
|
| 65 |
+
ggplot(df_long, aes(x="x_label", y="surprisal", fill="x_label"))
|
| 66 |
+
+ scale_x_discrete(limits=[target_x_label, alt_x_label])
|
| 67 |
+
+ geom_jitter(
|
| 68 |
+
aes(color="x_label", size="num_tokens"), width=0.01, alpha=0.7
|
| 69 |
+
)
|
| 70 |
+
+
|
| 71 |
+
# geom_text(aes(label='label'), nudge_y=0.1) +
|
| 72 |
+
geom_line(aes(group="subject_id"), color="gray", alpha=0.7, size=0.2)
|
| 73 |
+
+ geom_boxplot(
|
| 74 |
+
df_long[df_long["x_label"] == target_x_label],
|
| 75 |
+
aes(x="x_label", y="surprisal", group="x_label"),
|
| 76 |
+
width=0.2,
|
| 77 |
+
alpha=0.4,
|
| 78 |
+
size=0.6,
|
| 79 |
+
outlier_shape=None,
|
| 80 |
+
show_legend=False,
|
| 81 |
+
position=position_nudge(x=-0.2),
|
| 82 |
+
)
|
| 83 |
+
+ geom_boxplot(
|
| 84 |
+
df_long[df_long["x_label"] == alt_x_label],
|
| 85 |
+
aes(x="x_label", y="surprisal", group="x_label"),
|
| 86 |
+
width=0.2,
|
| 87 |
+
alpha=0.4,
|
| 88 |
+
size=0.6,
|
| 89 |
+
outlier_shape=None,
|
| 90 |
+
show_legend=False,
|
| 91 |
+
position=position_nudge(x=0.2),
|
| 92 |
+
)
|
| 93 |
+
+ geom_violin(
|
| 94 |
+
df_long[df_long["x_label"] == target_x_label],
|
| 95 |
+
aes(x="x_label", y="surprisal", group="x_label"),
|
| 96 |
+
position=position_nudge(x=-0.4),
|
| 97 |
+
style="left-right",
|
| 98 |
+
alpha=fill_alpha,
|
| 99 |
+
size=lsize,
|
| 100 |
+
)
|
| 101 |
+
+ geom_violin(
|
| 102 |
+
df_long[df_long["x_label"] == alt_x_label],
|
| 103 |
+
aes(x="x_label", y="surprisal", group="x_label"),
|
| 104 |
+
position=position_nudge(x=0.4),
|
| 105 |
+
style="right-left",
|
| 106 |
+
alpha=fill_alpha,
|
| 107 |
+
size=lsize,
|
| 108 |
+
)
|
| 109 |
+
+ guides(fill=False)
|
| 110 |
+
+ theme_bw()
|
| 111 |
+
+ theme(figure_size=(8, 4), legend_position="none")
|
| 112 |
+
+ labs(x=x_axis_label, y=y_axis_label, title=title)
|
| 113 |
+
)
|
| 114 |
+
p.save(path, width=14, height=8, dpi=300)
|