DanielGallagherIRE commited on
Commit
586650a
·
1 Parent(s): 698e6dd

upload files

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.requirements.txt.un~ ADDED
Binary file (947 Bytes). View file
 
app.py CHANGED
@@ -1,7 +1,199 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import tempfile
4
+ import ast
5
+ import sys
6
+ import os
7
+ from PIL import Image
8
 
9
+ # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+ from treetse.pipeline import Grewtse
11
 
12
+ grewtse = Grewtse()
13
+ treebank_path = None
14
+
15
+ def parse_treebank(path: str, treebank_selection: str) -> pd.DataFrame:
16
+ if treebank_selection == "None":
17
+ successful_treebank_parse = grewtse.parse_treebank(path)
18
+ treebank_path = path
19
+ else:
20
+ successful_treebank_parse = grewtse.parse_treebank(treebank_selection)
21
+ treebank_path = treebank_selection
22
+
23
+ print("changing treebank parse success")
24
+ is_treebank_parse_success = True
25
+ return grewtse.get_morphological_features().head()
26
+
27
+ def to_masked_dataset(query, node) -> pd.DataFrame:
28
+ df = grewtse.generate_masked_dataset(query, node)
29
+ return df
30
+
31
+ def safe_str_to_dict(s):
32
+ try:
33
+ return ast.literal_eval(s)
34
+ except (ValueError, SyntaxError):
35
+ return None
36
+
37
+ def generate_minimal_pairs(query: str, node: str, alt_features: str):
38
+ if not grewtse.is_treebank_loaded():
39
+ raise ValueError("Please parse a treebank first.")
40
+
41
+ # mask each sentence
42
+ resulting_dataset = to_masked_dataset(query, node)
43
+
44
+ # determine whether an alternative LI should be found
45
+ alt_features_as_dict = safe_str_to_dict(alt_features)
46
+ if alt_features_as_dict is not None:
47
+ resulting_dataset = grewtse.generate_minimal_pairs(alt_features_as_dict, {})
48
+ # resulting_dataset = grewtse.get_masked_dataset()
49
+ print(resulting_dataset)
50
+
51
+ # save to a temporary CSV file
52
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
53
+ resulting_dataset.to_csv(temp_file.name, index=False)
54
+ return resulting_dataset, temp_file.name
55
+
56
+ def evaluate_model(model_repo: str, target_x_label: str, alt_x_label: str, x_axis_label: str, title: str):
57
+ if not grewtse.are_minimal_pairs_generated():
58
+ raise ValueError("Please parse a treebank, mask a dataset and generate minimal pairs first.")
59
+
60
+ mp_with_eval_dataset = grewtse.evaluate_bert_mlm(model_repo)
61
+ vis_filename = "vis.png"
62
+
63
+ grewtse.visualise_syntactic_performance(vis_filename,
64
+ mp_with_eval_dataset,
65
+ target_x_label,
66
+ alt_x_label,
67
+ x_axis_label,
68
+ "Confidence",
69
+ title)
70
+
71
+ # save to a temporary CSV file
72
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
73
+ mp_with_eval_dataset.to_csv(temp_file.name, index=False)
74
+ return mp_with_eval_dataset, temp_file.name, vis_filename
75
+
76
+ def show_df():
77
+ return gr.update(visible=True)
78
+
79
+ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
80
+ is_treebank_parse_success = False
81
+
82
+ with gr.Row():
83
+ gr.Markdown("# GREW-TSE: A Pipeline for Query-based Targeted Syntactic Evaluation")
84
+
85
+ with gr.Row():
86
+ with gr.Column():
87
+ gr.Markdown("""
88
+ #### Load a Treebank
89
+ You can begin by loading up a particular treebank that you'd like to work with.<br>
90
+ You can either select a treebank from the pre-loaded options below, or upload your own.<br>
91
+ """)
92
+
93
+ with gr.Column():
94
+ with gr.Tabs():
95
+ with gr.TabItem("Choose Treebank"):
96
+ treebank_selection = gr.Dropdown(
97
+ choices=["None", "spanish-test-sm.conllu", "polish-test-lg.conllu"],
98
+ label="Select a treebank",
99
+ value="spanish-test-sm.conllu"
100
+ )
101
+
102
+ with gr.TabItem("Upload Your Own"):
103
+ gr.Markdown("## Upload a .conllu File")
104
+ file_input = gr.File(
105
+ label="Upload .conllu file",
106
+ file_types=[".conllu"],
107
+ type="filepath"
108
+ )
109
+ parse_file_button = gr.Button("Parse Treebank", size='sm', scale=1)
110
+
111
+ gr.Markdown("## Isolate A Syntactic Phenomenon")
112
+ morph_table = gr.Dataframe(interactive=False, visible=False)
113
+
114
+ parse_file_button.click(
115
+ fn=parse_treebank,
116
+ inputs=[file_input, treebank_selection],
117
+ outputs=[morph_table]
118
+ )
119
+ parse_file_button.click(
120
+ fn=show_df,
121
+ outputs=morph_table
122
+ )
123
+
124
+ with gr.Row():
125
+ with gr.Column():
126
+ gr.Markdown("""
127
+ **GREW (Graph Rewriting for Universal Dependencies)** is a query and transformation language used to search within and manipulate dependency treebanks. A GREW query allows linguists and NLP researchers to find specific syntactic patterns in parsed linguistic data (such as Universal Dependencies treebanks).
128
+ Queries are expressed as graph constraints using a concise pattern-matching syntax.
129
+
130
+ #### Example
131
+ The following short GREW query will find target any verbs. Try it with one of the sample treebanks above.
132
+ Make sure to include the variable V as the target that we're trying to isolate.
133
+
134
+ ```grew
135
+ V [upos=\"VERB\"];
136
+ ```
137
+ """)
138
+ with gr.Column():
139
+ query_input = gr.Textbox(label="GREW Query", lines=5, placeholder="Enter your GREW query here...", value="V [upos=\"VERB\"];")
140
+ node_input = gr.Textbox(label="Node", placeholder="The variable in your GREW query to isolate, e.g., N", value="V")
141
+ feature_input = gr.Textbox(
142
+ label="Enter Alternative Feature Values for Minimal Pair as a Dictionary",
143
+ placeholder='e.g. {"case": "Acc", "number": "Sing"}',
144
+ value="{\"mood\": \"Sub\"}",
145
+ lines=3
146
+ )
147
+ run_button = gr.Button("Run Query", size='sm', scale=3)
148
+
149
+ output_table = gr.Dataframe(label="Output Table", visible=False)
150
+ download_file = gr.File(label="Download CSV")
151
+ run_button.click(
152
+ fn=generate_minimal_pairs,
153
+ inputs=[query_input, node_input, feature_input],
154
+ outputs=[output_table, download_file]
155
+ )
156
+ run_button.click(
157
+ fn=show_df,
158
+ outputs=output_table
159
+ )
160
+
161
+ with gr.Row():
162
+ with gr.Column():
163
+ gr.Markdown("""
164
+ ## Evaluate A Model
165
+ You can evaluate any BERT for MLM model by providing the name of the model repository.
166
+ """)
167
+ with gr.Column():
168
+ repository_input = gr.Textbox(label="Model Repository", lines=1, placeholder="Enter the model repository here...", value="dccuchile/distilbert-base-spanish-uncased")
169
+
170
+ with gr.Row():
171
+ with gr.Column():
172
+ gr.Markdown("""
173
+ ## Choose Visualisation Settings
174
+ The results will be displayed as a visualisation which you can edit using the following settings.
175
+ """)
176
+ with gr.Column():
177
+ target_x_label_textbox = gr.Textbox(label="Original Label Name i.e type of the 'right' token", lines=1, placeholder="Genitive Version")
178
+ alt_x_label_textbox = gr.Textbox(label="Alternative Label Name i.e type of the 'wrong' token", lines=1, placeholder="Accusative Version")
179
+ x_axis_label_textbox = gr.Textbox(label="X Axis Title i.e what features are you comparing?", lines=1, placeholder="Case of Nouns in Transitive Verbs")
180
+ title_textbox = gr.Textbox(label="Visualisation Title", lines=1, placeholder="Syntactic Performance of BERT on English Transitive Noun Case")
181
+
182
+ evaluate_button = gr.Button("Evaluate Model", size='sm', scale=3)
183
+
184
+ mp_with_eval_output_dataset = gr.Dataframe(label="Output Table", visible=False)
185
+ mp_with_eval_output_download = gr.File(label="Download CSV")
186
+ visualisation_widget = gr.Image(type="pil", label="Loaded Image")
187
+
188
+ evaluate_button.click(
189
+ fn=evaluate_model,
190
+ inputs=[repository_input, target_x_label_textbox, alt_x_label_textbox, x_axis_label_textbox, title_textbox],
191
+ outputs=[mp_with_eval_output_dataset, mp_with_eval_output_download, visualisation_widget]
192
+ )
193
+ evaluate_button.click(
194
+ fn=show_df,
195
+ outputs=[mp_with_eval_output_dataset]
196
+ )
197
+
198
+ if __name__ == "__main__":
199
+ demo.launch(share=True)
polish-test-lg.conllu.conllu ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ black==25.1.0
5
+ Brotli==1.1.0
6
+ certifi==2025.4.26
7
+ charset-normalizer==3.4.2
8
+ click==8.2.0
9
+ colorama==0.4.6
10
+ conllu==6.0.0
11
+ contourpy==1.3.2
12
+ coverage==7.8.0
13
+ cycler==0.12.1
14
+ fastapi==0.116.1
15
+ ffmpy==0.6.0
16
+ filelock==3.18.0
17
+ fonttools==4.58.4
18
+ fsspec==2025.5.1
19
+ gradio==5.37.0
20
+ gradio_client==1.10.4
21
+ grewpy==0.6.0
22
+ groovy==0.1.2
23
+ h11==0.16.0
24
+ hf-xet==1.1.5
25
+ httpcore==1.0.9
26
+ httpx==0.28.1
27
+ huggingface-hub==0.32.2
28
+ idna==3.10
29
+ iniconfig==2.1.0
30
+ isort==6.0.1
31
+ Jinja2==3.1.6
32
+ kiwisolver==1.4.8
33
+ lark==1.2.2
34
+ markdown-it-py==3.0.0
35
+ MarkupSafe==3.0.2
36
+ matplotlib==3.10.3
37
+ mdurl==0.1.2
38
+ mizani==0.13.5
39
+ mpmath==1.3.0
40
+ mypy==1.15.0
41
+ mypy_extensions==1.1.0
42
+ networkx==3.4.2
43
+ numpy==2.2.5
44
+ nvidia-cublas-cu12==12.6.4.1
45
+ nvidia-cuda-cupti-cu12==12.6.80
46
+ nvidia-cuda-nvrtc-cu12==12.6.77
47
+ nvidia-cuda-runtime-cu12==12.6.77
48
+ nvidia-cudnn-cu12==9.5.1.17
49
+ nvidia-cufft-cu12==11.3.0.4
50
+ nvidia-cufile-cu12==1.11.1.6
51
+ nvidia-curand-cu12==10.3.7.77
52
+ nvidia-cusolver-cu12==11.7.1.2
53
+ nvidia-cusparse-cu12==12.5.4.2
54
+ nvidia-cusparselt-cu12==0.6.3
55
+ nvidia-nccl-cu12==2.26.2
56
+ nvidia-nvjitlink-cu12==12.6.85
57
+ nvidia-nvtx-cu12==12.6.77
58
+ orjson==3.10.18
59
+ packaging==25.0
60
+ pandas==2.2.3
61
+ pandas-stubs==2.2.3.250527
62
+ pathspec==0.12.1
63
+ patsy==1.0.1
64
+ pillow==11.2.1
65
+ platformdirs==4.3.8
66
+ plotnine==0.14.5
67
+ pluggy==1.5.0
68
+ pydantic==2.11.7
69
+ pydantic_core==2.33.2
70
+ pydub==0.25.1
71
+ Pygments==2.19.2
72
+ pyparsing==3.2.3
73
+ pytest==8.3.5
74
+ pytest-cov==6.1.1
75
+ python-dateutil==2.9.0.post0
76
+ python-multipart==0.0.20
77
+ pytz==2025.2
78
+ PyYAML==6.0.2
79
+ regex==2024.11.6
80
+ requests==2.32.3
81
+ rich==14.0.0
82
+ ruff==0.11.9
83
+ safehttpx==0.1.6
84
+ safetensors==0.5.3
85
+ scipy==1.15.3
86
+ semantic-version==2.10.0
87
+ setuptools==80.4.0
88
+ shellingham==1.5.4
89
+ six==1.17.0
90
+ sniffio==1.3.1
91
+ starlette==0.47.1
92
+ statsmodels==0.14.4
93
+ sympy==1.14.0
94
+ tokenizers==0.21.1
95
+ tomlkit==0.13.3
96
+ torch==2.7.0
97
+ tqdm==4.67.1
98
+ transformers==4.52.3
99
+ triton==3.3.0
100
+ typer==0.16.0
101
+ types-pytz==2025.2.0.20250516
102
+ typing-inspection==0.4.1
103
+ typing_extensions==4.13.2
104
+ tzdata==2025.2
105
+ urllib3==2.4.0
106
+ uvicorn==0.35.0
107
+ websockets==15.0.1
spanish-test-sm.conllu ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # global.columns = ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC
2
+ # sent_id = 3LB-CAST-c2-2-s18
3
+ # text = Sea enhorabuena.
4
+ # orig_file_sentence 011#21
5
+ 1 Sea ser AUX vssp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root _
6
+ 2 enhorabuena enhorabuena NOUN ncfs000 Gender=Fem|Number=Sing 1 nsubj 1:nsubj ArgTem=arg1:tem|SpaceAfter=No
7
+ 3 . . PUNCT fp PunctType=Peri 1 punct 1:punct _
8
+
9
+ # sent_id = 3LB-CAST-d2-12-s5
10
+ # text = Esperemos que muy pronto.
11
+ # orig_file_sentence 002#8
12
+ 1 Esperemos esperar VERB vmsp1p0 Mood=Sub|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin 0 root 0:root _
13
+ 2 que que SCONJ cs _ 4 mark 4:mark _
14
+ 3 muy mucho ADV rg _ 4 advmod 4:advmod _
15
+ 4 pronto pronto ADV rg _ 1 obj 1:obj ArgTem=arg1:pat|SpaceAfter=No
16
+ 5 . . PUNCT fp PunctType=Peri 1 punct 1:punct _
17
+
18
+ # sent_id = CESS-CAST-P-20000701-69-s8
19
+ # text = Mujer hubiera mostrado arrepentimiento.
20
+ # orig_file_sentence 131#40
21
+ 1 Mujer mujer NOUN ncfs000 Gender=Fem|Number=Sing 3 nsubj 3:nsubj ArgTem=arg0:agt|Entity=(NOCOREF:Gen--1-gstype:gen)
22
+ 2 hubiera haber AUX vasi3s0 Mood=Sub|Number=Sing|Person=3|Tense=Imp|VerbForm=Fin 3 aux 3:aux _
23
+ 3 mostrado mostrar VERB vmp00sm Gender=Masc|Number=Sing|Tense=Past|VerbForm=Part 0 root 0:root _
24
+ 4 arrepentimiento arrepentimiento NOUN ncms000 Gender=Masc|Number=Sing 3 obj 3:obj ArgTem=arg1:pat|Entity=(NOCOREF:Gen--1-gstype:gen)|SpaceAfter=No
25
+ 5 . . PUNCT fp PunctType=Peri 3 punct 3:punct _
26
+
27
+ # sent_id = CESS-CAST-P-20000601-41-s10
28
+ # text = Que por ritmo no quede.
29
+ # orig_file_sentence 119#88
30
+ 1 Que que SCONJ cs _ 5 mark 5:mark _
31
+ 2 por por ADP sps00 _ 3 case 3:case _
32
+ 3 ritmo ritmo NOUN ncms000 Gender=Masc|Number=Sing 5 obl 5:obl ArgTem=argM:adv
33
+ 4 no no ADV rn Polarity=Neg 5 advmod 5:advmod _
34
+ 5 quede quedar VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root SpaceAfter=No
35
+ 6 . . PUNCT fp PunctType=Peri 5 punct 5:punct _
36
+
37
+ # sent_id = 3LB-CAST-t3-4-s23
38
+ # text = Yo esperar cuanto haga falta.
39
+ # orig_file_sentence 027#95
40
+ 1 Yo yo PRON pp1csn00 Case=Nom|Number=Sing|Person=1|PronType=Prs 2 obj 2:obj _
41
+ 2 esperar esperar VERB vmn0000 VerbForm=Inf 0 root 0:root _
42
+ 3 cuanto cuanto PRON pr0ms000 Gender=Masc|Number=Sing|NumType=Card|PronType=Int,Rel 4 nsubj 4:nsubj ArgTem=arg1:tem
43
+ 4 haga hacer VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 2 advcl 2:advcl MWE=haga_falta|MWEPOS=VERB
44
+ 5 falta falta NOUN _ Gender=Fem|Number=Sing 4 compound 4:compound SpaceAfter=No
45
+ 6 . . PUNCT fp PunctType=Peri 2 punct 2:punct _
46
+
47
+ # sent_id = 3LB-CAST-c1-4-s5
48
+ # text = Un soldador que me meta fuego.
49
+ # orig_file_sentence 010#41
50
+ 1 Un uno DET di0ms0 Definite=Ind|Gender=Masc|Number=Sing|PronType=Art 2 det 2:det _
51
+ 2 soldador soldador NOUN ncms000 Gender=Masc|Number=Sing 0 root 0:root _
52
+ 3 que que PRON pr0cn000 PronType=Rel 5 nsubj 5:nsubj ArgTem=arg0:agt
53
+ 4 me yo PRON pp1cs000 Case=Dat|Number=Sing|Person=1|PrepCase=Npr|PronType=Prs 5 obl:arg 5:obl:arg ArgTem=arg2:ben
54
+ 5 meta meter VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 2 acl 2:acl _
55
+ 6 fuego fuego NOUN ncms000 Gender=Masc|Number=Sing 5 obj 5:obj ArgTem=arg1:pat|SpaceAfter=No
56
+ 7 . . PUNCT fp PunctType=Peri 2 punct 2:punct _
57
+
58
+ # sent_id = CESS-CAST-P-20010701-38-s94
59
+ # text = Por mucho que me lo pidan.
60
+ # orig_file_sentence 117#2
61
+ 1 Por por ADP cs _ 6 mark 6:mark MWE=Por_mucho_que|MWEPOS=SCONJ
62
+ 2 mucho mucho ADV _ _ 1 fixed 1:fixed _
63
+ 3 que que SCONJ _ _ 1 fixed 1:fixed _
64
+ 4 me yo PRON pp1cs000 Case=Dat|Number=Sing|Person=1|PrepCase=Npr|PronType=Prs 6 obl:arg 6:obl:arg ArgTem=arg2:ben|Entity=(CESSCASTP2001070138c1-person-1-CorefType:ident,gstype:spec)
65
+ 5 lo él PRON pp3msa00 Case=Acc|Gender=Masc|Number=Sing|Person=3|PrepCase=Npr|PronType=Prs 6 obj 6:obj ArgTem=arg1:pat|Entity=(CESSCASTP2001070138c51--1-CorefType:dx.prop)
66
+ 6 pidan pedir VERB vmsp3p0 Mood=Sub|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root SpaceAfter=No
67
+ 7 . . PUNCT fp PunctType=Peri 6 punct 6:punct _
68
+
69
+ # sent_id = 3LB-CAST-t6-3-s5
70
+ # text = Tal vez eso fuera lo peor.
71
+ # orig_file_sentence 030#40
72
+ 1 Tal tal NOUN rg _ 6 advmod 6:advmod MWE=Tal_vez|MWEPOS=ADV
73
+ 2 vez vez NOUN _ _ 1 fixed 1:fixed _
74
+ 3 eso ese PRON pd0ns000 Number=Sing|PronType=Dem 6 nsubj 6:nsubj ArgTem=arg1:tem
75
+ 4 fuera ser AUX vssi3s0 Mood=Sub|Number=Sing|Person=3|Tense=Imp|VerbForm=Fin 6 cop 6:cop _
76
+ 5 lo él PRON da0ns0 Case=Acc|Definite=Def|Gender=Masc|Number=Sing|Person=3|PrepCase=Npr|PronType=Prs 6 det 6:det _
77
+ 6 peor peor ADJ aq0cs0 Degree=Cmp|Number=Sing 0 root 0:root ArgTem=arg2:atr|SpaceAfter=No
78
+ 7 . . PUNCT fp PunctType=Peri 6 punct 6:punct _
79
+
80
+ # sent_id = 3LB-CAST-n1-3-s13
81
+ # text = La bandera, que no falte.
82
+ # orig_file_sentence 024#5
83
+ 1 La el DET da0fs0 Definite=Def|Gender=Fem|Number=Sing|PronType=Art 2 det 2:det _
84
+ 2 bandera bandera NOUN ncfs000 Gender=Fem|Number=Sing 6 nsubj 6:nsubj ArgTem=arg1:tem|SpaceAfter=No
85
+ 3 , , PUNCT fc PunctType=Comm 2 punct 2:punct _
86
+ 4 que que SCONJ cs _ 6 mark 6:mark _
87
+ 5 no no ADV rn Polarity=Neg 6 advmod 6:advmod _
88
+ 6 falte faltar VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 0 root 0:root SpaceAfter=No
89
+ 7 . . PUNCT fp PunctType=Peri 6 punct 6:punct _
90
+
91
+ # sent_id = CESS-CAST-P-20020202-102-s11
92
+ # text = Diga lo que diga.
93
+ # orig_file_sentence 076#95
94
+ 1 Diga decir VERB vmm03s0 Mood=Imp|Number=Sing|Person=3|VerbForm=Fin 0 root 0:root _
95
+ 1.1 _ _ PRON p _ _ _ 1:nsubj ArgTem=arg0:agt|Entity=(CESSCASTP20020202102c2--1-CorefType:ident)|wordform=__EMPTY__
96
+ 2 lo él PRON da0ns0 Case=Acc|Definite=Def|Gender=Masc|Number=Sing|Person=3|PrepCase=Npr|PronType=Prs 4 det 4:det _
97
+ 3 que que PRON pr0cn000 PronType=Rel 4 obj 4:obj ArgTem=arg1:pat
98
+ 4 diga decir VERB vmsp3s0 Mood=Sub|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 ccomp 1:ccomp ArgTem=arg1:pat|SpaceAfter=No
99
+ 4.1 _ _ PRON p _ _ _ 4:nsubj ArgTem=arg0:agt|Entity=(CESSCASTP20020202102c2--1-CorefType:ident)|wordform=__EMPTY__
100
+ 5 . . PUNCT fp PunctType=Peri 1 punct 1:punct _
treetse/__init__.py ADDED
File without changes
treetse/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (165 Bytes). View file
 
