Update model
Browse files- model_consts.py +2 -2
- segmenter.ckpt +2 -2
- train.py +1 -1
- utils.py +60 -64
model_consts.py
CHANGED
|
@@ -4,6 +4,6 @@ else:
|
|
| 4 |
from .utils import get_upenn_tags_dict
|
| 5 |
|
| 6 |
input_size = len(get_upenn_tags_dict())
|
| 7 |
-
embedding_size =
|
| 8 |
-
hidden_size =
|
| 9 |
num_layers = 2
|
|
|
|
| 4 |
from .utils import get_upenn_tags_dict
|
| 5 |
|
| 6 |
input_size = len(get_upenn_tags_dict())
|
| 7 |
+
embedding_size = 256
|
| 8 |
+
hidden_size = 256
|
| 9 |
num_layers = 2
|
segmenter.ckpt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8e6209584d0021684bb3a09ec1b717843f3086dfcc6411c57276f743f8e62fa
|
| 3 |
+
size 10584544
|
train.py
CHANGED
|
@@ -26,6 +26,6 @@ if __name__ == "__main__":
|
|
| 26 |
|
| 27 |
model.to(device)
|
| 28 |
|
| 29 |
-
train_bidirlstm_embedding_model(model, dataset, num_epochs=
|
| 30 |
|
| 31 |
torch.save(model.state_dict(), "segmenter.ckpt")
|
|
|
|
| 26 |
|
| 27 |
model.to(device)
|
| 28 |
|
| 29 |
+
train_bidirlstm_embedding_model(model, dataset, num_epochs=150, batch_size=2)
|
| 30 |
|
| 31 |
torch.save(model.state_dict(), "segmenter.ckpt")
|
utils.py
CHANGED
|
@@ -4,6 +4,64 @@ from stable_whisper.result import WordTiming
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def bind_wordtimings_to_tags(wt: list[WordTiming]):
|
| 8 |
raw_words = [w.word for w in wt]
|
| 9 |
|
|
@@ -16,6 +74,7 @@ def bind_wordtimings_to_tags(wt: list[WordTiming]):
|
|
| 16 |
tokens_wordtiming_map.append(len(tokens_word))
|
| 17 |
|
| 18 |
tagged_words = nltk.pos_tag(tokenized_raw_words)
|
|
|
|
| 19 |
|
| 20 |
grouped_tags = []
|
| 21 |
|
|
@@ -49,6 +108,7 @@ def tag_training_data(filename: str):
|
|
| 49 |
|
| 50 |
tokenized_full_text = nltk.word_tokenize(full_text)
|
| 51 |
tagged_full_text = nltk.pos_tag(tokenized_full_text)
|
|
|
|
| 52 |
|
| 53 |
tagged_full_text_copy = tagged_full_text
|
| 54 |
|
|
@@ -75,70 +135,6 @@ def tag_training_data(filename: str):
|
|
| 75 |
|
| 76 |
return reconstructed_tags
|
| 77 |
|
| 78 |
-
def get_upenn_tags_dict():
|
| 79 |
-
# tagger = PerceptronTagger()
|
| 80 |
-
|
| 81 |
-
# tags = list(tagger.tagdict.values())
|
| 82 |
-
|
| 83 |
-
# # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
|
| 84 |
-
# tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
|
| 85 |
-
# tags = list(set(tags))
|
| 86 |
-
# tags.sort()
|
| 87 |
-
# tags.append("BREAK")
|
| 88 |
-
|
| 89 |
-
# tags_dict = dict()
|
| 90 |
-
|
| 91 |
-
# for index, tag in enumerate(tags):
|
| 92 |
-
# tags_dict[tag] = index
|
| 93 |
-
|
| 94 |
-
return {'#': 0,
|
| 95 |
-
'$': 1,
|
| 96 |
-
"''": 2,
|
| 97 |
-
'(': 3,
|
| 98 |
-
')': 4,
|
| 99 |
-
',': 5,
|
| 100 |
-
'.': 6,
|
| 101 |
-
':': 7,
|
| 102 |
-
'CC': 8,
|
| 103 |
-
'CD': 9,
|
| 104 |
-
'DT': 10,
|
| 105 |
-
'EX': 11,
|
| 106 |
-
'FW': 12,
|
| 107 |
-
'IN': 13,
|
| 108 |
-
'JJ': 14,
|
| 109 |
-
'JJR': 15,
|
| 110 |
-
'JJS': 16,
|
| 111 |
-
'LS': 17,
|
| 112 |
-
'MD': 18,
|
| 113 |
-
'NN': 19,
|
| 114 |
-
'NNP': 20,
|
| 115 |
-
'NNPS': 21,
|
| 116 |
-
'NNS': 22,
|
| 117 |
-
'PDT': 23,
|
| 118 |
-
'POS': 24,
|
| 119 |
-
'PRP': 25,
|
| 120 |
-
'PRP$': 26,
|
| 121 |
-
'RB': 27,
|
| 122 |
-
'RBR': 28,
|
| 123 |
-
'RBS': 29,
|
| 124 |
-
'RP': 30,
|
| 125 |
-
'SYM': 31,
|
| 126 |
-
'TO': 32,
|
| 127 |
-
'UH': 33,
|
| 128 |
-
'VB': 34,
|
| 129 |
-
'VBD': 35,
|
| 130 |
-
'VBG': 36,
|
| 131 |
-
'VBN': 37,
|
| 132 |
-
'VBP': 38,
|
| 133 |
-
'VBZ': 39,
|
| 134 |
-
'WDT': 40,
|
| 135 |
-
'WP': 41,
|
| 136 |
-
'WP$': 42,
|
| 137 |
-
'WRB': 43,
|
| 138 |
-
'``': 44,
|
| 139 |
-
'BREAK': 45}
|
| 140 |
-
|
| 141 |
-
|
| 142 |
def parse_tags(reconstructed_tags):
|
| 143 |
"""
|
| 144 |
Parse reconstructed tags into input/tag datapoint.
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
|
| 7 |
+
additional_tags = {
|
| 8 |
+
"as": "`AS",
|
| 9 |
+
"and": "`AND",
|
| 10 |
+
"of": "`OF",
|
| 11 |
+
"how": "`HOW",
|
| 12 |
+
"but": "`BUT",
|
| 13 |
+
"the": "`THE",
|
| 14 |
+
"a": "`A",
|
| 15 |
+
"an": "`A",
|
| 16 |
+
"which": "`WHICH",
|
| 17 |
+
"what": "`WHAT",
|
| 18 |
+
"where": "`WHERE",
|
| 19 |
+
"that": "`THAT",
|
| 20 |
+
"who": "`WHO",
|
| 21 |
+
"when": "`WHEN",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def get_upenn_tags_dict():
|
| 25 |
+
# tagger = PerceptronTagger()
|
| 26 |
+
|
| 27 |
+
# tags = list(tagger.tagdict.values())
|
| 28 |
+
|
| 29 |
+
# # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
|
| 30 |
+
# tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
|
| 31 |
+
# tags = list(set(tags))
|
| 32 |
+
# tags.sort()
|
| 33 |
+
# tags.append("BREAK")
|
| 34 |
+
|
| 35 |
+
# tags_dict = dict()
|
| 36 |
+
|
| 37 |
+
# for index, tag in enumerate(tags):
|
| 38 |
+
# tags_dict[tag] = index
|
| 39 |
+
|
| 40 |
+
return {'#': 0, '$': 1, "''": 2,'(': 3,')': 4,',': 5,'.': 6,':': 7,'CC': 8,'CD': 9,'DT': 10,'EX': 11,'FW': 12,'IN': 13,'JJ': 14,'JJR': 15,'JJS': 16,'LS': 17,'MD': 18,'NN': 19,'NNP': 20,'NNPS': 21,'NNS': 22,'PDT': 23,'POS': 24,'PRP': 25,'PRP$': 26,'RB': 27,'RBR': 28,'RBS': 29,'RP': 30,'SYM': 31,'TO': 32,'UH': 33,'VB': 34,'VBD': 35,'VBG': 36,'VBN': 37,'VBP': 38,'VBZ': 39,'WDT': 40,'WP': 41,'WP$': 42,'WRB': 43,'``': 44,'BREAK': 45,
|
| 41 |
+
'`AS': 46,
|
| 42 |
+
'`AND': 47,
|
| 43 |
+
'`OF': 48,
|
| 44 |
+
'`HOW': 49,
|
| 45 |
+
'`BUT': 50,
|
| 46 |
+
'`THE': 51,
|
| 47 |
+
'`A': 52,
|
| 48 |
+
'`WHICH': 53,
|
| 49 |
+
'`WHAT': 54,
|
| 50 |
+
'`WHERE': 55,
|
| 51 |
+
'`THAT': 56,
|
| 52 |
+
'`WHO': 57,
|
| 53 |
+
'`WHEN': 58
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def nltk_extend_tags(tagged_text: list[tuple[str, str]]):
|
| 57 |
+
result = []
|
| 58 |
+
for text, tag in tagged_text:
|
| 59 |
+
text_lower = text.lower().strip()
|
| 60 |
+
if text_lower in additional_tags:
|
| 61 |
+
yield (text, additional_tags[text_lower])
|
| 62 |
+
else:
|
| 63 |
+
yield (text, tag)
|
| 64 |
+
|
| 65 |
def bind_wordtimings_to_tags(wt: list[WordTiming]):
|
| 66 |
raw_words = [w.word for w in wt]
|
| 67 |
|
|
|
|
| 74 |
tokens_wordtiming_map.append(len(tokens_word))
|
| 75 |
|
| 76 |
tagged_words = nltk.pos_tag(tokenized_raw_words)
|
| 77 |
+
tagged_words = list(nltk_extend_tags(tagged_words))
|
| 78 |
|
| 79 |
grouped_tags = []
|
| 80 |
|
|
|
|
| 108 |
|
| 109 |
tokenized_full_text = nltk.word_tokenize(full_text)
|
| 110 |
tagged_full_text = nltk.pos_tag(tokenized_full_text)
|
| 111 |
+
tagged_full_text = list(nltk_extend_tags(tagged_full_text))
|
| 112 |
|
| 113 |
tagged_full_text_copy = tagged_full_text
|
| 114 |
|
|
|
|
| 135 |
|
| 136 |
return reconstructed_tags
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
def parse_tags(reconstructed_tags):
|
| 139 |
"""
|
| 140 |
Parse reconstructed tags into input/tag datapoint.
|