cti-ner-modernbert-gui / src /streamlit_app.py
FilipL009's picture
Update src/streamlit_app.py
7013831 verified
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
import torch
import pypdf
import os
import pandas as pd
import json
# --- FIX PRO WINDOWS (Neškodí na Linuxu) ---
import torch._dynamo
torch._dynamo.config.suppress_errors = True
# Nastavení stránky
st.set_page_config(page_title="CTI NER Analyzer", page_icon="🛡️", layout="wide")
st.title("🛡️ CTI NER Analyzer")
st.markdown("Detekce entit v textu pomocí modelu **attack-vector/SecureModernBERT-NER**.")
# --- Funkce ---
@st.cache_resource
def load_model():
"""
Načte model. Vynucuje CPU nastavení pro Hugging Face Free Tier.
"""
# 1. Vynucení CPU (Free Tier nemá GPU)
device = -1
model_name = "attack-vector/SecureModernBERT-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 2. FIX: Vypnutí Flash Attention a vynucení float32 (prevence pádu na CPU)
model = AutoModelForTokenClassification.from_pretrained(
model_name,
attn_implementation="eager", # Důležité: Vypne GPU optimalizace
torch_dtype=torch.float32 # Důležité: Plná přesnost pro CPU
)
pipe = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
device=device
)
return pipe
def extract_text_from_pdf(uploaded_file):
try:
pdf_reader = pypdf.PdfReader(uploaded_file)
text = ""
for page in pdf_reader.pages:
extracted = page.extract_text()
if extracted: text += extracted + "\n\n"
return text
except Exception as e:
st.error(f"Chyba při čtení PDF: {e}")
return ""
def analyze_long_text_batched(pipeline, text, chunk_size=3000, batch_size=1):
"""
Pro CPU Spaces používáme menší chunk_size a batch_size=1.
"""
chunks = []
offsets = []
for i in range(0, len(text), chunk_size):
chunk = text[i : i + chunk_size]
if not chunk.strip(): continue
chunks.append(chunk)
offsets.append(i)
results = []
# Batch size 1 je pro CPU nejbezpečnější
for i, batch_results in enumerate(pipeline(chunks, batch_size=batch_size)):
current_offset = offsets[i]
for entity in batch_results:
entity['start'] += current_offset
entity['end'] += current_offset
results.append(entity)
return results
def merge_close_entities(results, original_text, max_char_distance=2):
if not results: return []
merged = []
current = results[0].copy()
for next_entity in results[1:]:
gap_start = current['end']
gap_end = next_entity['start']
if gap_start > gap_end: gap_start = gap_end
gap_text = original_text[gap_start:gap_end]
if (current['entity_group'] == next_entity['entity_group'] and
len(gap_text) <= max_char_distance and
"." not in gap_text):
current['end'] = next_entity['end']
current['score'] = float(max(current['score'], next_entity['score']))
else:
merged.append(current)
current = next_entity.copy()
merged.append(current)
return merged
# --- Načtení modelu ---
with st.spinner('Načítám model (může trvat minutu)...'):
try:
nlp_pipeline = load_model()
except Exception as e:
st.error(f"Chyba při načítání modelu: {e}")
st.stop()
# --- UI ---
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("📂 Vstup dat")
uploaded_file = st.file_uploader("Nahrajte PDF", type=["pdf"])
manual_text = st.text_area("Vložte text:", height=300, disabled=(uploaded_file is not None))
text_to_analyze = ""
if uploaded_file:
with st.spinner("Čtu PDF..."):
text_to_analyze = extract_text_from_pdf(uploaded_file)
if text_to_analyze: st.success(f"PDF načteno: {len(text_to_analyze)} znaků.")
else:
text_to_analyze = manual_text
analyze_button = st.button("Analyzovat", type="primary")
# --- Analýza ---
with col2:
if analyze_button and text_to_analyze.strip():
progress_bar = st.progress(0, text="Zahajuji analýzu...")
try:
# 1. Analýza
progress_bar.progress(10, text="Běží AI model (bude to chvíli trvat)...")
# Batch size 1 pro CPU stabilitu
raw_results = analyze_long_text_batched(nlp_pipeline, text_to_analyze, batch_size=1)
# 2. Slepování entit
progress_bar.progress(90, text="Čištění výsledků...")
results = merge_close_entities(raw_results, text_to_analyze)
progress_bar.progress(100, text="Hotovo!")
progress_bar.empty()
if not results:
st.info("Nic nenalezeno.")
else:
st.subheader("📝 Výsledky")
# --- VIZUALIZACE ---
display_limit = 5000
st.caption(f"🎨 Náhled barevného textu (prvních {display_limit} znaků):")
visible_results = [r for r in results if r['end'] < display_limit]
html_string = "<div style='line-height: 2.0; font-family: sans-serif;'>"
last_idx = 0
for entity in visible_results:
start = entity['start']
end = entity['end']
label = entity['entity_group']
word = text_to_analyze[start:end]
html_string += text_to_analyze[last_idx:start].replace("\n", "<br>")
color_map = {
"MALWARE": "#ff4b4b", "ACTOR": "#ffa421", "THREAT-ACTOR": "#ffa421",
"TOOL": "#1c83e1", "MITRE-TACTIC": "#800080", "INDICATOR": "#21c354",
"FILEPATH": "#6c757d", "DOMAIN": "#21c354", "IP": "#21c354"
}
color = color_map.get(label, "#6c757d")
html_string += f"<mark style='background-color: {color}; color: white; border-radius: 4px; padding: 2px 4px;'>{word} <sub style='font-size: 0.6em'>{label}</sub></mark>"
last_idx = end
html_string += text_to_analyze[last_idx:display_limit].replace("\n", "<br>")
if len(text_to_analyze) > display_limit:
html_string += "<br><br><i>... (zbytek textu je v tabulce níže) ...</i>"
html_string += "</div>"
with st.expander("Rozbalit barevný náhled", expanded=True):
st.markdown(html_string, unsafe_allow_html=True)
st.divider()
# --- TABULKA ---
st.subheader("📊 Kompletní přehled nalezených entit")
unique_entities = {}
full_export_data = []
for res in results:
raw_word = text_to_analyze[res['start']:res['end']]
clean_word = raw_word.strip(" .,;:)('\"")
if len(clean_word) < 2: continue
score_float = float(res['score'])
key = (clean_word, res['entity_group'])
if key not in unique_entities:
unique_entities[key] = score_float
else:
unique_entities[key] = max(unique_entities[key], score_float)
full_export_data.append({
"Entity": clean_word,
"Type": res['entity_group'],
"Confidence": score_float,
"Start_Char": int(res['start']),
"End_Char": int(res['end'])
})
table_data = [
{"Entity": k[0], "Type": k[1], "Confidence": v}
for k, v in unique_entities.items()
]
df_unique = pd.DataFrame(table_data).sort_values(by=["Type", "Entity"])
df_display = df_unique.copy()
df_display["Confidence"] = df_display["Confidence"].apply(lambda x: f"{x:.2%}")
st.dataframe(df_display, use_container_width=True)
# --- EXPORT ---
col_exp1, col_exp2 = st.columns(2)
with col_exp1:
csv = df_unique.to_csv(index=False).encode('utf-8')
st.download_button(
label="📥 Stáhnout CSV",
data=csv,
file_name='cti_analyza.csv',
mime='text/csv',
)
with col_exp2:
json_str = json.dumps(full_export_data, indent=4)
st.download_button(
label="📥 Stáhnout JSON",
data=json_str,
file_name='cti_analyza_full.json',
mime='application/json',
)
except Exception as e:
st.error(f"Chyba při analýze: {e}")