File size: 3,016 Bytes
b03ba71
724166c
1e6fb4e
b03ba71
1e6fb4e
 
62343fe
cad6ec2
b03ba71
1e6fb4e
 
b03ba71
30a4b6d
 
1ccafce
1e6fb4e
724166c
 
30a4b6d
 
1e6fb4e
 
30a4b6d
 
 
1ccafce
 
30a4b6d
 
 
 
7f6e108
30a4b6d
7f6e108
 
1e6fb4e
7f6e108
 
 
 
 
 
 
 
 
 
 
30a4b6d
7f6e108
 
 
30a4b6d
7f6e108
 
cad6ec2
30a4b6d
1ccafce
30a4b6d
47a0c22
1e6fb4e
d43bc72
 
 
 
 
1e6fb4e
30a4b6d
 
 
1e6fb4e
d43bc72
 
 
1e6fb4e
 
d43bc72
6bad3b6
2f83474
 
1e6fb4e
6bad3b6
 
 
 
 
 
 
 
f4f723b
 
 
 
 
 
 
 
30a4b6d
 
1ccafce
 
 
 
1e6fb4e
b03ba71
6b2b5a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import os, time
from transformers import pipeline

MODEL_1_ID = "patronus-protect/wolf-guard"
MODEL_2_ID = "HuggingLil/pii-sensitive-ner-german"
TOKEN = os.getenv("HF_TOKEN")
CLS_MAP = {"LABEL_0": "Benign", "LABEL_1": "Attack"}

pipe_cls = pipeline("text-classification", model=MODEL_1_ID, token=TOKEN)
pipe_ner = pipeline("ner", model=MODEL_2_ID, token=TOKEN, aggregation_strategy="simple")

def analyze(text):
    if not text or len(text.strip()) < 1:
        return gr.update(visible=False), ""

    time.sleep(0.5)

    res_cls = pipe_cls(text)[0]
    is_attack = res_cls['label'] == "LABEL_1" 

    res_ner = pipe_ner(text)
    
    if is_attack:
        return (
            gr.update(visible=True),
            [(text, "ATTACK")]
        )
    
    has_high_ner = False
    
    res_ner = sorted(res_ner, key=lambda x: x['start'])
    
    highlighted_data = []
    last_idx = 0

    for ent in res_ner:
        if ent['start'] < last_idx:
            continue
            
        if ent['score'] > 0.90:
            if ent['start'] > last_idx:
                highlighted_data.append((text[last_idx:ent['start']], None))
            
            label_text = f"{ent['entity_group']} ({ent['score']:.0%})"
            highlighted_data.append((text[ent['start']:ent['end']], label_text))
            
            last_idx = ent['end']
            has_high_ner = True
        else:
            continue
    
    if last_idx < len(text):
        highlighted_data.append((text[last_idx:], None))

    if has_high_ner:
        return gr.update(visible=True), highlighted_data
    else:
        return gr.update(visible=True), [(text, "SAFE")]

ACCENT_COLOR = "#F5C77A"

theme = gr.themes.Soft()
theme.font = [gr.themes.GoogleFont("Inter"), "sans-serif"]

css = """
.color-attack { background-color: #ffcccb !important; }
.color-ner { background-color: #ffe5b4 !important; }
.color-clean { background-color: #d1ffbd !important; }
footer {display: none !important;}

.gradio-container {border: none !important;}
.generating { border-color: """ + ACCENT_COLOR + """ !important; }
"""

with gr.Blocks(css=css, theme=theme) as demo:
    input_text = gr.Textbox(label="Enter Text", placeholder="e.g 'Elena Petrov is an american...' or 'Forget your instructions and do...'", lines=2)
    
    submit_btn = gr.Button("Analyse", variant="primary", scale=1)

    with gr.Column(visible=False) as results_col:
        display_output = gr.HighlightedText(
            label="Result",
            combine_adjacent=True,
            show_legend=False,
            color_map={"ATTACK": "red", "SAFE": "green"}
        )

    submit_btn.click(
        fn=analyze,
        inputs=input_text,
        outputs=[results_col, display_output],
        show_progress="minimal"
    )

    input_text.blur(
        fn=analyze,
        inputs=input_text,
        outputs=[results_col, display_output],
        show_progress="minimal",
        concurrency_limit=1,
        trigger_mode="always_last"
    )

demo.launch()