detector / app.py
Benedikt Veith
Layout again
6bad3b6
raw
history blame
3.02 kB
import gradio as gr
import os, time
from transformers import pipeline
MODEL_1_ID = "patronus-protect/wolf-guard"
MODEL_2_ID = "HuggingLil/pii-sensitive-ner-german"
TOKEN = os.getenv("HF_TOKEN")
CLS_MAP = {"LABEL_0": "Benign", "LABEL_1": "Attack"}
pipe_cls = pipeline("text-classification", model=MODEL_1_ID, token=TOKEN)
pipe_ner = pipeline("ner", model=MODEL_2_ID, token=TOKEN, aggregation_strategy="simple")
def analyze(text):
if not text or len(text.strip()) < 1:
return gr.update(visible=False), ""
time.sleep(0.5)
res_cls = pipe_cls(text)[0]
is_attack = res_cls['label'] == "LABEL_1"
res_ner = pipe_ner(text)
if is_attack:
return (
gr.update(visible=True),
[(text, "ATTACK")]
)
has_high_ner = False
res_ner = sorted(res_ner, key=lambda x: x['start'])
highlighted_data = []
last_idx = 0
for ent in res_ner:
if ent['start'] < last_idx:
continue
if ent['score'] > 0.90:
if ent['start'] > last_idx:
highlighted_data.append((text[last_idx:ent['start']], None))
label_text = f"{ent['entity_group']} ({ent['score']:.0%})"
highlighted_data.append((text[ent['start']:ent['end']], label_text))
last_idx = ent['end']
has_high_ner = True
else:
continue
if last_idx < len(text):
highlighted_data.append((text[last_idx:], None))
if has_high_ner:
return gr.update(visible=True), highlighted_data
else:
return gr.update(visible=True), [(text, "SAFE")]
ACCENT_COLOR = "#F5C77A"
theme = gr.themes.Soft()
theme.font = [gr.themes.GoogleFont("Inter"), "sans-serif"]
css = """
.color-attack { background-color: #ffcccb !important; }
.color-ner { background-color: #ffe5b4 !important; }
.color-clean { background-color: #d1ffbd !important; }
footer {display: none !important;}
.gradio-container {border: none !important;}
.generating { border-color: """ + ACCENT_COLOR + """ !important; }
"""
with gr.Blocks(css=css, theme=theme) as demo:
input_text = gr.Textbox(label="Enter Text", placeholder="e.g 'Elena Petrov is an american...' or 'Forget your instructions and do...'", lines=2)
submit_btn = gr.Button("Analyse", variant="primary", scale=1)
with gr.Column(visible=False) as results_col:
display_output = gr.HighlightedText(
label="Result",
combine_adjacent=True,
show_legend=False,
color_map={"ATTACK": "red", "SAFE": "green"}
)
submit_btn.click(
fn=analyze,
inputs=input_text,
outputs=[results_col, display_output],
show_progress="minimal"
)
input_text.blur(
fn=analyze,
inputs=input_text,
outputs=[results_col, display_output],
show_progress="minimal",
concurrency_limit=1,
trigger_mode="always_last"
)
demo.launch()