treetse/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (161 Bytes). View file
 
treetse/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (8.34 kB). View file
 
treetse/evaluators/__pycache__/evaluator.cpython-312.pyc ADDED
Binary file (5.25 kB). View file
 
treetse/evaluators/__pycache__/evaluator.cpython-313.pyc ADDED
Binary file (5.35 kB). View file
 
treetse/evaluators/evaluator.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
2
+ from typing import Any, Tuple
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class Evaluator:
8
+ def __init__(self):
9
+ self.mask_token_index: int = -1
10
+ self.mask_probs: torch.Tensor | None = None
11
+ self.tokeniser: Any = None
12
+ self.model: Any = None
13
+ self.logits: torch.Tensor = None
14
+
15
+ def setup_parameters(self, model_name: str) -> Tuple[Any, Any]:
16
+ # Q: what sort of tokenisers are being used?
17
+ self.tokeniser = AutoTokenizer.from_pretrained(model_name)
18
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
19
+
20
+ # set to eval mode, disabling things like dropout
21
+ self.model.eval()
22
+
23
+ return self.model, self.tokeniser
24
+
25
+ def run_masked_prediction(
26
+ self, model: Any, tokeniser: Any, sentence: str, target_token: str
27
+ ) -> Tuple[Any, Any]:
28
+ mask_token = tokeniser.mask_token
29
+ sentence_masked = sentence.replace("[MASK]", mask_token)
30
+
31
+ if sentence_masked.count("[MASK]") != 1:
32
+ raise ValueError("Only single-mask sentences are supported.")
33
+
34
+ inputs = tokeniser(sentence_masked, return_tensors="pt")
35
+
36
+ # Get logits from model
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ logits = outputs.logits
40
+ self.logits = logits
41
+
42
+ self.mask_token_index = self._get_mask_index(inputs, tokeniser)
43
+ self.mask_probs = self._get_mask_probabilities(
44
+ self.mask_token_index, self.logits
45
+ )
46
+
47
+ return self.mask_token_index, self.mask_probs
48
+
49
+ def get_token_prob(self, token: str) -> float:
50
+ target_id = self.tokeniser.convert_tokens_to_ids(token)
51
+ prob = self.get_prob_by_id(target_id)
52
+ return prob
53
+
54
+ def get_top_pred(self) -> dict:
55
+ top_pred_id = int(torch.argmax(self.mask_probs, dim=-1).item())
56
+ top_pred_token = self.tokeniser.convert_ids_to_tokens(top_pred_id)
57
+ top_token_prob = self.get_prob_by_id(top_pred_id)
58
+ return top_pred_token, top_token_prob
59
+
60
+ def get_prob_by_id(self, id: int) -> float:
61
+ if self.mask_probs is not None:
62
+ return self.mask_probs[id].item()
63
+ else:
64
+ raise KeyError("Please evaluate a dataset first. Results empty")
65
+
66
+ def _get_mask_index(self, inputs: Any, tokeniser: Any) -> int:
67
+ if "input_ids" not in inputs:
68
+ raise ValueError("Missing 'input_ids' in inputs.")
69
+
70
+ if tokeniser.mask_token_id is None:
71
+ raise ValueError("The tokeniser does not have a defined mask_token_id.")
72
+
73
+ input_ids = inputs["input_ids"]
74
+ mask_positions = torch.where(input_ids == tokeniser.mask_token_id)
75
+
76
+ if len(mask_positions[0]) == 0:
77
+ raise ValueError("No mask token found in input_ids.")
78
+
79
+ if len(mask_positions[0]) > 1:
80
+ raise ValueError("Multiple mask tokens found; expected only one.")
81
+
82
+ return (
83
+ mask_positions[1].item()
84
+ if len(mask_positions) > 1
85
+ else mask_positions[0].item()
86
+ )
87
+
88
+ def _get_mask_probabilities(
89
+ self, mask_token_index: int, logits: Any
90
+ ) -> torch.Tensor:
91
+ mask_logits = logits[0, mask_token_index, :] # shape: (1, vocab_size)
92
+ probs = F.softmax(mask_logits, dim=-1) # shape: (1, vocab_size)
93
+ return probs
treetse/evaluators/perplexity.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class PerplexityEvaluator:
2
+ def __init__(self) -> None:
3
+ pass
4
+
5
+ def compute_perplexity(self, logits: list) -> list:
6
+ pass
7
+
8
+ """
9
+ -- Classic TSE --
10
+ Evaluates based on minimal pairs, where a particular feature
11
+ is chosen and two values of that feature are compared.
12
+
13
+ 1. Accepts the inputs, logits, feature name, and feature values as input.
14
+ Finds the lexical items which are the same accept for these values of this
15
+ feature, including in UPOS and lemma.
16
+ 2. Computes the perplexity scores for the correct value and the alternative syntactic
17
+ option.
18
+ """
19
+
20
+ def compute_classic_tse(self) -> None:
21
+ pass
22
+
23
+ """
24
+ --- Generalised TSE --
25
+ Evaluates based on minimal syntactic pairs, that is, a candidate set is created for the
26
+ correct token as well as the alternate values for that particular features
27
+ """
28
+
29
+ def compute_generalised_tse(
30
+ self,
31
+ ) -> None:
32
+ pass
treetse/pipeline.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ from treetse.preprocessing.conllu_parser import ConlluParser
6
+ from treetse.evaluators.evaluator import Evaluator
7
+ from treetse.visualise.visualiser import Visualiser
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format="%(asctime)s [%(levelname)s] %(message)s",
12
+ handlers=[logging.FileHandler("app.log"), logging.StreamHandler()],
13
+ )
14
+
15
+ class Grewtse:
16
+ def __init__(self):
17
+ self.parser = ConlluParser()
18
+ self.evaluator = Evaluator()
19
+ self.visualiser = Visualiser()
20
+
21
+ self.treebank_path: str = None
22
+ self.lexical_items: pd.DataFrame = None
23
+ self.masked_dataset: pd.DataFrame = None
24
+ self.exception_dataset: pd.DataFrame = None
25
+ self.evaluation_results: pd.DataFrame = None
26
+
27
+ def parse_treebank(self, filepath: str) -> bool:
28
+ try:
29
+ self.treebank_path = filepath
30
+ self.lexical_items = self.parser._build_lexical_item_dataset(filepath)
31
+ return True
32
+ except Exception as e:
33
+ self.treebank_path = None
34
+ self.lexical_items = None
35
+ return False
36
+
37
+ def is_treebank_loaded(self) -> bool:
38
+ return self.lexical_items is not None
39
+
40
+ def is_dataset_masked(self) -> bool:
41
+ return self.masked_dataset is not None
42
+
43
+ def get_lexical_items(self) -> pd.DataFrame:
44
+ return self.lexical_items
45
+
46
+ def get_morphological_features(self) -> list:
47
+ if self.lexical_items is None:
48
+ raise ValueError("Cannot get features: You must parse a treebank first.")
49
+
50
+ morph_df = self.lexical_items
51
+ morph_df.columns = [col.replace("feats__", "") if col.startswith("feats__") else col for col in morph_df.columns]
52
+
53
+ return morph_df
54
+
55
+ def generate_masked_dataset(
56
+ self, query: str, target_node: str, mask_token: str = "[MASK]"
57
+ ) -> pd.DataFrame:
58
+ if self.treebank_path is None:
59
+ raise ValueError("Cannot create masked dataset: no treebank filepath provided.")
60
+
61
+ results = self.parser._build_masked_dataset_grew(
62
+ self.treebank_path, query, target_node, mask_token
63
+ )
64
+ self.masked_dataset = results['masked']
65
+ self.exception_dataset = results['exception']
66
+ return self.masked_dataset
67
+
68
+ def get_masked_dataset(self) -> pd.DataFrame:
69
+ return self.masked_dataset
70
+
71
+ def generate_minimal_pairs(self, morph_features: dict, upos_features: dict | None) -> pd.DataFrame:
72
+ if self.masked_dataset is None:
73
+ raise ValueError("Cannot generate minimal pairs: treebank must be parsed and masked first.")
74
+
75
+ def convert_row_to_feature(row):
76
+ return self.parser.to_syntactic_feature(
77
+ row['sentence_id'],
78
+ row['match_id']-1,
79
+ morph_features,
80
+ {},
81
+ )
82
+ alternative_row = self.masked_dataset.apply(convert_row_to_feature, axis=1)
83
+ self.masked_dataset['alternative'] = alternative_row
84
+ return self.masked_dataset
85
+
86
+ def are_minimal_pairs_generated(self) -> bool:
87
+ return self.is_treebank_loaded() and \
88
+ self.is_dataset_masked() and \
89
+ ('alternative' in self.masked_dataset.columns)
90
+
91
+ def evaluate_bert_mlm(self, model_repo: str, row_limit: int = None) -> pd.DataFrame:
92
+ if self.masked_dataset is None:
93
+ raise ValueError("Cannot evaluate: treebank must be parsed and masked first.")
94
+
95
+ test_model, test_tokeniser = self.evaluator.setup_parameters(model_repo)
96
+ results = []
97
+
98
+ counter = 0
99
+ for row in self.masked_dataset.itertuples():
100
+ masked_sentence = row.masked_text
101
+ label = row.match_token
102
+ alternative_form = row.alternative
103
+
104
+ row_results = {
105
+ "sentence_id": row.sentence_id,
106
+ "token_id": row.match_id,
107
+ "masked_sentence": masked_sentence,
108
+ "num_tokens": row.num_tokens,
109
+ "label": label,
110
+ "label_prob": None,
111
+ "alternative": alternative_form,
112
+ "alternative_prob": None,
113
+ "top_pred_label": None,
114
+ "top_pred_prob": None,
115
+ }
116
+
117
+ try:
118
+ self.evaluator.run_masked_prediction(
119
+ test_model, test_tokeniser, masked_sentence, label
120
+ )
121
+ except Exception as e:
122
+ raise Exception("There was an issue with the model or tokeniser")
123
+
124
+ # -- LABEL PROB --
125
+ label_prob = self.evaluator.get_token_prob(label)
126
+ row_results["label_prob"] = label_prob
127
+
128
+ # -- ALTERNATIVE FORM --
129
+ if alternative_form:
130
+ logging.info("----")
131
+ logging.info(f"Label Form: {label}")
132
+ logging.info(f"Alternative Form: {alternative_form}")
133
+ logging.info("----")
134
+
135
+ alt_form_prob = self.evaluator.get_token_prob(alternative_form)
136
+ row_results["alternative_prob"] = alt_form_prob
137
+
138
+ # -- HIGHEST PROB --
139
+ top_pred_label, top_pred_prob = self.evaluator.get_top_pred()
140
+ row_results["top_pred_label"] = top_pred_label
141
+ row_results["top_pred_prob"] = top_pred_prob
142
+
143
+ results.append(row_results)
144
+
145
+ if row_limit:
146
+ counter += 1
147
+ if counter == row_limit:
148
+ break
149
+
150
+ results_df = pd.DataFrame(results)
151
+ self.evaluation_dataset = results_df
152
+ return results_df
153
+
154
+ def visualise_syntactic_performance(
155
+ self,
156
+ filename: str,
157
+ results: pd.DataFrame,
158
+ target_x_label: str,
159
+ alt_x_label: str,
160
+ x_axis_label: str,
161
+ y_axis_label: str,
162
+ title: str,
163
+ ) -> None:
164
+ visualiser = Visualiser()
165
+ visualiser.visualise_slope(
166
+ filename,
167
+ results,
168
+ target_x_label,
169
+ alt_x_label,
170
+ x_axis_label,
171
+ y_axis_label,
172
+ title,
173
+ )
174
+
175
+
176
+
177
+ """
178
+ def store_results(
179
+ results_filename: str,
180
+ li_set_filename: str,
181
+ model_results: pd.DataFrame,
182
+ li_set: pd.DataFrame,
183
+ ):
184
+ try:
185
+ model_results.to_csv(base_dir / "output" / results_filename, index=False)
186
+ li_set.to_csv(base_dir / li_set_filename, index=True)
187
+
188
+ model_results["difference"] = (
189
+ model_results["label_prob"] - model_results["alternative_prob"]
190
+ )
191
+ model_results = model_results.sort_values("difference")
192
+ model_results.dropna().to_csv(
193
+ base_dir / "output" / f"filtered_{results_filename}", index=False
194
+ )
195
+ except Exception as e:
196
+ logging.error(f"Failed to output to CSV: {e}")
197
+ raise
198
+ """
199
+
200
+
treetse/preprocessing/__init__.py ADDED
File without changes
treetse/preprocessing/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (179 Bytes). View file
 
