HoneyTian commited on
Commit
b328553
·
1 Parent(s): 4b9750d
examples/tutorial_pyltp/README.md CHANGED
@@ -2,9 +2,14 @@
2
 
3
  ```text
4
  工程路径:
 
5
  https://github.com/HuangFJ/pyltp
6
 
7
  模型文件:
8
  https://ltp.ai/download.html
9
 
10
- ```
 
 
 
 
 
2
 
3
  ```text
4
  工程路径:
5
+ https://github.com/HIT-SCIR/pyltp
6
  https://github.com/HuangFJ/pyltp
7
 
8
  模型文件:
9
  https://ltp.ai/download.html
10
 
11
+ 参考信息:
12
+ https://ltp.readthedocs.io/zh-cn/v3.3.0/appendix.html
13
+ https://blog.csdn.net/weixin_43758551/article/details/104266953
14
+
15
+ ```
examples/tutorial_pyltp/srl.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+
6
+ from pyltp import Parser, Postagger, Segmentor, SementicRoleLabeller
7
+
8
+ from project_settings import project_path
9
+
10
+
11
+ def get_args():
12
+ parser = argparse.ArgumentParser()
13
+
14
+ parser.add_argument(
15
+ "--text",
16
+ default="元芳你怎么看?我就趴窗口上看呗!",
17
+ type=str
18
+ )
19
+ parser.add_argument(
20
+ "--ltp_data_dir",
21
+ default=(project_path / "data/pyltp_models/ltp_data_v3.4.0").as_posix(),
22
+ type=str
23
+ )
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+ def main():
29
+ args = get_args()
30
+
31
+ cws_model_path = os.path.join(args.ltp_data_dir, 'cws.model')
32
+ pos_model_path = os.path.join(args.ltp_data_dir, 'pos.model')
33
+ parser_model_path = os.path.join(args.ltp_data_dir, 'parser.model')
34
+ srl_model_path = os.path.join(args.ltp_data_dir, 'pisrl_win.model')
35
+
36
+ segmentor = Segmentor(cws_model_path)
37
+ pos_tagger = Postagger(pos_model_path)
38
+ parser = Parser(parser_model_path)
39
+ srl_labeler = SementicRoleLabeller(srl_model_path)
40
+
41
+ words = segmentor.segment(args.text)
42
+ postags = pos_tagger.postag(words)
43
+ arcs = parser.parse(words, postags)
44
+ roles = srl_labeler.label(words, postags, arcs)
45
+
46
+ for role in roles:
47
+ index = role[0]
48
+ role_ = [("INDEX", (index, index))] + role[1]
49
+ role_ = list(sorted(role_, key=lambda x: x[1][1]))
50
+
51
+ row = ""
52
+ for r in role_:
53
+ name = r[0]
54
+ start = r[1][0]
55
+ end = r[1][1]
56
+ arg_text = "".join(words[start:end+1])
57
+ row += f"{arg_text}/{name}\t"
58
+ print(row)
59
+
60
+ segmentor.release()
61
+ pos_tagger.release()
62
+ parser.release()
63
+ srl_labeler.release()
64
+
65
+ return
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
main.py CHANGED
@@ -21,7 +21,16 @@ log.setup(log_directory=log_directory)
21
  import gradio as gr
22
 
23
  from toolbox.os.command import Command
24
- from toolbox.part_of_speech.part_of_speech import language_to_engines, engine_to_tagger, pos_tag
 
 
 
 
 
 
 
 
 
25
 
26
  main_logger = logging.getLogger("main")
27
 
@@ -34,6 +43,11 @@ def get_args():
34
  default=(project_path / "pos_examples.json").as_posix(),
35
  type=str
36
  )
 
 
 
 
 
37
  args = parser.parse_args()
38
  return args
39
 
@@ -44,7 +58,7 @@ def run_pos_tag(text: str, language: str, engine: str) -> str:
44
 
45
  begin = time.time()
46
 
47
- words, postags = pos_tag(text, engine)
48
  result = ""
49
  for word, postag in zip(words, postags):
50
  row = f"{word}/{postag}"
@@ -58,6 +72,33 @@ def run_pos_tag(text: str, language: str, engine: str) -> str:
58
  return result
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def shell(cmd: str):
62
  return Command.popen(cmd)
63
 
@@ -67,16 +108,8 @@ def main():
67
 
68
  with open(args.pos_example_json_file, "r", encoding="utf-8") as f:
69
  pos_examples: list = json.load(f)
70
-
71
- def get_languages_by_engine(engine: str):
72
- language_list = list()
73
- for k, v in language_to_engines.items():
74
- if engine in v:
75
- language_list.append(k)
76
- return gr.Dropdown(choices=language_list, value=language_list[0], label="language")
77
-
78
- pos_language_choices = list(language_to_engines.keys())
79
- pos_engine_choices = list(engine_to_tagger.keys())
80
 
81
  # blocks
82
  with gr.Blocks() as blocks:
@@ -84,6 +117,16 @@ def main():
84
 
