Commit
·
621df19
1
Parent(s):
8241ba7
Upload model
Browse files- event_arg_predict.py +5 -5
- event_nugget_predict.py +3 -3
- event_realis_predict.py +5 -5
event_arg_predict.py
CHANGED
|
@@ -3,11 +3,11 @@ from annotated_text import annotated_text
|
|
| 3 |
import torch
|
| 4 |
from torch.utils.data import DataLoader
|
| 5 |
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
|
| 10 |
-
from
|
| 11 |
import spacy
|
| 12 |
from transformers import AutoTokenizer
|
| 13 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
|
@@ -37,7 +37,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|
| 37 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 38 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 39 |
|
| 40 |
-
from
|
| 41 |
model_nugget = ArgumentModel(num_classes=43)
|
| 42 |
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
|
| 43 |
model_nugget.eval()
|
|
|
|
| 3 |
import torch
|
| 4 |
from torch.utils.data import DataLoader
|
| 5 |
|
| 6 |
+
from .args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
| 7 |
+
from .nugget_model_utils import CustomRobertaWithPOS
|
| 8 |
+
from .utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list
|
| 9 |
|
| 10 |
+
from .event_nugget_predict import get_event_nuggets
|
| 11 |
import spacy
|
| 12 |
from transformers import AutoTokenizer
|
| 13 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
|
|
|
| 37 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 38 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 39 |
|
| 40 |
+
from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
|
| 41 |
model_nugget = ArgumentModel(num_classes=43)
|
| 42 |
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
|
| 43 |
model_nugget.eval()
|
event_nugget_predict.py
CHANGED
|
@@ -3,9 +3,9 @@ from annotated_text import annotated_text
|
|
| 3 |
import torch
|
| 4 |
from torch import nn
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
import spacy
|
| 10 |
from transformers import AutoTokenizer
|
| 11 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
|
|
|
| 3 |
import torch
|
| 4 |
from torch import nn
|
| 5 |
from torch.utils.data import DataLoader
|
| 6 |
+
from .nugget_model_utils import CustomRobertaWithPOS as NuggetModel
|
| 7 |
+
from .nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
| 8 |
+
from .utils import get_idxs_from_text, event_nugget_list
|
| 9 |
import spacy
|
| 10 |
from transformers import AutoTokenizer
|
| 11 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
event_realis_predict.py
CHANGED
|
@@ -3,12 +3,12 @@ import spacy
|
|
| 3 |
import torch
|
| 4 |
from torch.utils.data import DataLoader
|
| 5 |
from transformers import AutoTokenizer
|
| 6 |
-
from
|
| 7 |
import streamlit as st
|
| 8 |
from annotated_text import annotated_text
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
| 13 |
|
| 14 |
event_nugget_list = ['B-Phishing',
|
|
@@ -49,7 +49,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|
| 49 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 51 |
|
| 52 |
-
from
|
| 53 |
model_realis = RealisModel(num_classes_realis=4)
|
| 54 |
model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
|
| 55 |
model_realis.eval()
|
|
|
|
| 3 |
import torch
|
| 4 |
from torch.utils.data import DataLoader
|
| 5 |
from transformers import AutoTokenizer
|
| 6 |
+
from .utils import get_idxs_from_text
|
| 7 |
import streamlit as st
|
| 8 |
from annotated_text import annotated_text
|
| 9 |
+
from .nugget_model_utils import CustomRobertaWithPOS
|
| 10 |
+
from .event_nugget_predict import get_event_nuggets
|
| 11 |
+
from .realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
|
| 12 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
| 13 |
|
| 14 |
event_nugget_list = ['B-Phishing',
|
|
|
|
| 49 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
| 50 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
| 51 |
|
| 52 |
+
from .realis_model_utils import CustomRobertaWithPOS as RealisModel
|
| 53 |
model_realis = RealisModel(num_classes_realis=4)
|
| 54 |
model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
|
| 55 |
model_realis.eval()
|