treetse/preprocessing/__pycache__/conllu_parser.cpython-312.pyc ADDED
Binary file (16.4 kB). View file
 
treetse/preprocessing/__pycache__/conllu_parser.cpython-313.pyc ADDED
Binary file (12.3 kB). View file
 
treetse/preprocessing/__pycache__/grew_dependencies.cpython-312.pyc ADDED
Binary file (1.05 kB). View file
 
treetse/preprocessing/__pycache__/reconstruction.cpython-312.pyc ADDED
Binary file (2.33 kB). View file
 
treetse/preprocessing/__pycache__/reconstruction.cpython-313.pyc ADDED
Binary file (2.42 kB). View file
 
treetse/preprocessing/conllu_parser.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from treetse.preprocessing.grew_dependencies import match_dependencies
2
+ from treetse.preprocessing.reconstruction import Lexer
3
+ from conllu import parse_incr, Token
4
+ from pathlib import Path
5
+ from typing import Any
6
+ import pandas as pd
7
+ import numpy as np
8
+ import logging
9
+
10
+ def test_function():
11
+ return True
12
+
13
+ class ConlluParser:
14
+ def __init__(self) -> None:
15
+ self.li_feature_set: pd.DataFrame = None
16
+ self.masked_dataset: pd.DataFrame = None
17
+ self.exception_dataset: pd.DataFrame = None
18
+ self.lexer: Lexer = Lexer()
19
+
20
+ # todo: add error handling here
21
+ def parse_grew(
22
+ self, path: str, grew_query: str, grew_variable_to_mask: str, mask_token: str = "[MASK]"
23
+ ) -> bool:
24
+ self.li_feature_set = self._build_lexical_item_dataset(path)
25
+
26
+ masking_results = self._build_masked_dataset_grew(
27
+ path, grew_query, grew_variable_to_mask, mask_token
28
+ )
29
+ self.masked_dataset = masking_results["masked"]
30
+ self.exception_dataset = masking_results["exception"]
31
+
32
+ return self.masked_dataset, self.exception_dataset
33
+
34
+ # todo: add error handling here
35
+ def parse(
36
+ self, path: str, morphological_constraints: dict, universal_constraints: dict, mask_token: str = "[MASK]"
37
+ ) -> bool:
38
+ self.li_feature_set = self._build_lexical_item_dataset(path)
39
+
40
+ upos_constraint = universal_constraints["upos"] if "upos" in universal_constraints else None
41
+
42
+ masking_results = self._build_masked_dataset(
43
+ path, morphological_constraints, upos_constraint, mask_token
44
+ )
45
+ self.masked_dataset = masking_results["masked"]
46
+ self.exception_dataset = masking_results["exception"]
47
+
48
+ return True
49
+
50
+ def get_masked_dataset(self) -> pd.DataFrame:
51
+ return self.masked_dataset
52
+
53
+ def get_lexical_item_dataset(self) -> pd.DataFrame:
54
+ return self.li_feature_set
55
+
56
+ # this shouldn't be hard coded
57
+ def get_feature_names(self) -> list:
58
+ return self.li_feature_set.columns[4:].to_list()
59
+
60
+ # todo: add more safety
61
+ def get_features(self, sentence_id: str, token_id: int) -> dict:
62
+ print(sentence_id)
63
+ print(token_id)
64
+ print(self.li_feature_set.index)
65
+ return self.li_feature_set.loc[(sentence_id, token_id)][self.get_feature_names()].to_dict()
66
+
67
+ def get_lemma(self, sentence_id: str, token_id: str) -> str:
68
+ return self.li_feature_set.loc[(sentence_id, token_id)]["lemma"]
69
+
70
+ # todo: handle making sure that it is the exact same as the lemma
71
+ def to_syntactic_feature(self, sentence_id: str, token_id: str, alt_morph_constraints: dict, alt_universal_constraints: dict) -> str | None:
72
+
73
+ # distinguish morphological from universal features
74
+ # todo: find a better way to do this
75
+ # prefix = 'feats__'
76
+ prefix = ''
77
+ alt_morph_constraints = {prefix + key: value for key, value in alt_morph_constraints.items()}
78
+
79
+ token_features = self.get_features(sentence_id, token_id)
80
+
81
+ token_features.update(alt_morph_constraints)
82
+ token_features.update(alt_universal_constraints)
83
+ lexical_items = self.li_feature_set
84
+
85
+ # get only those items which are the same lemma
86
+ lemma = self.get_lemma(sentence_id, token_id)
87
+ lemma_mask = lexical_items['lemma'] == lemma
88
+ lexical_items = lexical_items[lemma_mask]
89
+ logging.info(f"Looking for form {lemma}")
90
+ logging.info(lexical_items)
91
+ print(token_features.items())
92
+
93
+ for feat, value in token_features.items():
94
+ # ensure feature is a valid feature in feature set
95
+ if feat not in lexical_items.columns:
96
+ raise KeyError(
97
+ "Invalid feature provided to confound set: {}".format(feat)
98
+ )
99
+
100
+ # slim the mask down using each feature
101
+ # interesting edge case: np.nan == np.nan returns false!
102
+ mask = (lexical_items[feat] == value) | (lexical_items[feat].isna() & pd.isna(value))
103
+ lexical_items = lexical_items[mask]
104
+
105
+ if len(lexical_items) > 0:
106
+ return lexical_items["form"].iloc[0]
107
+ else:
108
+ return None
109
+
110
+ def get_candidate_set(self, universal_constraints: dict, morph_constraints: dict) -> pd.DataFrame:
111
+ has_parsed_conllu = self.li_feature_set is not None
112
+ if not has_parsed_conllu:
113
+ raise ValueError("Please parse a ConLLU file first.")
114
+
115
+ morph_constraints = {f"feats__{k}": v for k, v in morph_constraints.items()}
116
+ are_morph_features_valid = all(
117
+ f in self.li_feature_set.columns for f in morph_constraints.keys()
118
+ )
119
+ are_universal_features_valid = all(
120
+ f in self.li_feature_set.columns for f in universal_constraints.keys()
121
+ )
122
+ if not are_morph_features_valid or not are_universal_features_valid:
123
+ raise KeyError(
124
+ "Features provided for candidate set are not valid features in the dataset."
125
+ )
126
+
127
+ all_constraints = {**universal_constraints, **morph_constraints}
128
+ candidate_set = self._construct_candidate_set(
129
+ self.li_feature_set, all_constraints
130
+ )
131
+ return candidate_set
132
+
133
+ def _build_masked_dataset_grew(self, filepath: Path, grew_query: str, dependency_node: str,
134
+ mask_token, encoding: str = "utf-8"):
135
+ masked_dataset = []
136
+ exception_dataset = []
137
+
138
+ get_tokens_to_mask = match_dependencies(filepath, grew_query, dependency_node)
139
+
140
+ try:
141
+ with open(filepath, "r", encoding=encoding) as data_file:
142
+ for sentence in parse_incr(data_file):
143
+ logging.info(f"Processing sentence: {sentence.metadata["sent_id"]}")
144
+
145
+ sentence_id = sentence.metadata["sent_id"]
146
+ sentence_text = sentence.metadata["text"]
147
+ if sentence_id in get_tokens_to_mask:
148
+ for i in range(len(sentence)):
149
+ sentence[i]["index"] = i
150
+
151
+ token_to_mask_id = get_tokens_to_mask[sentence_id]
152
+
153
+ try:
154
+ t_match = [tok for tok in sentence if tok.get("id") == token_to_mask_id][0]
155
+ t_match_form = t_match["form"]
156
+ t_match_index = t_match["index"]
157
+ sentence_as_str_list = [t["form"] for t in sentence]
158
+ except KeyError:
159
+ logging.info("There was a mismatch for the GREW-based ID and the Conllu ID.")
160
+ exception_dataset.append(
161
+ {
162
+ "sentence_id": sentence_id,
163
+ "match_id": None,
164
+ "all_tokens": None,
165
+ "match_token": None,
166
+ "original_text": sentence_text,
167
+ }
168
+ )
169
+ continue
170
+
171
+ try:
172
+ matched_token_start_index = self.lexer.recursive_match_token(
173
+ sentence_text, # the original string
174
+ sentence_as_str_list.copy(), # the string as a list of tokens
175
+ t_match_index, # the index of the token to be replaced
176
+ [
177
+ "_",
178
+ " ",
179
+ ], # todo: skip lines where we don't encounter accounted for tokens
180
+ )
181
+ except ValueError:
182
+ print("Token not found. Saving as exception.")
183
+ exception_dataset.append(
184
+ {
185
+ "sentence_id": sentence_id,
186
+ "match_id": token_to_mask_id,
187
+ "all_tokens": sentence_as_str_list,
188
+ "match_token": t_match_form,
189
+ "original_text": sentence_text,
190
+ }
191
+ )
192
+ continue
193
+
194
+ # let's replace the matched token with a MASK token
195
+ masked_sentence = self.lexer.perform_token_surgery(
196
+ sentence_text,
197
+ t_match_form,
198
+ mask_token,
199
+ matched_token_start_index,
200
+ )
201
+
202
+ # the sentence ID and match ID are together a primary key
203
+ masked_dataset.append(
204
+ {
205
+ "sentence_id": sentence_id,
206
+ "match_id": token_to_mask_id,
207
+ "all_tokens": sentence_as_str_list,
208
+ "num_tokens": len(sentence_as_str_list),
209
+ "match_token": t_match_form,
210
+ "original_text": sentence_text,
211
+ "masked_text": masked_sentence,
212
+ }
213
+ )
214
+ except FileNotFoundError:
215
+ print(f"Error: The file '{filepath}' was not found.")
216
+
217
+ masked_dataset_df = pd.DataFrame(masked_dataset)
218
+ exception_dataset_df = pd.DataFrame(exception_dataset)
219
+
220
+ return {"masked": masked_dataset_df, "exception": exception_dataset_df}
221
+
222
+ def _build_masked_dataset(
223
+ self, filepath: str, morph_constraints: dict, upos_constraint: str | None, mask_token: str, encoding: str = "utf-8"
224
+ ) -> dict[str, pd.DataFrame]:
225
+ masked_dataset = []
226
+ exception_dataset = []
227
+
228
+ try:
229
+ with open(filepath, "r", encoding=encoding) as data_file:
230
+ constraints_kwargs = {f"feats__{k.capitalize()}": v for k, v in morph_constraints.items()}
231
+
232
+ for sentence in parse_incr(data_file):
233
+ logging.info(f"Processing sentence: {sentence.metadata["sent_id"]}")
234
+
235
+ # MORPHOLOGICAL FILTER
236
+ token_constraint_matches = sentence.filter(**constraints_kwargs)
237
+
238
+ # UNIVERSAL POS FILTER
239
+ if upos_constraint:
240
+ token_constraint_matches = sentence.filter(lambda token: token.upos == upos_constraint)
241
+
242
+
243
+ if token_constraint_matches:
244
+ for i in range(len(sentence)):
245
+ sentence[i]["index"] = i
246
+
247
+ # sentence_text = " ".join(token["form"] for token in sentence)
248
+ sentence_text = sentence.metadata["text"]
249
+ sentence_id = sentence.metadata["sent_id"]
250
+
251
+ matches = [t["form"] for t in token_constraint_matches]
252
+ match_indices = [t["index"] for t in token_constraint_matches]
253
+
254
+ # iterate over each match in the sentence
255
+ for t_match_index, t_match in zip(match_indices, matches):
256
+ # we want to create one sentence entry per example
257
+ # so if we have two subjunctive's in one sentence for instance,
258
+ # there will be two test sentences
259
+
260
+ # at what point in the string does the matched token start?
261
+ sentence_as_str_list = [t["form"] for t in sentence]
262
+
263
+ try:
264
+ matched_token_start_index = self.lexer.recursive_match_token(
265
+ sentence_text,
266
+ sentence_as_str_list.copy(),
267
+ t_match_index,
268
+ [
269
+ "_",
270
+ " ",
271
+ ], # todo: skip lines where we don't encounter accounted for tokens
272
+ )
273
+ except ValueError:
274
+ print("Token not found. Saving as exception.")
275
+ exception_dataset.append(
276
+ {
277
+ "sentence_id": sentence_id,
278
+ "match_id": t_match_index,
279
+ "all_tokens": sentence_as_str_list,
280
+ "match_token": t_match,
281
+ "original_text": sentence_text,
282
+ }
283
+ )
284
+ continue
285
+
286
+ # let's replace the matched token with a MASK token
287
+ masked_sentence = self.lexer.perform_token_surgery(
288
+ sentence_text,
289
+ t_match,
290
+ mask_token,
291
+ matched_token_start_index,
292
+ )
293
+
294
+ # the sentence ID and match ID are together a primary key
295
+ masked_dataset.append(
296
+ {
297
+ "sentence_id": sentence_id,
298
+ "match_id": t_match_index,
299
+ "all_tokens": sentence_as_str_list,
300
+ "num_tokens": len(sentence_as_str_list),
301
+ "match_token": t_match,
302
+ "original_text": sentence_text,
303
+ "masked_text": masked_sentence,
304
+ }
305
+ )
306
+ except FileNotFoundError:
307
+ print(f"Error: The file '{filepath}' was not found.")
308
+
309
+ masked_dataset_df = pd.DataFrame(masked_dataset)
310
+ exception_dataset_df = pd.DataFrame(exception_dataset)
311
+
312
+ return {"masked": masked_dataset_df, "exception": exception_dataset_df}
313
+
314
+ def _is_valid_token(self, token: Token) -> bool:
315
+ punctuation = [".", ",", "!", "?", "*"]
316
+
317
+ # skip multiword tokens, malformed entries and punctuation
318
+ is_punctuation = token.get("form") in punctuation
319
+ is_valid_type = isinstance(token, dict)
320
+ has_valid_id = isinstance(token.get("id"), int)
321
+ return is_valid_type and has_valid_id and not is_punctuation
322
+
323
+ def _build_token_row(self, token: Token, sentence_id: str) -> dict[str, Any]:
324
+ # get all token features such as Person, Mood, etc
325
+ feats = token.get("feats") or {}
326
+
327
+ row = {
328
+ "sentence_id": sentence_id,
329
+ "token_id": token.get("id") - 1, # ID's are reduced by one to start at 0
330
+ "form": token.get("form"),
331
+ "lemma": token.get("lemma"),
332
+ "upos": token.get("upos"),
333
+ "xpos": token.get("xpos"),
334
+ }
335
+
336
+ # add each morphological feature as a column
337
+ for feat_name, feat_value in feats.items():
338
+ row["feats__" + feat_name.lower()] = feat_value
339
+
340
+ return row
341
+
342
+ def _build_lexical_item_dataset(self, conllu_path: str) -> pd.DataFrame:
343
+ rows = []
344
+
345
+ with open(conllu_path, "r", encoding="utf-8") as f:
346
+ for i, tokenlist in enumerate(parse_incr(f)):
347
+ # get the sentence ID in the dataset
348
+ sent_id = tokenlist.metadata["sent_id"]
349
+ logging.info(f"Building LI Set For Sentence: {sent_id}")
350
+
351
+ # iterate over each token
352
+ for token in tokenlist:
353
+ # check if it's worth saving to our lexical item dataset
354
+ is_valid_token = self._is_valid_token(token)
355
+ if not is_valid_token:
356
+ continue
357
+
358
+ # from the token object create a dict and append
359
+ row = self._build_token_row(token, sent_id)
360
+ rows.append(row)
361
+
362
+ lexical_item_df = pd.DataFrame(rows)
363
+
364
+ # make sure our nan values are interpreted as such
365
+ lexical_item_df.replace("nan", np.nan, inplace=True)
366
+
367
+ # create the (Sentence ID, Token ID) primary key
368
+ lexical_item_df.set_index(["sentence_id", "token_id"], inplace=True)
369
+
370
+ self.li_feature_set = lexical_item_df
371
+
372
+ return lexical_item_df
373
+
374
+ """
375
+ -- Candidate Set --
376
+ This constructs a list of words which have the same feature set as the
377
+ target features which are passed as an argument.
378
+ """
379
+
380
+ def _construct_candidate_set(
381
+ self, li_feature_set: pd.DataFrame, target_features: dict
382
+ ) -> pd.DataFrame:
383
+ # optionally restrict search to a certain type of lexical item
384
+ subset = li_feature_set
385
+
386
+ # continuously filter the dataframe so as to be left
387
+ # only with those lexical items which match the target
388
+ # features
389
+ # this includes cases
390
+ for feat, value in target_features.items():
391
+ # ensure feature is a valid feature in feature set
392
+ if feat not in subset.columns:
393
+ raise KeyError(
394
+ "Invalid feature provided to confound set: {}".format(feat)
395
+ )
396
+
397
+ # slim the mask down using each feature
398
+ # interesting edge case: np.nan == np.nan returns false!
399
+ mask = (subset[feat] == value) | (subset[feat].isna() & pd.isna(value))
400
+ subset = subset[mask]
401
+
402
+ return subset
treetse/preprocessing/grew_dependencies.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from grewpy import Corpus, Request, set_config
2
+ from pathlib import Path
3
+
4
+
5
+ def match_dependencies(
6
+ corpus_path: Path, grew_query: str, dependency_node: str
7
+ ) -> dict:
8
+ set_config("sud") # ud or basic
9
+ # run the GREW request on the corpus
10
+ corpus = Corpus(str(corpus_path))
11
+ request = Request().pattern(grew_query)
12
+ occurrences = corpus.search(request)
13
+
14
+ # step 2
15
+ dep_matches = {}
16
+ for occ in occurrences:
17
+ sent_id = occ["sent_id"]
18
+ object_node_id = int(occ["matching"]["nodes"][dependency_node])
19
+ dep_matches[sent_id] = object_node_id
20
+ return dep_matches
treetse/preprocessing/reconstruction.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Lexer:
2
+ def __init__(self) -> None:
3
+ pass
4
+
5
+ def perform_token_surgery(
6
+ self,
7
+ sentence: str,
8
+ original_token: str,
9
+ replacement_token: str,
10
+ start_index: int,
11
+ ) -> str:
12
+ t_len = len(original_token)
13
+
14
+ return (
15
+ sentence[:start_index] + replacement_token + sentence[start_index + t_len :]
16
+ )
17
+
18
+ def recursive_match_token(
19
+ self,
20
+ full_sentence: str,
21
+ token_list: list[str],
22
+ token_list_mask_index: int,
23
+ skippable_tokens: list[str],
24
+ ) -> int:
25
+ # ensure we can retrieve another token
26
+ n_remaining_tokens = len(token_list)
27
+ if n_remaining_tokens == 0:
28
+ raise ValueError(
29
+ "Mask index not reached but token list has been iterated for sentence: {}".format(
30
+ full_sentence
31
+ )
32
+ )
33
+ t = token_list[0]
34
+
35
+ # returns the index of the first occurrence
36
+ # of the token t
37
+ match_index = full_sentence.find(t)
38
+ is_match_found = match_index != -1
39
+ has_reached_mask_token = token_list_mask_index == 0
40
+
41
+ # BASE CASE
42
+ if has_reached_mask_token and is_match_found:
43
+ # we're at the end
44
+ return match_index
45
+ # RECURSIVE CASE
46
+ elif is_match_found:
47
+ sliced_sentence = full_sentence[match_index + len(t) :]
48
+ token_list.pop(0)
49
+
50
+ return (
51
+ match_index
52
+ + len(t)
53
+ + self.recursive_match_token(
54
+ sliced_sentence,
55
+ token_list,
56
+ token_list_mask_index - 1,
57
+ skippable_tokens,
58
+ )
59
+ )
60
+ else:
61
+ # no match found, is t irrelevant?
62
+ if t in skippable_tokens:
63
+ # need to watch out with the slicing here
64
+ # tests are important
65
+ sliced_sentence = full_sentence[len(t) - 1 :]
66
+ token_list.pop(0)
67
+ return self.recursive_match_token(
68
+ sliced_sentence,
69
+ token_list,
70
+ token_list_mask_index - 1,
71
+ skippable_tokens,
72
+ )
73
+ else:
74
+ raise ValueError(
75
+ "Token not found in string nor has it been specified as skippable: {}".format(
76
+ t
77
+ )
78
+ )
treetse/visualise/__pycache__/visualiser.cpython-312.pyc ADDED
Binary file (4.78 kB). View file
 
