bakhil-aissa's picture
Upload 2 files
e54e4ba verified
raw
history blame
1.9 kB
import streamlit as st
import pandas as pd
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os
# download the model from Hugging Face
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
if os.path.exists("model_f16.onnx"):
st.write("Model already downloaded.")
else:
st.write("Downloading model...")
model_path = hf_hub_download(
repo_id="bakhil-aissa/anti_prompt_injection",
filename="model_f16.onnx",
local_dir_use_symlinks=False,
)
st.title("Anti Prompt Injection Detection")
# Load the ONNX model
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
# Define the input form
def predict ( text ):
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
logits = sess.run(["logits"], inputs)[0]
exp = np.exp(logits)
probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
return probs
st.subheader("Enter your text to check for prompt injection:")
text_input = st.text_area("Text Input", height=200)
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
if st.button("Check"):
if text_input:
try:
with st.spinner("Processing..."):
# Call the predict function
probs = predict(text_input)
jailbreak_prob = float(probs[0][1]) # index into batch
is_jailbreak = jailbreak_prob >= confidence_threshold
st.success(f"Is Jailbreak: {is_jailbreak}")
st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
except Exception as e:
st.error(f"Error: {str(e)}")
else:
st.warning("Please enter some text to check.")