FilipL009 commited on
Commit
3460f59
·
verified ·
1 Parent(s): ef6ebd0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +265 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,267 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
3
+ import torch
4
+ import pypdf
5
+ import os
6
+ import pandas as pd
7
+ import json
8
+ from tqdm import tqdm # Pro progress bar v terminálu
9
+
10
+ # --- FIX PRO WINDOWS A MODERNBERT ---
11
+ import torch._dynamo
12
+ torch._dynamo.config.suppress_errors = True
13
+
14
+ # Nastavení stránky
15
+ st.set_page_config(page_title="CTI NER Analyzer", page_icon="🛡️", layout="wide")
16
+
17
+ st.title("🛡️ CTI NER Analyzer")
18
+ st.markdown("Detekce entit v textu pomocí modelu **attack-vector/SecureModernBERT-NER**.")
19
+
20
+ # --- Funkce ---
21
+
22
+ @st.cache_resource
23
+ def load_model():
24
+ """
25
+ Načte model. Strategii dáme 'simple', protože hlavní spojování
26
+ děláme vlastní funkcí merge_close_entities.
27
+ """
28
+ device = 0 if torch.cuda.is_available() else -1
29
+ model_name = "attack-vector/SecureModernBERT-NER"
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
33
+
34
+ # Důležité: Tady zatím batch_size neurčujeme, to až při volání
35
+ pipe = pipeline(
36
+ "ner",
37
+ model=model,
38
+ tokenizer=tokenizer,
39
+ aggregation_strategy="simple",
40
+ device=device
41
+ )
42
+ return pipe
43
+
44
+ def extract_text_from_pdf(uploaded_file):
45
+ try:
46
+ pdf_reader = pypdf.PdfReader(uploaded_file)
47
+ text = ""
48
+ for page in pdf_reader.pages:
49
+ extracted = page.extract_text()
50
+ if extracted: text += extracted + "\n\n"
51
+ return text
52
+ except Exception as e:
53
+ st.error(f"Chyba při čtení PDF: {e}")
54
+ return ""
55
+
56
+ def analyze_long_text_batched(pipeline, text, chunk_size=4000, batch_size=8):
57
+ """
58
+ OPTIMALIZOVANÁ VERZE: Používá batch processing pro maximální využití GPU.
59
+ """
60
+ # 1. Příprava dat (chunks a jejich offsety)
61
+ chunks = []
62
+ offsets = []
63
+
64
+ for i in range(0, len(text), chunk_size):
65
+ chunk = text[i : i + chunk_size]
66
+ if not chunk.strip(): continue
67
+ chunks.append(chunk)
68
+ offsets.append(i)
69
+
70
+ results = []
71
+
72
+ # 2. Hromadná inference (Batch Inference)
73
+ # pipeline iteruje přes chunks a díky batch_size=8 krmí GPU efektivně
74
+ # enumerate nám pomůže spárovat výsledek s původním offsetem
75
+ for i, batch_results in enumerate(pipeline(chunks, batch_size=batch_size)):
76
+ current_offset = offsets[i]
77
+
78
+ # 3. Oprava pozic (přičtení offsetu)
79
+ for entity in batch_results:
80
+ entity['start'] += current_offset
81
+ entity['end'] += current_offset
82
+ results.append(entity)
83
+
84
+ return results
85
+
86
+ def merge_close_entities(results, original_text, max_char_distance=2):
87
+ """
88
+ Slepí roztrhané entity (např. 'Cozy' + 'Bear').
89
+ """
90
+ if not results: return []
91
+ merged = []
92
+ current = results[0].copy()
93
+
94
+ for next_entity in results[1:]:
95
+ gap_start = current['end']
96
+ gap_end = next_entity['start']
97
+
98
+ if gap_start > gap_end: gap_start = gap_end
99
+
100
+ gap_text = original_text[gap_start:gap_end]
101
+
102
+ if (current['entity_group'] == next_entity['entity_group'] and
103
+ len(gap_text) <= max_char_distance and
104
+ "." not in gap_text):
105
+
106
+ # Sloučení
107
+ current['end'] = next_entity['end']
108
+ current['score'] = float(max(current['score'], next_entity['score']))
109
+ else:
110
+ merged.append(current)
111
+ current = next_entity.copy()
112
+
113
+ merged.append(current)
114
+ return merged
115
+
116
+ # --- Načtení modelu ---
117
+ with st.spinner('Načítám model...'):
118
+ try:
119
+ nlp_pipeline = load_model()
120
+ except Exception as e:
121
+ st.error(f"Chyba: {e}")
122
+ st.stop()
123
+
124
+ # --- UI ---
125
+ col1, col2 = st.columns([1, 2])
126
+
127
+ with col1:
128
+ st.subheader("📂 Vstup dat")
129
+ uploaded_file = st.file_uploader("Nahrajte PDF", type=["pdf"])
130
+ manual_text = st.text_area("Vložte text:", height=300, disabled=(uploaded_file is not None))
131
+
132
+ text_to_analyze = ""
133
+ if uploaded_file:
134
+ with st.spinner("Čtu PDF..."):
135
+ text_to_analyze = extract_text_from_pdf(uploaded_file)
136
+ if text_to_analyze: st.success(f"PDF načteno: {len(text_to_analyze)} znaků.")
137
+ else:
138
+ text_to_analyze = manual_text
139
+
140
+ analyze_button = st.button("Analyzovat", type="primary")
141
+
142
+ if torch.cuda.is_available():
143
+ st.caption(f"🚀 GPU Akcelerace aktivní: {torch.cuda.get_device_name(0)}")
144
+
145
+ # --- Analýza ---
146
+ with col2:
147
+ if analyze_button and text_to_analyze.strip():
148
+ progress_bar = st.progress(0, text="Zahajuji analýzu...")
149
+
150
+ try:
151
+ # 1. Analýza (OPTIMALIZOVANÁ)
152
+ progress_bar.progress(20, text="Běží AI model (Batch Processing)...")
153
+
154
+ # Zde voláme novou funkci s batch_size
155
+ raw_results = analyze_long_text_batched(nlp_pipeline, text_to_analyze, batch_size=8)
156
+
157
+ # 2. Slepování entit
158
+ progress_bar.progress(80, text="Čištění výsledků...")
159
+ results = merge_close_entities(raw_results, text_to_analyze)
160
+
161
+ progress_bar.progress(100, text="Hotovo!")
162
+ progress_bar.empty()
163
+
164
+ if not results:
165
+ st.info("Nic nenalezeno.")
166
+ else:
167
+ st.subheader("📝 Výsledky")
168
+
169
+ # --- VIZUALIZACE ---
170
+ display_limit = 5000
171
+ st.caption(f"🎨 Náhled barevného textu (prvních {display_limit} znaků):")
172
+
173
+ visible_results = [r for r in results if r['end'] < display_limit]
174
+ html_string = "<div style='line-height: 2.0; font-family: sans-serif;'>"
175
+ last_idx = 0
176
+
177
+ for entity in visible_results:
178
+ start = entity['start']
179
+ end = entity['end']
180
+ label = entity['entity_group']
181
+ word = text_to_analyze[start:end]
182
+
183
+ html_string += text_to_analyze[last_idx:start].replace("\n", "<br>")
184
+
185
+ color_map = {
186
+ "MALWARE": "#ff4b4b", "ACTOR": "#ffa421", "THREAT-ACTOR": "#ffa421",
187
+ "TOOL": "#1c83e1", "MITRE-TACTIC": "#800080", "INDICATOR": "#21c354",
188
+ "FILEPATH": "#6c757d", "DOMAIN": "#21c354", "IP": "#21c354"
189
+ }
190
+ color = color_map.get(label, "#6c757d")
191
+
192
+ html_string += f"<mark style='background-color: {color}; color: white; border-radius: 4px; padding: 2px 4px;'>{word} <sub style='font-size: 0.6em'>{label}</sub></mark>"
193
+ last_idx = end
194
+
195
+ html_string += text_to_analyze[last_idx:display_limit].replace("\n", "<br>")
196
+ if len(text_to_analyze) > display_limit:
197
+ html_string += "<br><br><i>... (zbytek textu je v tabulce níže) ...</i>"
198
+ html_string += "</div>"
199
+
200
+ with st.expander("Rozbalit barevný náhled", expanded=True):
201
+ st.markdown(html_string, unsafe_allow_html=True)
202
+
203
+ st.divider()
204
+
205
+ # --- TABULKA A EXPORT ---
206
+ st.subheader("📊 Kompletní přehled nalezených entit")
207
+
208
+ unique_entities = {}
209
+ full_export_data = []
210
+
211
+ for res in results:
212
+ raw_word = text_to_analyze[res['start']:res['end']]
213
+ clean_word = raw_word.strip(" .,;:)('\"")
214
+ if len(clean_word) < 2: continue
215
+
216
+ score_float = float(res['score'])
217
+
218
+ # 1. Unikátní entity
219
+ key = (clean_word, res['entity_group'])
220
+ if key not in unique_entities:
221
+ unique_entities[key] = score_float
222
+ else:
223
+ unique_entities[key] = max(unique_entities[key], score_float)
224
+
225
+ # 2. Export dat
226
+ full_export_data.append({
227
+ "Entity": clean_word,
228
+ "Type": res['entity_group'],
229
+ "Confidence": score_float,
230
+ "Start_Char": int(res['start']),
231
+ "End_Char": int(res['end'])
232
+ })
233
+
234
+ # Tabulka
235
+ table_data = [
236
+ {"Entity": k[0], "Type": k[1], "Confidence": v}
237
+ for k, v in unique_entities.items()
238
+ ]
239
+ df_unique = pd.DataFrame(table_data).sort_values(by=["Type", "Entity"])
240
+
241
+ df_display = df_unique.copy()
242
+ df_display["Confidence"] = df_display["Confidence"].apply(lambda x: f"{x:.2%}")
243
+ st.dataframe(df_display, use_container_width=True)
244
+
245
+ # Exporty
246
+ col_exp1, col_exp2 = st.columns(2)
247
+
248
+ with col_exp1:
249
+ csv = df_unique.to_csv(index=False).encode('utf-8')
250
+ st.download_button(
251
+ label="📥 Stáhnout CSV (Excel)",
252
+ data=csv,
253
+ file_name='cti_analyza.csv',
254
+ mime='text/csv',
255
+ )
256
+
257
+ with col_exp2:
258
+ json_str = json.dumps(full_export_data, indent=4)
259
+ st.download_button(
260
+ label="📥 Stáhnout JSON",
261
+ data=json_str,
262
+ file_name='cti_analyza_full.json',
263
+ mime='application/json',
264
+ )
265
 
266
+ except Exception as e:
267
+ st.error(f"Chyba při analýze: {e}")