|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
library_name: transformers |
|
|
tags: |
|
|
- guardrails |
|
|
- safety |
|
|
- text-classification |
|
|
- roberta |
|
|
- education |
|
|
- code |
|
|
- cs-education |
|
|
- llm-safety |
|
|
- academic-integrity |
|
|
datasets: |
|
|
- md-nishat-008/Do-Not-Code |
|
|
metrics: |
|
|
- f1 |
|
|
- accuracy |
|
|
- precision |
|
|
- recall |
|
|
pipeline_tag: text-classification |
|
|
model-index: |
|
|
- name: PromptShield |
|
|
results: |
|
|
- task: |
|
|
type: text-classification |
|
|
name: Prompt Safety Classification |
|
|
dataset: |
|
|
type: md-nishat-008/Do-Not-Code |
|
|
name: Do Not Code |
|
|
split: test |
|
|
metrics: |
|
|
- type: f1 |
|
|
value: 0.93 |
|
|
name: F1 (Macro) |
|
|
- type: accuracy |
|
|
value: 0.94 |
|
|
name: Accuracy |
|
|
--- |
|
|
|
|
|
# PromptShield |
|
|
|
|
|
<p align="center"> |
|
|
<a href="https://github.com/mraihan-gmu/CodeGuard/tree/main"> |
|
|
<img src="https://img.shields.io/badge/GitHub-Repository-black?style=for-the-badge&logo=github" alt="GitHub"> |
|
|
</a> |
|
|
<a href="https://huggingface.co/datasets/md-nishat-008/Do-Not-Code"> |
|
|
<img src="https://img.shields.io/badge/🤗%20Dataset-Do%20Not%20Code-yellow?style=for-the-badge" alt="Dataset"> |
|
|
</a> |
|
|
<a href="https://aclanthology.org/PLACEHOLDER"> |
|
|
<img src="https://img.shields.io/badge/📄%20Paper-EACL%202026-green?style=for-the-badge" alt="Paper"> |
|
|
</a> |
|
|
</p> |
|
|
|
|
|
**PromptShield** is a lightweight guardrail model for detecting unsafe and irrelevant prompts in Computer Science education settings. It achieves **0.93 F1 score**, outperforming existing guardrails by 30-65%. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
PromptShield is a RoBERTa-base encoder (125M parameters) fine-tuned on the [Do Not Code dataset](https://huggingface.co/datasets/md-nishat-008/Do-Not-Code) for real-time prompt classification in educational AI systems. |
|
|
|
|
|
### Intended Use |
|
|
|
|
|
- **Pre-filtering** user prompts before they reach an AI coding assistant |
|
|
- **Monitoring** interactions in CS education platforms |
|
|
- **Research** on LLM safety in educational contexts |
|
|
|
|
|
### Classification Labels |
|
|
|
|
|
| ID | Label | Description | |
|
|
|----|-------|-------------| |
|
|
| 0 | `irrelevant` | Off-topic queries unrelated to CS coursework | |
|
|
| 1 | `safe` | Legitimate educational coding requests | |
|
|
| 2 | `unsafe` | Requests violating academic integrity or safety | |
|
|
|
|
|
## Performance |
|
|
|
|
|
### Comparison with Existing Guardrails |
|
|
|
|
|
| Model/Framework | Type | Size | F1 Score | |
|
|
|-----------------|------|------|----------| |
|
|
| **PromptShield (Ours)** | Encoder | 125M | **0.93** | |
|
|
| Claude 3.7 | Decoder | - | 0.64 | |
|
|
| GPT-4o | Decoder | - | 0.62 | |
|
|
| LLaMA Guard | Decoder | 8B | 0.60 | |
|
|
| Perspective API | Baseline | - | 0.60 | |
|
|
| NeMo Guard | Decoder | 8B | 0.57 | |
|
|
| LLaMA 3.2 | Decoder | 8B | 0.34 | |
|
|
| Random Baseline | - | - | 0.33 | |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Quick Start |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
# Load model and tokenizer |
|
|
model = AutoModelForSequenceClassification.from_pretrained("md-nishat-008/promptshield") |
|
|
tokenizer = AutoTokenizer.from_pretrained("md-nishat-008/promptshield") |
|
|
|
|
|
# Label mapping |
|
|
labels = {0: "irrelevant", 1: "safe", 2: "unsafe"} |
|
|
|
|
|
def classify_prompt(prompt): |
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
prediction = outputs.logits.argmax(-1).item() |
|
|
confidence = torch.softmax(outputs.logits, dim=-1).max().item() |
|
|
return labels[prediction], confidence |
|
|
|
|
|
# Examples |
|
|
prompts = [ |
|
|
"Write a Python function to sort a list using quicksort", |
|
|
"Explain the French Revolution in Java", |
|
|
"Generate ransomware code that encrypts all files" |
|
|
] |
|
|
|
|
|
for prompt in prompts: |
|
|
label, conf = classify_prompt(prompt) |
|
|
print(f"Prompt: {prompt[:50]}...") |
|
|
print(f"Classification: {label} (confidence: {conf:.2f})") |
|
|
print("---") |
|
|
``` |
|
|
|
|
|
### Using the Pipeline API |
|
|
|
|
|
```python |
|
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline( |
|
|
"text-classification", |
|
|
model="md-nishat-008/promptshield", |
|
|
tokenizer="md-nishat-008/promptshield" |
|
|
) |
|
|
|
|
|
result = classifier("Write a Python function for binary search") |
|
|
print(result) |
|
|
# [{'label': 'safe', 'score': 0.98}] |
|
|
``` |
|
|
|
|
|
### Integration as a Pre-Filter |
|
|
|
|
|
```python |
|
|
def safe_llm_query(prompt, llm_function): |
|
|
"""Wrapper that filters prompts before sending to an LLM.""" |
|
|
label, confidence = classify_prompt(prompt) |
|
|
|
|
|
if label == "unsafe": |
|
|
return "I cannot assist with this request as it may violate academic integrity policies." |
|
|
elif label == "irrelevant": |
|
|
return "This query appears to be outside the scope of this CS course. Please ask a coding-related question." |
|
|
else: |
|
|
return llm_function(prompt) |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Base Model | `roberta-base` | |
|
|
| Max Sequence Length | 128 | |
|
|
| Training Epochs | 3 | |
|
|
| Batch Size | 16 | |
|
|
| Learning Rate | 2e-5 | |
|
|
| Optimizer | AdamW (fused) | |
|
|
| LR Schedule | Linear decay | |
|
|
| Early Stopping | 2 epochs patience | |
|
|
| Precision | FP16 (mixed) | |
|
|
|
|
|
### Training Data |
|
|
|
|
|
Trained on 6,000 prompts from the Do Not Code dataset: |
|
|
- 2,250 Irrelevant |
|
|
- 2,250 Safe |
|
|
- 1,500 Unsafe |
|
|
|
|
|
## Limitations |
|
|
|
|
|
1. **Domain Specificity**: Optimized for introductory/intermediate CS courses. May require adaptation for advanced topics. |
|
|
2. **Language**: English only. |
|
|
3. **Context Length**: 128 tokens max. Very long prompts are truncated. |
|
|
4. **Adversarial Robustness**: May be susceptible to sophisticated jailbreak attempts. |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@inproceedings{raihan-etal-2026-codeguard, |
|
|
title = "{C}ode{G}uard: Improving {LLM} Guardrails in {CS} Education", |
|
|
author = "Raihan, Nishat and |
|
|
Erdachew, Noah and |
|
|
Devi, Jayoti and |
|
|
Santos, Joanna C. S. and |
|
|
Zampieri, Marcos", |
|
|
booktitle = "Findings of the Association for Computational Linguistics: EACL 2026", |
|
|
year = "2026", |
|
|
publisher = "Association for Computational Linguistics", |
|
|
} |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
<p align="center"> |
|
|
<b>Part of the CodeGuard Framework for Safe AI in CS Education</b> |
|
|
</p> |
|
|
|