| import streamlit as st | |
| from transformers import AutoModelForTokenClassification | |
| from annotated_text import annotated_text | |
| import numpy as np | |
| import os, joblib | |
| from utils import get_idxs_from_text | |
| model = AutoModelForTokenClassification.from_pretrained("CyberPeace-Institute/Cybersecurity-Knowledge-Graph", trust_remote_code=True) | |
| role_classifiers = {} | |
| folder_path = '/arg_role_models' | |
| for filename in os.listdir(os.getcwd() + folder_path): | |
| if filename.endswith('.joblib'): | |
| file_path = os.getcwd() + os.path.join(folder_path, filename) | |
| clf = joblib.load(file_path) | |
| arg = filename.split(".")[0] | |
| role_classifiers[arg] = clf | |
| def annotate(name): | |
| tokens = [item["token"] for item in output] | |
| tokens = [token.replace(" ", "") for token in tokens] | |
| text = model.tokenizer.decode([item["id"] for item in output]) | |
| idxs = get_idxs_from_text(text, tokens) | |
| labels = [item[name] for item in output] | |
| annotated_text_list = [] | |
| last_label = "" | |
| cumulative_tokens = "" | |
| last_id = 0 | |
| for idx, label in zip(idxs, labels): | |
| to_label = label | |
| label_short = to_label.split("-")[1] if "-" in to_label else to_label | |
| if last_label == label_short: | |
| cumulative_tokens += text[last_id : idx["end_idx"]] | |
| last_id = idx["end_idx"] | |
| else: | |
| if last_label != "": | |
| if last_label == "O": | |
| annotated_text_list.append(cumulative_tokens) | |
| else: | |
| annotated_text_list.append((cumulative_tokens, last_label)) | |
| last_label = label_short | |
| cumulative_tokens = idx["word"] | |
| last_id = idx["end_idx"] | |
| if last_label == "O": | |
| annotated_text_list.append(cumulative_tokens) | |
| else: | |
| annotated_text_list.append((cumulative_tokens, last_label)) | |
| annotated_text(annotated_text_list) | |
| def get_arg_roles(output): | |
| args = [(idx, item["argument"], item["token"]) for idx, item in enumerate(output) if item["argument"]!= "O"] | |
| entities = [] | |
| current_entity = None | |
| for position, label, token in args: | |
| if label.startswith('B-'): | |
| if current_entity is not None: | |
| entities.append(current_entity) | |
| current_entity = {'label': label[2:], 'text': token.replace(" ", ""), 'start': position, 'end': position} | |
| elif label.startswith('I-'): | |
| if current_entity is not None: | |
| current_entity['text'] += ' ' + token.replace(" ", "") | |
| current_entity['end'] = position | |
| for entity in entities: | |
| context = model.tokenizer.decode([item["id"] for item in output[max(0, entity["start"] - 15) : min(len(output), entity["end"] + 15)]]) | |
| entity["context"] = context | |
| for entity in entities: | |
| if len(model.arg_2_role[entity["label"]]) > 1: | |
| sent_embed = model.embed_model.encode(entity["context"]) | |
| arg_embed = model.embed_model.encode(entity["text"]) | |
| embed = np.concatenate((sent_embed, arg_embed)) | |
| arg_clf = role_classifiers[entity["label"]] | |
| role_id = arg_clf.predict(embed.reshape(1, -1)) | |
| role = model.arg_2_role[entity["label"]][role_id[0]] | |
| entity["role"] = role | |
| else: | |
| entity["role"] = model.arg_2_role[entity["label"]][0] | |
| for item in output: | |
| item["role"] = "O" | |
| for entity in entities: | |
| for i in range(entity["start"], entity["end"] + 1): | |
| output[i]["role"] = entity["role"] | |
| return output | |
| st.title("Create Knowledge Graphs from Cyber Incidents") | |
| text_input = st.text_area("Enter your text here", height=100) | |
| if text_input or st.button('Apply'): | |
| output = model(text_input) | |
| st.subheader("Event Nuggets") | |
| annotate("nugget") | |
| st.subheader("Event Arguments") | |
| annotate("argument") | |
| st.subheader("Realis of Event Nuggets") | |
| annotate("realis") | |
| output = get_arg_roles(output) | |
| st.subheader("Role of the Event Arguments") | |
| annotate("role") | |