File size: 12,136 Bytes
bec7d28
 
2883685
 
 
bec7d28
 
 
2883685
bec7d28
 
 
 
 
2883685
bec7d28
2883685
bec7d28
 
2883685
bec7d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2883685
 
bec7d28
 
 
 
 
 
 
 
2883685
bec7d28
 
 
 
 
2883685
bec7d28
 
 
 
 
 
2883685
 
 
bec7d28
 
 
 
 
 
2883685
bec7d28
 
 
 
 
 
 
 
2883685
bec7d28
2883685
bec7d28
 
 
2883685
bec7d28
 
 
 
 
 
 
 
2883685
7a50c62
bec7d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2883685
7a50c62
 
bec7d28
7a50c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2883685
7a50c62
2883685
7a50c62
 
 
 
 
 
2883685
7a50c62
bec7d28
 
 
2883685
 
7a50c62
 
2883685
bec7d28
 
 
2883685
bec7d28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

import os, re, random
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ── Load model from HF Hub ─────────────────────────────────
MODEL_ID  = "MUHAMMADSAADAMIN/PolyGuard"
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model     = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
print("βœ“ Model ready")

# ── Vulnerability rules ────────────────────────────────────
VULN_RULES = {
    "python": [
        (r"execute\s*\(\s*[f'\"].*?\{",          "Use parameterized queries instead of building SQL strings manually."),
        (r"execute\s*\(\s*\".*?%",                  "Use parameterized queries instead of building SQL strings manually."),
        (r"eval\s*\(",                                 "Avoid eval() β€” it executes arbitrary code and is a critical security risk."),
        (r"exec\s*\(",                                 "Avoid exec() β€” it executes arbitrary code and is a critical security risk."),
        (r"pickle\.loads?\s*\(",                      "Avoid pickle.load() on untrusted data β€” it can execute arbitrary code."),
        (r"subprocess.*shell\s*=\s*True",              "Never use shell=True in subprocess β€” use a list of arguments instead."),
        (r"os\.system\s*\(",                          "Avoid os.system() β€” use subprocess with a list of arguments instead."),
        (r"hashlib\.md5\s*\(",                        "MD5 is cryptographically broken β€” use SHA-256 or stronger."),
        (r"hashlib\.sha1\s*\(",                       "SHA-1 is weak β€” use SHA-256 or stronger."),
        (r"random\.(random|randint|choice)\s*\(",     "Use secrets module instead of random for security-sensitive operations."),
        (r"open\s*\(.*['\"]w['\"]",               "Validate file paths before writing to prevent path traversal attacks."),
        (r"request\.(args|form|json)\[",               "Validate and sanitize all user input before use."),
        (r"render_template_string\s*\(",               "Avoid render_template_string with user input β€” use template files instead."),
        (r"yaml\.load\s*\(",                          "Use yaml.safe_load() instead of yaml.load() to prevent code execution."),
        (r"SSL_VERIFY\s*=\s*False|verify\s*=\s*False", "Never disable SSL verification in production."),
        (r"password\s*=\s*['\"][^'\"]{1,20}['\"]", "Hardcoded password detected β€” use environment variables instead."),
        (r"secret\s*=\s*['\"][^'\"]{1,20}['\"]",   "Hardcoded secret detected β€” use environment variables instead."),
        (r"api_key\s*=\s*['\"][^'\"]+['\"]",        "Hardcoded API key detected β€” use environment variables instead."),
    ],
    "javascript": [
        (r"eval\s*\(",                                 "Avoid eval() β€” it executes arbitrary code and is a critical security risk."),
        (r"innerHTML\s*=",                              "Avoid innerHTML β€” use textContent or DOMPurify to prevent XSS."),
        (r"document\.write\s*\(",                     "Avoid document.write() β€” it can lead to XSS vulnerabilities."),
        (r"dangerouslySetInnerHTML",                     "Avoid dangerouslySetInnerHTML β€” sanitize content with DOMPurify first."),
        (r"localStorage\.setItem.*password",            "Never store passwords or secrets in localStorage."),
        (r"Math\.random\s*\(",                        "Use crypto.getRandomValues() instead of Math.random() for security tokens."),
        (r"http://",                                     "Use HTTPS instead of HTTP for all external requests."),
        (r"password\s*=\s*['\"][^'\"]+['\"]",   "Hardcoded password detected β€” use environment variables instead."),
    ],
    "sql": [
        (r"'\s*\+\s*",                               "String concatenation in SQL is vulnerable to injection β€” use parameterized queries."),
        (r"GRANT ALL",                                   "Avoid GRANT ALL β€” apply principle of least privilege."),
        (r"DROP TABLE",                                  "Dangerous DDL statement detected β€” ensure proper access controls."),
        (r"xp_cmdshell",                                 "xp_cmdshell is a critical security risk β€” disable it on the server."),
    ],
    "php": [
        (r"mysql_query\s*\(",                          "mysql_* functions are deprecated β€” use PDO or mysqli with prepared statements."),
        (r"\$_GET\[|\$_POST\[|\$_REQUEST\[",      "Sanitize all user input from $_GET/$_POST/$_REQUEST before use."),
        (r"eval\s*\(",                                 "Avoid eval() β€” it executes arbitrary code."),
        (r"system\s*\(|exec\s*\(",                   "Avoid system()/exec() with user input β€” use escapeshellarg()."),
        (r"md5\s*\(",                                  "MD5 is not suitable for password hashing β€” use password_hash() instead."),
    ],
}

