""" Privacy Classifier Demo - Classifies prompts as KEEP_LOCAL vs ALLOW_CLOUD """ import gradio as gr import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer # Model configuration MODEL_ID = "jonmabe/privacy-classifier-electra" # Model labels: 0=safe (ALLOW_CLOUD), 1=sensitive (KEEP_LOCAL) LABELS = ["ALLOW_CLOUD", "KEEP_LOCAL"] # index 0=safe, index 1=sensitive # Load model and tokenizer print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) model.eval() # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) print(f"Model loaded on {device}") def classify_prompt(text: str) -> tuple[str, dict]: """ Classify a prompt as KEEP_LOCAL or ALLOW_CLOUD. Returns: - Classification label with confidence - Dictionary of class probabilities for the label component """ if not text.strip(): return "Please enter a prompt to classify.", {} # Tokenize inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} # Inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.softmax(logits, dim=-1)[0] # Get prediction pred_idx = probs.argmax().item() pred_label = LABELS[pred_idx] confidence = probs[pred_idx].item() # Create probability dict for label component prob_dict = {label: float(probs[i]) for i, label in enumerate(LABELS)} return f"{pred_label} ({confidence:.1%} confidence)", prob_dict def get_color_for_label(label: str) -> str: """Return color based on classification.""" if "KEEP_LOCAL" in label: return "red" elif "ALLOW_CLOUD" in label: return "green" return "gray" # Example prompts EXAMPLES = [ ["What is the capital of France?"], ["My social security number is 123-45-6789, can you help me file taxes?"], ["Write me a poem about the ocean."], ["Here's my password: hunter2, please remember it."], ["Explain how photosynthesis works."], ["My credit card number is 4111-1111-1111-1111, check if it's valid."], ["What are some good restaurants in Seattle?"], ["My medical records show I have diabetes. What should I eat?"], ["Translate 'hello world' to Spanish."], ["My home address is 123 Main St, Anytown USA. Send me a pizza."], ["How do I sort a list in Python?"], ["My employee ID is E12345 and my salary is $85,000."], ] # Custom CSS for styling css = """ .keep-local { background: linear-gradient(135deg, #ff6b6b 0%, #ee5a5a 100%) !important; color: white !important; font-weight: bold !important; } .allow-cloud { background: linear-gradient(135deg, #51cf66 0%, #40c057 100%) !important; color: white !important; font-weight: bold !important; } .result-box { font-size: 1.2em; padding: 20px; border-radius: 10px; text-align: center; } """ # Create Gradio interface with gr.Blocks(css=css, title="Privacy Classifier") as demo: gr.Markdown(""" # 🔒 Privacy Classifier Classify prompts to determine if they contain sensitive information that should stay local or if they're safe to send to cloud LLM services. - **🔴 KEEP_LOCAL**: Contains PII, sensitive data, or private information - **🟢 ALLOW_CLOUD**: Safe to process with cloud-based AI services This model helps route requests in privacy-aware AI systems. """) with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox( label="Enter your prompt", placeholder="Type a prompt to classify...", lines=3, ) classify_btn = gr.Button("🔍 Classify", variant="primary", size="lg") with gr.Column(scale=1): result_label = gr.Textbox( label="Classification Result", interactive=False, lines=2, ) confidence_chart = gr.Label( label="Confidence Scores", num_top_classes=2, ) gr.Markdown("### 📝 Example Prompts") gr.Examples( examples=EXAMPLES, inputs=input_text, outputs=[result_label, confidence_chart], fn=classify_prompt, cache_examples=False, ) # Event handlers classify_btn.click( fn=classify_prompt, inputs=input_text, outputs=[result_label, confidence_chart], ) input_text.submit( fn=classify_prompt, inputs=input_text, outputs=[result_label, confidence_chart], ) gr.Markdown(""" --- ### About This Model **Model**: [jonmabe/privacy-classifier-electra](https://huggingface.co/jonmabe/privacy-classifier-electra) This is an ELECTRA-based classifier fine-tuned to detect sensitive information in prompts. Use cases include: - Privacy-aware prompt routing - Data loss prevention for LLM applications - Compliance with data protection regulations ⚠️ **Disclaimer**: This model is for demonstration purposes. Always verify classifications for production use cases involving sensitive data. """) if __name__ == "__main__": demo.launch()