treetse/visualise/__pycache__/visualiser.cpython-313.pyc ADDED
Binary file (4.4 kB). View file
 
treetse/visualise/visualiser.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from plotnine import labs, theme, theme_bw, guides, position_nudge, aes, geom_violin, geom_boxplot, geom_line, geom_jitter, scale_x_discrete, ggplot
3
+ from pathlib import Path
4
+ import math
5
+
6
+
7
+ class Visualiser:
8
+ def __init__(self) -> None:
9
+ self.data = None
10
+
11
+ def load_dataset(self, results: pd.DataFrame) -> bool:
12
+ self.data = results
13
+
14
+ def visualise_slope(
15
+ self,
16
+ path: Path,
17
+ results: pd.DataFrame,
18
+ target_x_label: str,
19
+ alt_x_label: str,
20
+ x_axis_label: str,
21
+ y_axis_label: str,
22
+ title: str,
23
+ ):
24
+ lsize = 0.65
25
+ fill_alpha = 0.7
26
+
27
+ # X-axis: Acc, Gen
28
+ # Y-axis: surprisal
29
+ filtered_df = results[
30
+ results["alternative"].notna() & (results["alternative"].str.strip() != "")
31
+ ]
32
+ print("Number of filtered results: ", len(filtered_df))
33
+ print(filtered_df.head())
34
+
35
+ filtered_df["subject_id"] = filtered_df.index
36
+
37
+ print(filtered_df.head())
38
+
39
+ # Melt the dataframe
40
+ df_long = pd.melt(
41
+ filtered_df,
42
+ id_vars=["subject_id", "num_tokens"],
43
+ value_vars=["label_prob", "alternative_prob"],
44
+ var_name="source",
45
+ value_name="log_prob",
46
+ )
47
+
48
+ # Map source to fixed x-axis labels
49
+ df_long["x_label"] = df_long["source"].map(
50
+ {"label_prob": target_x_label, "alternative_prob": alt_x_label}
51
+ )
52
+
53
+ print(df_long.head())
54
+
55
+ def surprisal(p: float) -> float:
56
+ return -math.log2(p)
57
+
58
+ def confidence(p: float) -> float:
59
+ return math.log2(p)
60
+
61
+ df_long["surprisal"] = df_long["log_prob"].apply(confidence)
62
+ print(df_long.head())
63
+
64
+ p = (
65
+ ggplot(df_long, aes(x="x_label", y="surprisal", fill="x_label"))
66
+ + scale_x_discrete(limits=[target_x_label, alt_x_label])
67
+ + geom_jitter(
68
+ aes(color="x_label", size="num_tokens"), width=0.01, alpha=0.7
69
+ )
70
+ +
71
+ # geom_text(aes(label='label'), nudge_y=0.1) +
72
+ geom_line(aes(group="subject_id"), color="gray", alpha=0.7, size=0.2)
73
+ + geom_boxplot(
74
+ df_long[df_long["x_label"] == target_x_label],
75
+ aes(x="x_label", y="surprisal", group="x_label"),
76
+ width=0.2,
77
+ alpha=0.4,
78
+ size=0.6,
79
+ outlier_shape=None,
80
+ show_legend=False,
81
+ position=position_nudge(x=-0.2),
82
+ )
83
+ + geom_boxplot(
84
+ df_long[df_long["x_label"] == alt_x_label],
85
+ aes(x="x_label", y="surprisal", group="x_label"),
86
+ width=0.2,
87
+ alpha=0.4,
88
+ size=0.6,
89
+ outlier_shape=None,
90
+ show_legend=False,
91
+ position=position_nudge(x=0.2),
92
+ )
93
+ + geom_violin(
94
+ df_long[df_long["x_label"] == target_x_label],
95
+ aes(x="x_label", y="surprisal", group="x_label"),
96
+ position=position_nudge(x=-0.4),
97
+ style="left-right",
98
+ alpha=fill_alpha,
99
+ size=lsize,
100
+ )
101
+ + geom_violin(
102
+ df_long[df_long["x_label"] == alt_x_label],
103
+ aes(x="x_label", y="surprisal", group="x_label"),
104
+ position=position_nudge(x=0.4),
105
+ style="right-left",
106
+ alpha=fill_alpha,
107
+ size=lsize,
108
+ )
109
+ + guides(fill=False)
110
+ + theme_bw()
111
+ + theme(figure_size=(8, 4), legend_position="none")
112
+ + labs(x=x_axis_label, y=y_axis_label, title=title)
113
+ )
114
+ p.save(path, width=14, height=8, dpi=300)