CODE_TIPS = {
    "python":     ["Use list comprehensions instead of for loops.", "Use f-strings for string formatting.", "Use with open() for file handling.", "Add type hints to function signatures.", "Use logging instead of print() in production.", "Use dataclasses or Pydantic instead of plain dicts."],
    "javascript": ["Use const and let instead of var.", "Use async/await instead of callback chains.", "Use strict equality (===) instead of ==.", "Prefer arrow functions for concise syntax."],
    "sql":        ["Always use parameterized queries.", "Add indexes on frequently queried columns.", "Use EXPLAIN to analyze query performance."],
    "php":        ["Use Composer for dependency management.", "Enable strict_types=1 at the top of files.", "Use prepared statements for all database queries."],
}

SMART_TIPS = {
    "sql_injection":     "Use parameterized queries e.g. cursor.execute(query, params) to prevent SQL injection.",
    "code_execution":    "Replace eval()/exec() with ast.literal_eval() for safe expression parsing.",
    "hardcoded_secrets": "Store secrets in environment variables and load with os.environ.get(KEY).",
    "weak_crypto":       "Replace MD5/SHA1 with hashlib.sha256() or bcrypt for password hashing.",
    "xss":               "Sanitize output with DOMPurify or use textContent instead of innerHTML.",
    "command_injection": "Pass arguments as a list to subprocess.run() and never use shell=True.",
}

SUPPORTED_LANGUAGES = list(VULN_RULES.keys()) + ["java", "c", "cpp", "go", "ruby", "rust"]