85
  with gr.Tabs():
86
  with gr.TabItem("part of speech"):
 
 
 
 
 
 
 
 
 
 
87
  pos_text = gr.Textbox(value="学而时习之,不亦悦乎。", lines=4, max_lines=50, label="text")
88
 
89
  with gr.Row():
@@ -97,7 +140,7 @@ def main():
97
  )
98
 
99
  pos_engine.change(
100
- get_languages_by_engine,
101
  inputs=[pos_engine],
102
  outputs=[pos_language],
103
  )
@@ -116,6 +159,47 @@ def main():
116
  fn=run_pos_tag,
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  with gr.TabItem("shell"):
120
  shell_text = gr.Textbox(label="cmd")
121
  shell_button = gr.Button("run")
 
21
  import gradio as gr
22
 
23
  from toolbox.os.command import Command
24
+ from toolbox.part_of_speech.part_of_speech import (
25
+ language_to_engines as pos_language_to_engines,
26
+ engine_to_tagger as pos_engine_to_tagger,
27
+ pos_tag
28
+ )
29
+ from toolbox.sementic_role_labeling.sementic_role_labeling import (
30
+ language_to_engines as srl_language_to_engines,
31
+ engine_to_tagger as srl_engine_to_tagger,
32
+ srl
33
+ )
34
 
35
  main_logger = logging.getLogger("main")
36
 
 
43
  default=(project_path / "pos_examples.json").as_posix(),
44
  type=str
45
  )
46
+ parser.add_argument(
47
+ "--srl_example_json_file",
48
+ default=(project_path / "srl_examples.json").as_posix(),
49
+ type=str
50
+ )
51
  args = parser.parse_args()
52
  return args
53
 
 
58
 
59
  begin = time.time()
60
 
61
+ words, postags = pos_tag(text, language, engine)
62
  result = ""
63
  for word, postag in zip(words, postags):
64
  row = f"{word}/{postag}"
 
72
  return result
73
 
74
 
75
+ def run_srl(text: str, language: str, engine: str) -> str:
76
+ try:
77
+ main_logger.info(f"srl started. text: {text}, language: {language}, engine: {engine}")
78
+
79
+ begin = time.time()
80
+
81
+ words, postags, arcs, roles = srl(text, language, engine)
82
+
83
+ result = ""
84
+ for role in roles:
85
+ row = ""
86
+ for r in role:
87
+ name = r[0]
88
+ start = r[1][0]
89
+ end = r[1][1]
90
+ arg_text = "".join(words[start:end+1])
91
+ row += f"{arg_text}/{name}\t"
92
+ result += f"{row}\n"
93
+
94
+ time_cost = time.time() - begin
95
+ result += f"\n\ntime_cost: {round(time_cost, 4)}"
96
+ return result
97
+ except Exception as e:
98
+ result = f"{type(e)}\n{str(e)}"
99
+ return result
100
+
101
+
102
  def shell(cmd: str):
103
  return Command.popen(cmd)
104
 
 
108
 
109
  with open(args.pos_example_json_file, "r", encoding="utf-8") as f:
110
  pos_examples: list = json.load(f)
111
+ with open(args.srl_example_json_file, "r", encoding="utf-8") as f:
112
+ srl_examples: list = json.load(f)
 
 
 
 
 
 
 
 
113
 
114
  # blocks
115
  with gr.Blocks() as blocks:
 
117
 
118
  with gr.Tabs():
119
  with gr.TabItem("part of speech"):
120
+ def pos_get_languages_by_engine(engine: str):
121
+ language_list = list()
122
+ for k, v in pos_language_to_engines.items():
123
+ if engine in v:
124
+ language_list.append(k)
125
+ return gr.Dropdown(choices=language_list, value=language_list[0], label="language")
126
+
127
+ pos_language_choices = list(pos_language_to_engines.keys())
128
+ pos_engine_choices = list(pos_engine_to_tagger.keys())
129
+
130
  pos_text = gr.Textbox(value="学而时习之,不亦悦乎。", lines=4, max_lines=50, label="text")
131
 
132
  with gr.Row():
 
140
  )
141
 
142
  pos_engine.change(
143
+ pos_get_languages_by_engine,
144
  inputs=[pos_engine],
145
  outputs=[pos_language],
146
  )
 
159
  fn=run_pos_tag,
160
  )
161
 
