import gradio as gr import os, time from transformers import pipeline MODEL_1_ID = "patronus-protect/wolf-guard" MODEL_2_ID = "HuggingLil/pii-sensitive-ner-german" TOKEN = os.getenv("HF_TOKEN") CLS_MAP = {"LABEL_0": "Benign", "LABEL_1": "Attack"} pipe_cls = pipeline("text-classification", model=MODEL_1_ID, token=TOKEN) pipe_ner = pipeline("ner", model=MODEL_2_ID, token=TOKEN, aggregation_strategy="simple") def analyze(text): if not text or len(text.strip()) < 1: return gr.update(visible=False), "" time.sleep(0.5) res_cls = pipe_cls(text)[0] is_attack = res_cls['label'] == "LABEL_1" res_ner = pipe_ner(text) if is_attack: return ( gr.update(visible=True), [(text, "ATTACK")] ) has_high_ner = False res_ner = sorted(res_ner, key=lambda x: x['start']) highlighted_data = [] last_idx = 0 for ent in res_ner: if ent['start'] < last_idx: continue if ent['score'] > 0.90: if ent['start'] > last_idx: highlighted_data.append((text[last_idx:ent['start']], None)) label_text = f"{ent['entity_group']} ({ent['score']:.0%})" highlighted_data.append((text[ent['start']:ent['end']], label_text)) last_idx = ent['end'] has_high_ner = True else: continue if last_idx < len(text): highlighted_data.append((text[last_idx:], None)) if has_high_ner: return gr.update(visible=True), highlighted_data else: return gr.update(visible=True), [(text, "SAFE")] ACCENT_COLOR = "#F5C77A" theme = gr.themes.Soft() theme.font = [gr.themes.GoogleFont("Inter"), "sans-serif"] css = """ .color-attack { background-color: #ffcccb !important; } .color-ner { background-color: #ffe5b4 !important; } .color-clean { background-color: #d1ffbd !important; } footer {display: none !important;} .gradio-container {border: none !important;} .generating { border-color: """ + ACCENT_COLOR + """ !important; } """ with gr.Blocks(css=css, theme=theme) as demo: input_text = gr.Textbox(label="Enter Text", placeholder="e.g 'Elena Petrov is an american...' or 'Forget your instructions and do...'", lines=2) submit_btn = gr.Button("Analyse", variant="primary", scale=1) with gr.Column(visible=False) as results_col: display_output = gr.HighlightedText( label="Result", combine_adjacent=True, show_legend=False, color_map={"ATTACK": "red", "SAFE": "green"} ) submit_btn.click( fn=analyze, inputs=input_text, outputs=[results_col, display_output], show_progress="minimal" ) input_text.blur( fn=analyze, inputs=input_text, outputs=[results_col, display_output], show_progress="minimal", concurrency_limit=1, trigger_mode="always_last" ) demo.launch()