|
|
--- |
|
|
language: |
|
|
- en |
|
|
- code |
|
|
tags: |
|
|
- security |
|
|
- vulnerability-detection |
|
|
- codebert |
|
|
- classification |
|
|
license: mit |
|
|
--- |
|
|
|
|
|
# codebert_vulnerability_scanner |
|
|
|
|
|
## Overview |
|
|
|
|
|
`codebert_vulnerability_scanner` is a fine-tuned RoBERTa model (specifically based on Microsoft's CodeBERT) designed to detect potential security vulnerabilities in source code snippets. It treats vulnerability detection as a binary classification task, labeling code as either `SAFE` or `VULNERABLE`. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
This model utilizes the `RobertaForSequenceClassification` architecture. It was pre-trained on the CodeSearchNet dataset (a large collection of function-level code across multiple programming languages) and subsequently fine-tuned on a curated dataset of C and C++ functions labeled with Common Weakness Enumerations (CWEs), such as buffer overflows and memory leaks. |
|
|
|
|
|
- **Base Model:** `microsoft/codebert-base` |
|
|
- **Head:** A linear classification head on top of the pooled output. |
|
|
- **Input:** Source code functions (tokenized). |
|
|
- **Output:** Logits for two classes: SAFE (0) and VULNERABLE (1). |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
This model is intended primarily for DevSecOps workflows and static analysis research. |
|
|
|
|
|
- **Automated Code Review:** Scanning pull requests for high-risk code patterns before merging. |
|
|
- **Security Auditing:** Quickly analyzing large legacy codebases to prioritize manual security reviews. |
|
|
- **Research:** Benchmarking against traditional static analysis security testing (SAST) tools. |
|
|
|
|
|
### How to use |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
model_name = "your_username/codebert_vulnerability_scanner" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
# Example C function snippet |
|
|
code_snippet = """ |
|
|
void vulnerable_function(char *user_input) { |
|
|
char buffer[64]; |
|
|
strcpy(buffer, user_input); // Potential buffer overflow |
|
|
} |
|
|
""" |
|
|
|
|
|
inputs = tokenizer(code_snippet, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
predicted_class_id = logits.argmax().item() |
|
|
labels = model.config.id2label |
|
|
print(f"Prediction: {labels[predicted_class_id]}") |
|
|
# Expected output: VULNERABLE |