def analyze_code(code: str, language: str):
    language = language.lower().strip()
    inputs   = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding=True)
    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.softmax(logits, dim=1).squeeze().tolist()
    if not isinstance(probs, list):
        probs = [probs, 1 - probs]

    clean_conf = round(probs[0] * 100, 1)
    vuln_conf  = round(probs[1] * 100, 1)

    # ── Static rule scan ───────────────────────────────────
    findings, matched_categories = [], set()
    for pattern, message in VULN_RULES.get(language, []):
        if re.search(pattern, code, re.IGNORECASE):
            if message not in findings:
                findings.append(message)
                if "sql" in message.lower() or "query" in message.lower():
                    matched_categories.add("sql_injection")
                elif "eval" in message.lower() or "exec" in message.lower():
                    matched_categories.add("code_execution")
                elif any(w in message.lower() for w in ["password", "secret", "api_key"]):
                    matched_categories.add("hardcoded_secrets")
                elif any(w in message.lower() for w in ["md5", "sha1", "hash"]):
                    matched_categories.add("weak_crypto")
                elif "xss" in message.lower() or "innerhtml" in message.lower():
                    matched_categories.add("xss")
                elif "shell" in message.lower() or "subprocess" in message.lower():
                    matched_categories.add("command_injection")

    # ── Code quality score (10 = perfect, 0 = critical) ───
    # Start at 10, deduct for each issue found
    model_uncertain = abs(vuln_conf - clean_conf) < 15.0

    # Deductions
    model_penalty    = (vuln_conf / 100) * 4.0 if not model_uncertain else 0.0
    findings_penalty = min(len(findings) * 1.2, 6.0)
    total_deduction  = round(min(model_penalty + findings_penalty, 10.0), 1)

    quality_score = round(max(10.0 - total_deduction, 0.0), 1)

    # ── Grade label ────────────────────────────────────────
    if quality_score >= 9.0:
        grade, verdict = "A",  "EXCELLENT"
    elif quality_score >= 7.5:
        grade, verdict = "B",  "GOOD"
    elif quality_score >= 6.0:
        grade, verdict = "C",  "FAIR"
    elif quality_score >= 4.0:
        grade, verdict = "D",  "POOR"
    elif quality_score >= 2.0:
        grade, verdict = "F",  "VULNERABLE"
    else:
        grade, verdict = "F-", "CRITICAL"

    # ── Risk label (kept for API compatibility) ────────────
    if quality_score <= 3.0:   risk = "critical"
    elif quality_score <= 5.0: risk = "high"
    elif quality_score <= 6.5: risk = "medium"
    elif quality_score <= 8.5: risk = "low"
    else:                      risk = "safe"

    # ── Smart tips ─────────────────────────────────────────
    tips    = [SMART_TIPS[c] for c in matched_categories if c in SMART_TIPS]
    general = CODE_TIPS.get(language, ["Follow security best practices."])
    tips   += random.sample(general, min(max(0, 3 - len(tips)), len(general)))

    return {
        "score":            quality_score,   # now 10 = best, 0 = worst
        "grade":            grade,
        "risk":             risk,
        "verdict":          verdict,
        "clean_confidence": clean_conf,
        "vuln_confidence":  vuln_conf,
        "findings":         findings,
        "tips":             tips[:3],
        "language":         language,
        "lines_analyzed":   len(code.splitlines()),
    }

# ── FastAPI app ────────────────────────────────────────────
app = FastAPI(
    title       = "PolyGuard API",
    description = "Code vulnerability scanner powered by CodeBERT fine-tuned on CrossVUL.",
    version     = "3.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins     = ["*"],
    allow_credentials = True,
    allow_methods     = ["*"],
    allow_headers     = ["*"],
)

class AnalyzeRequest(BaseModel):
    code:     str
    language: str

class SnippetItem(BaseModel):
    code:     str
    language: str

class BatchRequest(BaseModel):
    snippets: List[SnippetItem]

@app.get("/")
def index():
    return {
        "service":             "PolyGuard β€” Code Vulnerability Scanner",
        "version":             "3.0.0",
        "model":               MODEL_ID,
        "status":              "running",
        "docs":                "/docs",
        "supported_languages": SUPPORTED_LANGUAGES,
        "endpoints":           ["/health", "/analyze", "/analyze_batch"],
    }

@app.get("/health")
def health():
    return {"status": "ok", "model": MODEL_ID, "supported_languages": SUPPORTED_LANGUAGES}

@app.post("/analyze")
def analyze(req: AnalyzeRequest):
    return analyze_code(req.code, req.language)

@app.post("/analyze_batch")
def analyze_batch(req: BatchRequest):
    results = [analyze_code(s.code, s.language) for s in req.snippets]
    return {"count": len(results), "results": results}