File size: 1,904 Bytes
e54e4ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st 

import pandas as pd
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download 


import os



# download the model from Hugging Face
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
if os.path.exists("model_f16.onnx"):
    st.write("Model already downloaded.")
else:
    st.write("Downloading model...")
    model_path = hf_hub_download(
        repo_id="bakhil-aissa/anti_prompt_injection",
        filename="model_f16.onnx",
        local_dir_use_symlinks=False,
    )

st.title("Anti Prompt Injection Detection")


# Load the ONNX model
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
# Define the input form
def predict ( text ):
    enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
    inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
    logits = sess.run(["logits"], inputs)[0]
    exp = np.exp(logits)
    probs = exp / exp.sum(axis=1, keepdims=True)        # shape (1, num_classes)
    return probs

st.subheader("Enter your text to check for prompt injection:")
text_input = st.text_area("Text Input", height=200)
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
if st.button("Check"):
    if text_input:
        try:
            with st.spinner("Processing..."):
                # Call the predict function
                probs = predict(text_input)
            jailbreak_prob = float(probs[0][1])  # index into batch
            is_jailbreak = jailbreak_prob >= confidence_threshold
            
            st.success(f"Is Jailbreak: {is_jailbreak}")
            st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
        except Exception as e:
            st.error(f"Error: {str(e)}")
    else:
        st.warning("Please enter some text to check.")