jonmabe's picture
Initial Gradio demo upload
bd30e03 verified
"""
Privacy Classifier Demo - Classifies prompts as KEEP_LOCAL vs ALLOW_CLOUD
"""
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Model configuration
MODEL_ID = "jonmabe/privacy-classifier-electra"
# Model labels: 0=safe (ALLOW_CLOUD), 1=sensitive (KEEP_LOCAL)
LABELS = ["ALLOW_CLOUD", "KEEP_LOCAL"] # index 0=safe, index 1=sensitive
# Load model and tokenizer
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded on {device}")
def classify_prompt(text: str) -> tuple[str, dict]:
"""
Classify a prompt as KEEP_LOCAL or ALLOW_CLOUD.
Returns:
- Classification label with confidence
- Dictionary of class probabilities for the label component
"""
if not text.strip():
return "Please enter a prompt to classify.", {}
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)[0]
# Get prediction
pred_idx = probs.argmax().item()
pred_label = LABELS[pred_idx]
confidence = probs[pred_idx].item()
# Create probability dict for label component
prob_dict = {label: float(probs[i]) for i, label in enumerate(LABELS)}
return f"{pred_label} ({confidence:.1%} confidence)", prob_dict
def get_color_for_label(label: str) -> str:
"""Return color based on classification."""
if "KEEP_LOCAL" in label:
return "red"
elif "ALLOW_CLOUD" in label:
return "green"
return "gray"
# Example prompts
EXAMPLES = [
["What is the capital of France?"],
["My social security number is 123-45-6789, can you help me file taxes?"],
["Write me a poem about the ocean."],
["Here's my password: hunter2, please remember it."],
["Explain how photosynthesis works."],
["My credit card number is 4111-1111-1111-1111, check if it's valid."],
["What are some good restaurants in Seattle?"],
["My medical records show I have diabetes. What should I eat?"],
["Translate 'hello world' to Spanish."],
["My home address is 123 Main St, Anytown USA. Send me a pizza."],
["How do I sort a list in Python?"],
["My employee ID is E12345 and my salary is $85,000."],
]
# Custom CSS for styling
css = """
.keep-local {
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a5a 100%) !important;
color: white !important;
font-weight: bold !important;
}
.allow-cloud {
background: linear-gradient(135deg, #51cf66 0%, #40c057 100%) !important;
color: white !important;
font-weight: bold !important;
}
.result-box {
font-size: 1.2em;
padding: 20px;
border-radius: 10px;
text-align: center;
}
"""
# Create Gradio interface
with gr.Blocks(css=css, title="Privacy Classifier") as demo:
gr.Markdown("""
# πŸ”’ Privacy Classifier
Classify prompts to determine if they contain sensitive information that should stay local
or if they're safe to send to cloud LLM services.
- **πŸ”΄ KEEP_LOCAL**: Contains PII, sensitive data, or private information
- **🟒 ALLOW_CLOUD**: Safe to process with cloud-based AI services
This model helps route requests in privacy-aware AI systems.
""")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Enter your prompt",
placeholder="Type a prompt to classify...",
lines=3,
)
classify_btn = gr.Button("πŸ” Classify", variant="primary", size="lg")
with gr.Column(scale=1):
result_label = gr.Textbox(
label="Classification Result",
interactive=False,
lines=2,
)
confidence_chart = gr.Label(
label="Confidence Scores",
num_top_classes=2,
)
gr.Markdown("### πŸ“ Example Prompts")
gr.Examples(
examples=EXAMPLES,
inputs=input_text,
outputs=[result_label, confidence_chart],
fn=classify_prompt,
cache_examples=False,
)
# Event handlers
classify_btn.click(
fn=classify_prompt,
inputs=input_text,
outputs=[result_label, confidence_chart],
)
input_text.submit(
fn=classify_prompt,
inputs=input_text,
outputs=[result_label, confidence_chart],
)
gr.Markdown("""
---
### About This Model
**Model**: [jonmabe/privacy-classifier-electra](https://huggingface.co/jonmabe/privacy-classifier-electra)
This is an ELECTRA-based classifier fine-tuned to detect sensitive information in prompts.
Use cases include:
- Privacy-aware prompt routing
- Data loss prevention for LLM applications
- Compliance with data protection regulations
⚠️ **Disclaimer**: This model is for demonstration purposes. Always verify classifications
for production use cases involving sensitive data.
""")
if __name__ == "__main__":
demo.launch()