Pranava Kailash commited on
Commit
d2369e6
Β·
1 Parent(s): 4b4796e

Fixed No data in Tensor error v1

Browse files
Files changed (2) hide show
  1. app.py +108 -40
  2. requirements.txt +5 -2
app.py CHANGED
@@ -1,17 +1,28 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
  from collections import defaultdict
 
4
 
5
  # Load model and tokenizer
6
  path_to_checkpoint = 'PranavaKailash/CyNER-2.0-DeBERTa-v3-base'
7
- tokenizer = AutoTokenizer.from_pretrained(path_to_checkpoint, use_fast=True, max_length=768)
8
- model = AutoModelForTokenClassification.from_pretrained(path_to_checkpoint)
9
 
10
- # Ensure the model is loaded on CPU explicitly to avoid any device issues
11
- model.to('cpu')
 
 
 
 
 
 
 
12
 
13
  # Initialize the NER pipeline
14
- ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer)
 
 
 
 
 
15
 
16
  def tag_sentence(sentence, entities_dict):
17
  """
@@ -21,10 +32,10 @@ def tag_sentence(sentence, entities_dict):
21
  [(e['start'], e['end'], e['entity'], e['word']) for ents in entities_dict.values() for e in ents],
22
  key=lambda x: x[0]
23
  )
24
-
25
  merged_entities = []
26
  current_entity = None
27
-
28
  for start, end, entity_type, word in all_entities:
29
  if current_entity is None:
30
  current_entity = [start, end, entity_type, word]
@@ -35,57 +46,114 @@ def tag_sentence(sentence, entities_dict):
35
  else:
36
  merged_entities.append(tuple(current_entity))
37
  current_entity = [start, end, entity_type, word]
38
-
39
  if current_entity:
40
  merged_entities.append(tuple(current_entity))
41
-
42
  tagged_sentence = ""
43
  last_idx = 0
44
-
45
  for start, end, entity_type, _ in merged_entities:
46
  tagged_sentence += sentence[last_idx:start]
47
  entity_tag = entity_type.replace('I-', 'B-')
48
- tagged_sentence += f"<span style='color:blue'><{entity_tag}></span>{sentence[start:end]}<span style='color:blue'>/{entity_tag}></span>"
49
  last_idx = end
50
-
51
  tagged_sentence += sentence[last_idx:]
52
  return tagged_sentence
53
 
 
54
  def perform_ner(text):
55
  """
56
  Run NER pipeline and prepare results for display.
57
  """
58
- entities = ner_pipeline(text)
59
- entities_dict = defaultdict(list)
60
-
61
- for entity in entities:
62
- entities_dict[entity['entity']].append({
63
- "entity": entity['entity'],
64
- "score": entity['score'],
65
- "index": entity['index'],
66
- "word": entity['word'],
67
- "start": entity['start'],
68
- "end": entity['end']
69
- })
70
-
71
- tagged_sentence = tag_sentence(text, entities_dict)
72
- return dict(entities_dict), tagged_sentence
 
 
 
 
73
 
74
  # Streamlit UI
75
- st.title("CyNER 2.0 - Named Entity Recognition")
76
- st.write("Enter text to get named entity recognition results.")
 
 
 
77
 
78
- input_text = st.text_area("Input Text", "Type your text here...")
 
 
79
 
80
- if st.button("Analyze"):
81
- if input_text.strip():
82
- entities_dict, tagged_sentence = perform_ner(input_text)
83
-
84
- # Display results / visualization
85
- st.subheader("Tagged Entities")
86
- st.markdown(tagged_sentence, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- st.subheader("Entities and Details")
89
- st.json(entities_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  else:
91
- st.warning("Please enter some text for analysis.")
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
  from collections import defaultdict
4
+ import torch
5
 
6
  # Load model and tokenizer
7
  path_to_checkpoint = 'PranavaKailash/CyNER-2.0-DeBERTa-v3-base'
 
 
8
 
9
+ # Load tokenizer
10
+ tokenizer = AutoTokenizer.from_pretrained(path_to_checkpoint, use_fast=True)
11
+
12
+ # Load model with proper device handling
13
+ model = AutoModelForTokenClassification.from_pretrained(
14
+ path_to_checkpoint,
15
+ torch_dtype='auto',
16
+ device_map='cpu'
17
+ )
18
 
19
  # Initialize the NER pipeline
20
+ ner_pipeline = pipeline(
21
+ "ner",
22
+ model=model,
23
+ tokenizer=tokenizer,
24
+ device=-1 # Explicitly use CPU
25
+ )
26
 
27
  def tag_sentence(sentence, entities_dict):
28
  """
 
32
  [(e['start'], e['end'], e['entity'], e['word']) for ents in entities_dict.values() for e in ents],
33
  key=lambda x: x[0]
34
  )
35
+
36
  merged_entities = []
37
  current_entity = None
38
+
39
  for start, end, entity_type, word in all_entities:
40
  if current_entity is None:
41
  current_entity = [start, end, entity_type, word]
 
46
  else:
47
  merged_entities.append(tuple(current_entity))
