| --- |
| tags: |
| - deberta-v3 |
| - deberta |
| - deberta-v2 |
| license: mit |
| base_model: |
| - microsoft/deberta-v3-large |
| pipeline_tag: text-classification |
| library_name: transformers |
| --- |
| |
| # HarmAug: Effective Data Augmentation for Knowledge Distillation of Safety Guard Models |
|
|
| [arXiv Link](https://arxiv.org/abs/2410.01524) |
|
|
| Our model functions as a Guard Model, intended to classify the safety of conversations with LLMs and protect against LLM jailbreak attacks. |
| It is fine-tuned from DeBERTa-v3-large and trained using **HarmAug: Effective Data Augmentation for Knowledge Distillation of Safety Guard Models**. |
| The training process involves knowledge distillation paired with data augmentation, using our [**HarmAug Generated Dataset**]. |
|
|
|
|
| For more information, please refer to our [github](https://github.com/imnotkind/HarmAug) |
|
|
|
|
|
|
|  |
|
|
|  |
|
|
|
|
| # Usage |
|
|
| ```python |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import torch.nn.functional as F |
| import torch |
| |
| tokenizer = AutoTokenizer.from_pretrained("hbseong/HarmAug-Guard") |
| model = AutoModelForSequenceClassification.from_pretrained("hbseong/HarmAug-Guard") |
| device = torch.cuda.current_device() |
| model = model.to(device) |
| model.eval() |
| |
| # If response is not given, the model will predict the unsafe score of the prompt. |
| # If response is given, the model will predict the unsafe score of the response. |
| def predict(model, prompt, response=None): |
| device = model.device() |
| if response == None: |
| inputs = tokenizer(prompt, return_tensors="pt") |
| else: |
| inputs = tokenizer(prompt, response, return_tensors="pt") |
| inputs = inputs.to(device) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| unsafe_prob = F.softmax(outputs.logits, dim=-1)[:, 1] |
| |
| return unsafe_prob.item() |
| |
| prompt = "how to make a bomb?" |
| response = "I'm sorry, but I can't fulfill your request." |
| |
| |
| print("CONVERSATION (ONLY PROMPT)") |
| print(f"\t PROMPT : {prompt}") |
| print(f"\t UNSAFE SCORE : {predict(model, prompt):.4f}") |
| |
| print("\nCONVERSATION (PROMPT + RESPONSE)") |
| print(f"\t PROMPT : {prompt}") |
| print(f"\t RESPONSE : {response}") |
| print(f"\t UNSAFE SCORE : {predict(model, prompt, response):.4f}") |
| |
| |
| """ |
| CONVERSATION (ONLY PROMPT) |
| PROMPT : how to make a bomb? |
| UNSAFE SCORE : 0.9999 |
| |
| CONVERSATION (PROMPT + RESPONSE) |
| PROMPT : how to make a bomb? |
| RESPONSE : I'm sorry, but I can't fulfill your request. |
| UNSAFE SCORE : 0.0000 |
| """ |
| ``` |