162
+ with gr.TabItem("srl"):
163
+ def srl_get_languages_by_engine(engine: str):
164
+ language_list = list()
165
+ for k, v in pos_language_to_engines.items():
166
+ if engine in v:
167
+ language_list.append(k)
168
+ return gr.Dropdown(choices=language_list, value=language_list[0], label="language")
169
+
170
+ srl_language_choices = list(srl_language_to_engines.keys())
171
+ srl_engine_choices = list(srl_engine_to_tagger.keys())
172
+
173
+ srl_text = gr.Textbox(value="学而时习之,不亦悦乎。", lines=4, max_lines=50, label="text")
174
+
175
+ with gr.Row():
176
+ srl_language = gr.Dropdown(
177
+ choices=srl_language_choices, value=srl_language_choices[0],
178
+ label="language"
179
+ )
180
+ srl_engine = gr.Dropdown(
181
+ choices=srl_engine_choices, value=srl_engine_choices[0],
182
+ label="engine"
183
+ )
184
+ srl_engine.change(
185
+ srl_get_languages_by_engine,
186
+ inputs=[srl_engine],
187
+ outputs=[srl_language],
188
+ )
189
+ srl_output = gr.Textbox(lines=4, max_lines=50, label="output")
190
+ srl_button = gr.Button(value="pos_tag", variant="primary")
191
+ srl_button.click(
192
+ run_srl,
193
+ inputs=[srl_text, srl_language, srl_engine],
194
+ outputs=[srl_output],
195
+ )
196
+ gr.Examples(
197
+ examples=srl_examples,
198
+ inputs=[srl_text, srl_language, srl_engine],
199
+ outputs=[srl_output],
200
+ fn=run_srl,
201
+ )
202
+
203
  with gr.TabItem("shell"):
204
  shell_text = gr.Textbox(label="cmd")
205
  shell_button = gr.Button("run")
srl_examples.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [
2
+ ["元芳你怎么看?我就趴窗口上看呗!", "chinese", "pyltp"]
3
+ ]
toolbox/part_of_speech/pyltp_pos_tagger.py CHANGED
@@ -9,6 +9,12 @@ ltp_data_dir = os.environ.get("LTP_DATA_DIR")
9
  from pyltp import Postagger, Segmentor
10
 
11
 
 
 
 
 
 
 
12
  @lru_cache(maxsize=5)
13
  def get_pyltp_pos_tagger():
14
  global ltp_data_dir
 
9
  from pyltp import Postagger, Segmentor
10
 
11
 
12
+ pos_name_amp = {
13
+ "nh": "人名",
14
+
15
+ }
16
+
17
+
18
  @lru_cache(maxsize=5)
19
  def get_pyltp_pos_tagger():
20
  global ltp_data_dir
toolbox/sementic_role_labeling/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/sementic_role_labeling/pyltp_srl.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from functools import lru_cache
4
+ import os
5
+ from typing import List, Union
6
+
7
+ ltp_data_dir = os.environ.get("LTP_DATA_DIR")
8
+
9
+ from pyltp import Parser, Postagger, Segmentor, SementicRoleLabeller
10
+
11
+
12
+ @lru_cache(maxsize=5)
13
+ def get_pyltp_srl_tagger():
14
+ global ltp_data_dir
15
+
16
+ cws_model_path = os.path.join(ltp_data_dir, 'cws.model')
17
+ pos_model_path = os.path.join(ltp_data_dir, 'pos.model')
18
+ parser_model_path = os.path.join(ltp_data_dir, 'parser.model')
19
+ srl_model_path = os.path.join(ltp_data_dir, 'pisrl_win.model')
20
+
21
+ segmentor = Segmentor(cws_model_path)
22
+ pos_tagger = Postagger(pos_model_path)
23
+ parser = Parser(parser_model_path)
24
+ srl_labeler = SementicRoleLabeller(srl_model_path)
25
+
26
+ return segmentor, pos_tagger, parser, srl_labeler
27
+
28
+
29
+ def pyltp_srl(text: str, language: str) -> list:
30
+ segmentor, pos_tagger, parser, srl_labeler = get_pyltp_srl_tagger()
31
+
32
+ words = segmentor.segment(text)
33
+ postags = pos_tagger.postag(words)
34
+ arcs = parser.parse(words, postags)
35
+ roles = srl_labeler.label(words, postags, arcs)
36
+
37
+ roles_ = list()
38
+ for role in roles:
39
+ index = role[0]
40
+ role_ = [("INDEX", (index, index))] + role[1]
41
+ role_ = list(sorted(role_, key=lambda x: x[1][1]))
42
+ roles_.append(role_)
43
+
44
+ return words, postags, arcs, roles_
45
+
46
+
47
+ if __name__ == "__main__":
48
+ pass
toolbox/sementic_role_labeling/sementic_role_labeling.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Callable, Dict, List, Tuple, Union
4
+
5
+ from toolbox.sementic_role_labeling.pyltp_srl import pyltp_srl
6
+
7
+
8
+ language_to_engines = {
9
+ "chinese": ["pyltp"]
10
+ }
11
+
12
+
13
+ engine_to_tagger: Dict[str, Callable] = {
14
+ "pyltp": pyltp_srl
15
+ }
16
+
17
+
18
+ def srl(text: str, language: str, engine: str):
19
+ srl_tagger = engine_to_tagger.get(engine)
20
+ if srl_tagger is None:
21
+ raise AssertionError(f"engine {engine} not supported.")
22
+
23
+ words, postags, arcs, roles = srl_tagger(text, language)
24
+ return words, postags, arcs, roles
25
+
26
+
27
+ if __name__ == "__main__":
28
+ pass