Pranava Kailash commited on
Commit
5dde192
Β·
1 Parent(s): d2369e6

Fixed No data in Tensor error v1.1

Browse files
Files changed (1) hide show
  1. app.py +115 -79
app.py CHANGED
@@ -1,91 +1,95 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
  from collections import defaultdict
4
- import torch
5
 
6
  # Load model and tokenizer
7
  path_to_checkpoint = 'PranavaKailash/CyNER-2.0-DeBERTa-v3-base'
8
 
9
- # Load tokenizer
10
- tokenizer = AutoTokenizer.from_pretrained(path_to_checkpoint, use_fast=True)
11
-
12
- # Load model with proper device handling
13
- model = AutoModelForTokenClassification.from_pretrained(
14
- path_to_checkpoint,
15
- torch_dtype='auto',
16
- device_map='cpu'
17
- )
18
-
19
- # Initialize the NER pipeline
20
- ner_pipeline = pipeline(
21
- "ner",
22
- model=model,
23
- tokenizer=tokenizer,
24
- device=-1 # Explicitly use CPU
25
- )
 
26
 
27
- def tag_sentence(sentence, entities_dict):
28
  """
29
  Add HTML tags to entities for visualization.
30
  """
31
- all_entities = sorted(
32
- [(e['start'], e['end'], e['entity'], e['word']) for ents in entities_dict.values() for e in ents],
33
- key=lambda x: x[0]
34
- )
35
-
36
- merged_entities = []
37
- current_entity = None
38
-
39
- for start, end, entity_type, word in all_entities:
40
- if current_entity is None:
41
- current_entity = [start, end, entity_type, word]
42
- else:
43
- if start == current_entity[1] and entity_type == current_entity[2] and entity_type.startswith('I-'):
44
- current_entity[1] = end
45
- current_entity[3] += word.replace('▁', ' ')
46
- else:
47
- merged_entities.append(tuple(current_entity))
48
- current_entity = [start, end, entity_type, word]
49
 
50
- if current_entity:
51
- merged_entities.append(tuple(current_entity))
52
 
53
  tagged_sentence = ""
54
  last_idx = 0
55
 
56
- for start, end, entity_type, _ in merged_entities:
57
- tagged_sentence += sentence[last_idx:start]
58
- entity_tag = entity_type.replace('I-', 'B-')
59
- tagged_sentence += f"<span style='color:blue; background-color: #e6f3ff; padding: 2px; border-radius: 3px;'><strong>{entity_tag}</strong></span><span style='background-color: #fff3cd; padding: 2px; border-radius: 3px;'>{sentence[start:end]}</span>"
60
- last_idx = end
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
62
  tagged_sentence += sentence[last_idx:]
63
  return tagged_sentence
64
 
65
  @st.cache_data
66
- def perform_ner(text):
67
  """
68
  Run NER pipeline and prepare results for display.
69
  """
 
 
 
70
  try:
71
- entities = ner_pipeline(text)
72
- entities_dict = defaultdict(list)
73
 
 
 
74
  for entity in entities:
75
- entities_dict[entity['entity']].append({
76
- "entity": entity['entity'],
77
- "score": round(entity['score'], 4),
78
- "index": entity['index'],
79
- "word": entity['word'],
80
- "start": entity['start'],
81
- "end": entity['end']
82
  })
83
 
84
- tagged_sentence = tag_sentence(text, entities_dict)
85
- return dict(entities_dict), tagged_sentence
 
 
86
  except Exception as e:
87
  st.error(f"Error during NER processing: {str(e)}")
88
- return {}, text
89
 
90
  # Streamlit UI
