| """Inference module for vulnerability detection |
| Load trained models and make predictions""" |
|
|
| import torch |
| from transformers import RobertaTokenizer |
| from pathlib import Path |
| import sys |
| sys.path.append(str(Path(__file__).parent.parent.parent)) |
|
|
| from src.model import VulnerabilityCodeT5 |
|
|
| class VulnerabilityDetector: |
| def __init__(self, model_path="models/best_model_clean.pt", |
| model_name="Salesforce/codet5-base", max_length=256): |
| |
| |
| self.device = torch.device('cpu') |
| self.max_length = max_length |
|
|
| self.tokenizer = RobertaTokenizer.from_pretrained(model_name) |
|
|
| self.model = VulnerabilityCodeT5(model_name=model_name, num_labels=2) |
|
|
| state_dict = torch.load(model_path, map_location=self.device) |
| self.model.load_state_dict(state_dict) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
|
|
| print("Model Loaded Successfully") |
|
|
| self.labels = { |
| 0: "Safe Code", |
| 1: "Vulnerable Code" |
| } |
|
|
| def predict(self, code_snippet): |
| """Predict Vulnerability of Code Snippet |
| |
| Args : |
| code_snippet: String Containing source code |
| |
| Returns: |
| dict with predictions, confidence and label |
| |
| """ |
| inputs = self.tokenizer( |
| code_snippet, |
| max_length=256, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
|
|
| input_ids = inputs['input_ids'].to(self.device) |
| attention_mask = inputs['attention_mask'].to(self.device) |
|
|
| with torch.no_grad(): |
|
|
| predictions, probs = self.model.predict(input_ids, attention_mask) |
|
|
| pred_label = predictions[0].item() |
| confidence = probs[0][pred_label].item() |
|
|
| return { |
| 'prediction': pred_label, |
| 'label': self.labels[pred_label], |
| 'confidence': confidence, |
| 'probabilities':{ |
| 'safe': probs[0][0].item(), |
| 'vulnerable': probs[0][1].item() |
| } |
| } |
| |
| def analyze_batch(self, code_snippets): |
| """Analyze multiple code snippets at once""" |
| return [self.predict(code) for code in code_snippets] |
| |
| def test_inference(): |
| detector = VulnerabilityDetector() |
| |
| |
|
|
|
|
| test_cases = [ |
| { |
| "name": "Safe Bounded Copy", |
| "code": """void copy_input(const char *input) { |
| char buffer[32]; |
| strncpy(buffer, input, sizeof(buffer) - 1); |
| buffer[sizeof(buffer) - 1] = '\\0'; |
| }""" |
| }, |
| { |
| "name": "Safe fgets Input", |
| "code": """void read_input() { |
| char buffer[64]; |
| if (fgets(buffer, sizeof(buffer), stdin) != NULL) { |
| printf("%s", buffer); |
| } |
| }""" |
| }, |
| { |
| "name": "Safe malloc usage", |
| "code": """void allocate() { |
| char *buf = (char *)malloc(128); |
| if (buf == NULL) { |
| return; |
| } |
| strcpy(buf, "safe"); |
| free(buf); |
| }""" |
| }, |
| { |
| "name": "Stack Buffer Overflow", |
| "code": """void copy_input(char *input) { |
| char buffer[8]; |
| strcpy(buffer, input); |
| }""" |
| }, |
| { |
| "name": "Integer Overflow", |
| "code": """void allocate(int size) { |
| char *buf = (char *)malloc(size * sizeof(char)); |
| if (buf == NULL) return; |
| memset(buf, 'A', size + 10); |
| }""" |
| }, |
| { |
| "name": "Use After Free", |
| "code": """void uaf() { |
| char *buf = (char *)malloc(16); |
| free(buf); |
| strcpy(buf, "UAF"); |
| }""" |
| } |
| ] |
|
|
|
|
| print("\n" + "="*60) |
| print("Testing Vulnerability Detection") |
| print("="*60) |
| |
| for test in test_cases: |
| print(f"\nTest: {test['name']}") |
| print(f"Code: {test['code'][:60]}...") |
| |
| result = detector.predict(test['code']) |
| |
| print(f"Prediction: {result['label']}") |
| print(f"Confidence: {result['confidence']:.2%}") |
| print(f" - Safe: {result['probabilities']['safe']:.2%}") |
| print(f" - Vulnerable: {result['probabilities']['vulnerable']:.2%}") |
|
|
| if __name__ == "__main__": |
| test_inference() |