48
  current_entity = [start, end, entity_type, word]
49
+
50
  if current_entity:
51
  merged_entities.append(tuple(current_entity))
52
+
53
  tagged_sentence = ""
54
  last_idx = 0
55
+
56
  for start, end, entity_type, _ in merged_entities:
57
  tagged_sentence += sentence[last_idx:start]
58
  entity_tag = entity_type.replace('I-', 'B-')
59
+ tagged_sentence += f"<span style='color:blue; background-color: #e6f3ff; padding: 2px; border-radius: 3px;'><strong>{entity_tag}</strong></span><span style='background-color: #fff3cd; padding: 2px; border-radius: 3px;'>{sentence[start:end]}</span>"
60
  last_idx = end
61
+
62
  tagged_sentence += sentence[last_idx:]
63
  return tagged_sentence
64
 
65
+ @st.cache_data
66
  def perform_ner(text):
67
  """
68
  Run NER pipeline and prepare results for display.
69
  """
70
+ try:
71
+ entities = ner_pipeline(text)
72
+ entities_dict = defaultdict(list)
73
+
74
+ for entity in entities:
75
+ entities_dict[entity['entity']].append({
76
+ "entity": entity['entity'],
77
+ "score": round(entity['score'], 4),
78
+ "index": entity['index'],
79
+ "word": entity['word'],
80
+ "start": entity['start'],
81
+ "end": entity['end']
82
+ })
83
+
84
+ tagged_sentence = tag_sentence(text, entities_dict)
85
+ return dict(entities_dict), tagged_sentence
86
+ except Exception as e:
87
+ st.error(f"Error during NER processing: {str(e)}")
88
+ return {}, text
89
 
90
  # Streamlit UI
91
+ st.set_page_config(
92
+ page_title="CyNER 2.0",
93
+ page_icon="πŸ”",
94
+ layout="wide"
95
+ )
96
 
97
+ st.title("πŸ” CyNER 2.0 - Cybersecurity Named Entity Recognition")
98
+ st.markdown("**Advanced NER for Cybersecurity Text Analysis**")
99
+ st.write("Enter cybersecurity-related text to identify and extract named entities using the CyNER 2.0 model.")
100
 
101
+ # Example texts
102
+ examples = {
103
+ "Malware Analysis": "The Zeus trojan was detected on the victim's Windows 10 system at IP address 192.168.1.100. The malware communicated with command and control server evil.example.com using port 8080.",
104
+ "Vulnerability Report": "CVE-2021-44228 affects Apache Log4j versions 2.0 to 2.15.0. The vulnerability allows remote code execution through LDAP injection.",
105
+ "Incident Response": "Suspicious network traffic detected from IP 203.0.113.1 attempting to access /admin/login.php on our web server nginx running on Ubuntu 20.04."
106
+ }
107
+
108
+ # Sidebar for examples
109
+ with st.sidebar:
110
+ st.header("Example Texts")
111
+ for title, text in examples.items():
112
+ if st.button(f"Load: {title}"):
113
+ st.session_state.input_text = text
114
+
115
+ # Main input
116
+ input_text = st.text_area(
117
+ "Input Text",
118
+ value=st.session_state.get('input_text', "Enter your cybersecurity text here..."),
119
+ height=150,
120
+ key='input_text'
121
+ )
122
+
123
+ col1, col2 = st.columns([1, 4])
124
+ with col1:
125
+ analyze_button = st.button("πŸ” Analyze Text", type="primary")
126
+
127
+ if analyze_button:
128
+ if input_text.strip() and input_text != "Enter your cybersecurity text here...":
129
+ with st.spinner("Processing text with CyNER 2.0..."):
130
+ entities_dict, tagged_sentence = perform_ner(input_text)
131
 
132
+ if entities_dict:
133
+ # Display results
134
+ st.subheader("πŸ“Š Analysis Results")
135
+
136
+ # Tagged visualization
137
+ st.markdown("**Tagged Entities:**")
138
+ st.markdown(tagged_sentence, unsafe_allow_html=True)
139
+
140
+ # Entity summary
141
+ st.markdown("**Entity Summary:**")
142
+ entity_counts = {k: len(v) for k, v in entities_dict.items()}
143
+
144
+ cols = st.columns(min(len(entity_counts), 4))
145
+ for i, (entity_type, count) in enumerate(entity_counts.items()):
146
+ with cols[i % 4]:
147
+ st.metric(entity_type.replace('B-', '').replace('I-', ''), count)
148
+
149
+ # Detailed results
150
+ with st.expander("πŸ“‹ Detailed Entity Information", expanded=False):
151
+ st.json(entities_dict)
152
+ else:
153
+ st.info("No entities detected in the provided text.")
154
  else:
155
+ st.warning("⚠️ Please enter some text for analysis.")
156
+
157
+ # Footer
158
+ st.markdown("---")
159
+ st.markdown("**CyNER 2.0** - Powered by DeBERTa-v3-base | Built with Streamlit")
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
- transformers
2
- torch
 
 
 
 
1
+ streamlit>=1.28.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ tokenizers>=0.13.0
5
+ numpy>=1.21.0