91
  st.set_page_config(
@@ -94,66 +98,98 @@ st.set_page_config(
94
  layout="wide"
95
  )
96
 
 
 
 
 
 
 
 
97
  st.title("πŸ” CyNER 2.0 - Cybersecurity Named Entity Recognition")
98
- st.markdown("**Advanced NER for Cybersecurity Text Analysis**")
99
- st.write("Enter cybersecurity-related text to identify and extract named entities using the CyNER 2.0 model.")
100
 
101
  # Example texts
102
  examples = {
103
  "Malware Analysis": "The Zeus trojan was detected on the victim's Windows 10 system at IP address 192.168.1.100. The malware communicated with command and control server evil.example.com using port 8080.",
104
  "Vulnerability Report": "CVE-2021-44228 affects Apache Log4j versions 2.0 to 2.15.0. The vulnerability allows remote code execution through LDAP injection.",
105
- "Incident Response": "Suspicious network traffic detected from IP 203.0.113.1 attempting to access /admin/login.php on our web server nginx running on Ubuntu 20.04."
 
106
  }
107
 
108
  # Sidebar for examples
109
  with st.sidebar:
110
- st.header("Example Texts")
 
111
  for title, text in examples.items():
112
- if st.button(f"Load: {title}"):
113
  st.session_state.input_text = text
114
 
115
  # Main input
116
  input_text = st.text_area(
117
- "Input Text",
118
  value=st.session_state.get('input_text', "Enter your cybersecurity text here..."),
119
  height=150,
 
120
  key='input_text'
121
  )
122
 
123
- col1, col2 = st.columns([1, 4])
124
  with col1:
125
  analyze_button = st.button("πŸ” Analyze Text", type="primary")
 
 
126
 
127
- if analyze_button:
 
 
 
 
128
  if input_text.strip() and input_text != "Enter your cybersecurity text here...":
129
- with st.spinner("Processing text with CyNER 2.0..."):
130
- entities_dict, tagged_sentence = perform_ner(input_text)
131
 
132
  if entities_dict:
 
 
133
  # Display results
134
  st.subheader("πŸ“Š Analysis Results")
135
 
136
  # Tagged visualization
137
- st.markdown("**Tagged Entities:**")
138
  st.markdown(tagged_sentence, unsafe_allow_html=True)
139
 
140
- # Entity summary
141
- st.markdown("**Entity Summary:**")
142
- entity_counts = {k: len(v) for k, v in entities_dict.items()}
 
 
 
 
 
 
 
143
 
144
- cols = st.columns(min(len(entity_counts), 4))
145
- for i, (entity_type, count) in enumerate(entity_counts.items()):
146
- with cols[i % 4]:
147
- st.metric(entity_type.replace('B-', '').replace('I-', ''), count)
 
 
148
 
149
- # Detailed results
150
- with st.expander("πŸ“‹ Detailed Entity Information", expanded=False):
151
  st.json(entities_dict)
152
  else:
153
- st.info("No entities detected in the provided text.")
154
  else:
155
  st.warning("⚠️ Please enter some text for analysis.")
156
 
157
  # Footer
158
  st.markdown("---")
159
- st.markdown("**CyNER 2.0** - Powered by DeBERTa-v3-base | Built with Streamlit")
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
3
  from collections import defaultdict
 
4
 
5
  # Load model and tokenizer
6
  path_to_checkpoint = 'PranavaKailash/CyNER-2.0-DeBERTa-v3-base'
7
 
8
+ @st.cache_resource
9
+ def load_model():
10
+ """Load model and tokenizer with proper error handling"""
11
+ try:
12
+ tokenizer = AutoTokenizer.from_pretrained(path_to_checkpoint, use_fast=True)
13
+ model = AutoModelForTokenClassification.from_pretrained(path_to_checkpoint)
14
+
15
+ # Initialize the NER pipeline (this handles device placement automatically)
16
+ ner_pipeline = pipeline(
17
+ "ner",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ device=-1 # Force CPU usage
21
+ )
22
+ return ner_pipeline
23
+ except Exception as e:
24
+ st.error(f"Error loading model: {str(e)}")
25
+ return None
26
 
27
+ def tag_sentence(sentence, entities):
28
  """
29
  Add HTML tags to entities for visualization.
30
  """
31
+ if not entities:
32
+ return sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Sort entities by start position
35
+ sorted_entities = sorted(entities, key=lambda x: x['start'])
36
 
37
  tagged_sentence = ""
38
  last_idx = 0
39
 
40
+ for entity in sorted_entities:
41
+ # Add text before entity
42
+ tagged_sentence += sentence[last_idx:entity['start']]
43
+
44
+ # Add tagged entity
45
+ entity_text = sentence[entity['start']:entity['end']]
46
+ entity_label = entity['entity_group'] if 'entity_group' in entity else entity['entity']
47
+ confidence = entity.get('score', 0)
48
+
49
+ tagged_sentence += f"""
50
+ <span style='background-color: #e6f3ff; padding: 2px 6px; border-radius: 4px; border-left: 3px solid #007acc; margin: 1px;'>
51
+ <strong style='color: #005299;'>{entity_label}</strong>
52
+ <span style='color: #333;'>{entity_text}</span>
53
+ <small style='color: #666; font-size: 0.8em;'>({confidence:.2f})</small>
54
+ </span>
55
+ """
56
+
57
+ last_idx = entity['end']
58
 
59
+ # Add remaining text
60
  tagged_sentence += sentence[last_idx:]
61
  return tagged_sentence
62
 
63
  @st.cache_data
64
+ def perform_ner(text, _pipeline):
65
  """
66
  Run NER pipeline and prepare results for display.
67
  """
68
+ if not _pipeline:
69
+ return [], text
70
+
71
  try:
72
+ # Get entities from pipeline
73
+ entities = _pipeline(text)
74
 
75
+ # Group entities by type for summary
76
+ entities_by_type = defaultdict(list)
77
  for entity in entities:
78
+ entity_type = entity.get('entity_group', entity.get('entity', 'Unknown'))
79
+ entities_by_type[entity_type].append({
80
+ 'text': text[entity['start']:entity['end']],
81
+ 'confidence': round(entity['score'], 3),
82
+ 'start': entity['start'],
83
+ 'end': entity['end']
 
84
  })
85
 
86
+ # Create tagged sentence
87
+ tagged_sentence = tag_sentence(text, entities)
88
+
89
+ return dict(entities_by_type), tagged_sentence, entities
90
  except Exception as e:
91
  st.error(f"Error during NER processing: {str(e)}")
92
+ return {}, text, []
93
 
94
  # Streamlit UI
95
  st.set_page_config(
 
98
  layout="wide"
99
  )
100
 
101
+ # Load the pipeline
102
+ ner_pipeline = load_model()
103
+
104
+ if not ner_pipeline:
105
+ st.error("❌ Failed to load the model. Please refresh the page or contact support.")
106
+ st.stop()
107
+
108
  st.title("πŸ” CyNER 2.0 - Cybersecurity Named Entity Recognition")
109
+ st.markdown("**Advanced NER for Cybersecurity Text Analysis using DeBERTa-v3**")
110
+ st.write("Enter cybersecurity-related text to identify and extract named entities.")
111
 
112
  # Example texts
113
  examples = {
114
  "Malware Analysis": "The Zeus trojan was detected on the victim's Windows 10 system at IP address 192.168.1.100. The malware communicated with command and control server evil.example.com using port 8080.",
115
  "Vulnerability Report": "CVE-2021-44228 affects Apache Log4j versions 2.0 to 2.15.0. The vulnerability allows remote code execution through LDAP injection.",
116
+ "Incident Response": "Suspicious network traffic detected from IP 203.0.113.1 attempting to access /admin/login.php on our web server nginx running on Ubuntu 20.04.",
117
+ "Phishing Attack": "Users received emails from admin@secur3-bank.com asking them to update their credentials by clicking on https://phishing-site.malicious.com/login"
118
  }
119
 
120
  # Sidebar for examples
121
  with st.sidebar:
122
+ st.header("πŸ“ Example Texts")
123
+ st.write("Click to load example cybersecurity text:")
124
  for title, text in examples.items():
125
+ if st.button(f"πŸ“‹ {title}", key=f"example_{title}"):
126
  st.session_state.input_text = text
127
 
128
  # Main input
129
  input_text = st.text_area(
130
+ "**Input Text**",
131
  value=st.session_state.get('input_text', "Enter your cybersecurity text here..."),
132
  height=150,
133
+ help="Paste any cybersecurity-related text to analyze",
134
  key='input_text'
135
  )
136
 
137
+ col1, col2, col3 = st.columns([2, 1, 3])
138
  with col1:
139
  analyze_button = st.button("πŸ” Analyze Text", type="primary")
140
+ with col2:
141
+ clear_button = st.button("πŸ—‘οΈ Clear")
142
 
143
+ if clear_button:
144
+ st.session_state.input_text = ""
145
+ st.experimental_rerun()
146
+
147
+ if analyze_button and ner_pipeline:
148
  if input_text.strip() and input_text != "Enter your cybersecurity text here...":
149
+ with st.spinner("πŸ€– Processing text with CyNER 2.0..."):
150
+ entities_dict, tagged_sentence, raw_entities = perform_ner(input_text, ner_pipeline)
151
 
152
  if entities_dict:
153
+ st.success(f"βœ… Analysis complete! Found {sum(len(v) for v in entities_dict.values())} entities")
154
+
155
  # Display results
156
  st.subheader("πŸ“Š Analysis Results")
157
 
158
  # Tagged visualization
159
+ st.markdown("**🏷️ Tagged Entities:**")
160
  st.markdown(tagged_sentence, unsafe_allow_html=True)
161
 
162
+ # Entity summary metrics
163
+ st.markdown("**πŸ“ˆ Entity Summary:**")
164
+ if len(entities_dict) > 0:
165
+ cols = st.columns(min(len(entities_dict), 4))
166
+ for i, (entity_type, entities_list) in enumerate(entities_dict.items()):
167
+ with cols[i % 4]:
168
+ st.metric(
169
+ label=entity_type.replace('B-', '').replace('I-', ''),
170
+ value=len(entities_list)
171
+ )
172
 
173
+ # Detailed breakdown
174
+ with st.expander("πŸ“‹ Detailed Entity Breakdown", expanded=True):
175
+ for entity_type, entities_list in entities_dict.items():
176
+ st.markdown(f"**{entity_type}:**")
177
+ for entity in entities_list:
178
+ st.markdown(f"- `{entity['text']}` (confidence: {entity['confidence']})")
179
 
180
+ # Raw data for developers
181
+ with st.expander("πŸ”§ Raw JSON Data", expanded=False):
182
  st.json(entities_dict)
183
  else:
184
+ st.info("ℹ️ No cybersecurity entities detected in the provided text. Try using text with security-related terms like IP addresses, malware names, CVEs, etc.")
185
  else:
186
  st.warning("⚠️ Please enter some text for analysis.")
187
 
188
  # Footer
189
  st.markdown("---")
190
+ st.markdown("""
191
+ <div style='text-align: center; color: #666; font-size: 0.9em;'>
192
+ <strong>CyNER 2.0</strong> - Cybersecurity Named Entity Recognition<br>
193
+ Model: <code>PranavaKailash/CyNER-2.0-DeBERTa-v3-base</code> | Built with Streamlit
194
+ </div>
195
+ """, unsafe_allow_html=True)