Commit
·
9df8979
1
Parent(s):
008fd4d
Upload model
Browse files
model.py
CHANGED
|
@@ -88,43 +88,43 @@ class CybersecurityKnowledgeGraphModel(PreTrainedModel):
|
|
| 88 |
structured_output.extend(batch_output)
|
| 89 |
|
| 90 |
|
| 91 |
-
args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
|
| 92 |
|
| 93 |
-
entities = []
|
| 94 |
-
current_entity = None
|
| 95 |
-
for position, label, token in args:
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
for entity in entities:
|
| 106 |
-
|
| 107 |
-
|
| 108 |
|
| 109 |
-
for entity in entities:
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
|
| 123 |
-
for item in structured_output:
|
| 124 |
-
|
| 125 |
-
for entity in entities:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
return structured_output
|
| 129 |
|
| 130 |
def forward_model(self, model, dataloader):
|
|
|
|
| 88 |
structured_output.extend(batch_output)
|
| 89 |
|
| 90 |
|
| 91 |
+
# args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(structured_output) if item["argument"]!= "O"]
|
| 92 |
|
| 93 |
+
# entities = []
|
| 94 |
+
# current_entity = None
|
| 95 |
+
# for position, label, token in args:
|
| 96 |
+
# if label.startswith('B-'):
|
| 97 |
+
# if current_entity is not None:
|
| 98 |
+
# entities.append(current_entity)
|
| 99 |
+
# current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position}
|
| 100 |
+
# elif label.startswith('I-'):
|
| 101 |
+
# if current_entity is not None:
|
| 102 |
+
# current_entity['text'] += ' ' + token.replace(" ", "")
|
| 103 |
+
# current_entity['end'] = position
|
| 104 |
+
|
| 105 |
+
# for entity in entities:
|
| 106 |
+
# context = self.tokenizer.decode([item["id"] for item in structured_output[max(0, entity["start"] - 15) : min(len(structured_output), entity["end"] + 15)]])
|
| 107 |
+
# entity["context"] = context
|
| 108 |
|
| 109 |
+
# for entity in entities:
|
| 110 |
+
# if len(self.arg_2_role[entity["label"]]) > 1:
|
| 111 |
+
# sent_embed = self.embed_model.encode(entity["context"])
|
| 112 |
+
# arg_embed = self.embed_model.encode(entity["text"])
|
| 113 |
+
# embed = np.concatenate((sent_embed, arg_embed))
|
| 114 |
+
|
| 115 |
+
# arg_clf = self.role_classifiers[entity["label"]]
|
| 116 |
+
# role_id = arg_clf.predict(embed.reshape(1, -1))
|
| 117 |
+
# role = self.arg_2_role[entity["label"]][role_id[0]]
|
| 118 |
+
|
| 119 |
+
# entity["role"] = role
|
| 120 |
+
# else:
|
| 121 |
+
# entity["role"] = self.arg_2_role[entity["label"]][0]
|
| 122 |
|
| 123 |
+
# for item in structured_output:
|
| 124 |
+
# item["role"] = "O"
|
| 125 |
+
# for entity in entities:
|
| 126 |
+
# for i in range(entity["start"], entity["end"] + 1):
|
| 127 |
+
# structured_output[i]["role"] = entity["role"]
|
| 128 |
return structured_output
|
| 129 |
|
| 130 |
def forward_model(self, model, dataloader):
|