File size: 4,304 Bytes
21ff8b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from transformers import AutoTokenizer, RobertaForSequenceClassification
import logging
import warnings

# Suppress warnings to keep logs clean
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
warnings.filterwarnings("ignore", category=FutureWarning)

class MLEngine:
    def __init__(self, logger_callback=None):
        self.log = logger_callback if logger_callback else print
        self.model_name = "microsoft/graphcodebert-base"
        
        self.log("🧠 [ML ENGINE] Initializing Neural Engine...")
        self.log("   └── Loading Tensor Weights (This happens once)...")
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = RobertaForSequenceClassification.from_pretrained(
                self.model_name,
                use_safetensors=True,
                ignore_mismatched_sizes=True
            )
            
            # Use CPU for stability on free tier (GPU runs out of VRAM too fast)
            self.device = torch.device("cpu")
            self.model.to(self.device)
            
            self.log(f"   └── ✅ Model Ready. Running on: {self.device.type.upper()}")
            self.is_ready = True
            
        except Exception as e:
            self.log(f"⚠️ [ML ENGINE FAILURE] {str(e)}")
            self.is_ready = False

    def predict_vulnerability(self, code_content):
        """
        Sliding Window Scanner:
        Scans ANY file size by breaking it into 512-token chunks.
        Never freezes, never runs out of RAM.
        """
        if not self.is_ready or not code_content:
            return False, 0.0, []

        # 1. Tokenize the WHOLE file first (AutoTokenizer handles this fast)
        # We start with a max length of 100,000 to prevent infinite loops on binary garbage
        inputs = self.tokenizer(
            code_content, 
            return_tensors="pt", 
            truncation=True, 
            max_length=100000, 
            padding=False
        )
        
        input_ids = inputs['input_ids'][0]
        
        # 2. Define Window Size (510 tokens + 2 special tokens = 512 limit)
        window_size = 510
        stride = 510 # No overlap needed for simple classification, makes it 2x faster
        
        total_tokens = len(input_ids)
        chunks = []
        
        # 3. Create the windows
        for i in range(0, total_tokens, stride):
            chunk = input_ids[i : i + window_size]
            if len(chunk) < 10: continue # Skip tiny fragments
            chunks.append(chunk)

        highest_confidence = 0.0
        is_vulnerable = False
        bad_snippets = []

        # 4. Process each window sequentially
        # This loop ensures the UI updates and doesn't look "stuck"
        for i, chunk_ids in enumerate(chunks):
            try:
                # Prepare chunk for model
                chunk_ids = chunk_ids.unsqueeze(0).to(self.device)
                
                with torch.no_grad():
                    outputs = self.model(input_ids=chunk_ids)
                    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
                    
                    # Microsoft CodeBERT usually outputs [Safe_Score, Unsafe_Score]
                    vuln_score = probs[0][1].item() 
                    
                    if vuln_score > 0.50: # Threshold
                        is_vulnerable = True
                        if vuln_score > highest_confidence:
                            highest_confidence = vuln_score
                        
                        # Decode back to text to show the user EXACTLY what is wrong
                        decoded_snippet = self.tokenizer.decode(chunk_ids[0], skip_special_tokens=True)
                        bad_snippets.append(decoded_snippet)
                        
                        # Optimization: If we found high-confidence issues, stop scanning this file
                        # to save time. We only need to prove it's vulnerable.
                        if vuln_score > 0.85:
                            break
                            
            except Exception:
                continue

        return is_vulnerable, round(highest_confidence * 100, 2), bad_snippets