Spaces:
Running
Running
File size: 12,136 Bytes
bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 bec7d28 2883685 7a50c62 bec7d28 2883685 7a50c62 bec7d28 7a50c62 2883685 7a50c62 2883685 7a50c62 2883685 7a50c62 bec7d28 2883685 7a50c62 2883685 bec7d28 2883685 bec7d28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import os, re, random
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ββ Load model from HF Hub βββββββββββββββββββββββββββββββββ
MODEL_ID = "MUHAMMADSAADAMIN/PolyGuard"
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
print("β Model ready")
# ββ Vulnerability rules ββββββββββββββββββββββββββββββββββββ
VULN_RULES = {
"python": [
(r"execute\s*\(\s*[f'\"].*?\{", "Use parameterized queries instead of building SQL strings manually."),
(r"execute\s*\(\s*\".*?%", "Use parameterized queries instead of building SQL strings manually."),
(r"eval\s*\(", "Avoid eval() β it executes arbitrary code and is a critical security risk."),
(r"exec\s*\(", "Avoid exec() β it executes arbitrary code and is a critical security risk."),
(r"pickle\.loads?\s*\(", "Avoid pickle.load() on untrusted data β it can execute arbitrary code."),
(r"subprocess.*shell\s*=\s*True", "Never use shell=True in subprocess β use a list of arguments instead."),
(r"os\.system\s*\(", "Avoid os.system() β use subprocess with a list of arguments instead."),
(r"hashlib\.md5\s*\(", "MD5 is cryptographically broken β use SHA-256 or stronger."),
(r"hashlib\.sha1\s*\(", "SHA-1 is weak β use SHA-256 or stronger."),
(r"random\.(random|randint|choice)\s*\(", "Use secrets module instead of random for security-sensitive operations."),
(r"open\s*\(.*['\"]w['\"]", "Validate file paths before writing to prevent path traversal attacks."),
(r"request\.(args|form|json)\[", "Validate and sanitize all user input before use."),
(r"render_template_string\s*\(", "Avoid render_template_string with user input β use template files instead."),
(r"yaml\.load\s*\(", "Use yaml.safe_load() instead of yaml.load() to prevent code execution."),
(r"SSL_VERIFY\s*=\s*False|verify\s*=\s*False", "Never disable SSL verification in production."),
(r"password\s*=\s*['\"][^'\"]{1,20}['\"]", "Hardcoded password detected β use environment variables instead."),
(r"secret\s*=\s*['\"][^'\"]{1,20}['\"]", "Hardcoded secret detected β use environment variables instead."),
(r"api_key\s*=\s*['\"][^'\"]+['\"]", "Hardcoded API key detected β use environment variables instead."),
],
"javascript": [
(r"eval\s*\(", "Avoid eval() β it executes arbitrary code and is a critical security risk."),
(r"innerHTML\s*=", "Avoid innerHTML β use textContent or DOMPurify to prevent XSS."),
(r"document\.write\s*\(", "Avoid document.write() β it can lead to XSS vulnerabilities."),
(r"dangerouslySetInnerHTML", "Avoid dangerouslySetInnerHTML β sanitize content with DOMPurify first."),
(r"localStorage\.setItem.*password", "Never store passwords or secrets in localStorage."),
(r"Math\.random\s*\(", "Use crypto.getRandomValues() instead of Math.random() for security tokens."),
(r"http://", "Use HTTPS instead of HTTP for all external requests."),
(r"password\s*=\s*['\"][^'\"]+['\"]", "Hardcoded password detected β use environment variables instead."),
],
"sql": [
(r"'\s*\+\s*", "String concatenation in SQL is vulnerable to injection β use parameterized queries."),
(r"GRANT ALL", "Avoid GRANT ALL β apply principle of least privilege."),
(r"DROP TABLE", "Dangerous DDL statement detected β ensure proper access controls."),
(r"xp_cmdshell", "xp_cmdshell is a critical security risk β disable it on the server."),
],
"php": [
(r"mysql_query\s*\(", "mysql_* functions are deprecated β use PDO or mysqli with prepared statements."),
(r"\$_GET\[|\$_POST\[|\$_REQUEST\[", "Sanitize all user input from $_GET/$_POST/$_REQUEST before use."),
(r"eval\s*\(", "Avoid eval() β it executes arbitrary code."),
(r"system\s*\(|exec\s*\(", "Avoid system()/exec() with user input β use escapeshellarg()."),
(r"md5\s*\(", "MD5 is not suitable for password hashing β use password_hash() instead."),
],
}
CODE_TIPS = {
"python": ["Use list comprehensions instead of for loops.", "Use f-strings for string formatting.", "Use with open() for file handling.", "Add type hints to function signatures.", "Use logging instead of print() in production.", "Use dataclasses or Pydantic instead of plain dicts."],
"javascript": ["Use const and let instead of var.", "Use async/await instead of callback chains.", "Use strict equality (===) instead of ==.", "Prefer arrow functions for concise syntax."],
"sql": ["Always use parameterized queries.", "Add indexes on frequently queried columns.", "Use EXPLAIN to analyze query performance."],
"php": ["Use Composer for dependency management.", "Enable strict_types=1 at the top of files.", "Use prepared statements for all database queries."],
}
SMART_TIPS = {
"sql_injection": "Use parameterized queries e.g. cursor.execute(query, params) to prevent SQL injection.",
"code_execution": "Replace eval()/exec() with ast.literal_eval() for safe expression parsing.",
"hardcoded_secrets": "Store secrets in environment variables and load with os.environ.get(KEY).",
"weak_crypto": "Replace MD5/SHA1 with hashlib.sha256() or bcrypt for password hashing.",
"xss": "Sanitize output with DOMPurify or use textContent instead of innerHTML.",
"command_injection": "Pass arguments as a list to subprocess.run() and never use shell=True.",
}
SUPPORTED_LANGUAGES = list(VULN_RULES.keys()) + ["java", "c", "cpp", "go", "ruby", "rust"]
def analyze_code(code: str, language: str):
language = language.lower().strip()
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1).squeeze().tolist()
if not isinstance(probs, list):
probs = [probs, 1 - probs]
clean_conf = round(probs[0] * 100, 1)
vuln_conf = round(probs[1] * 100, 1)
# ββ Static rule scan βββββββββββββββββββββββββββββββββββ
findings, matched_categories = [], set()
for pattern, message in VULN_RULES.get(language, []):
if re.search(pattern, code, re.IGNORECASE):
if message not in findings:
findings.append(message)
if "sql" in message.lower() or "query" in message.lower():
matched_categories.add("sql_injection")
elif "eval" in message.lower() or "exec" in message.lower():
matched_categories.add("code_execution")
elif any(w in message.lower() for w in ["password", "secret", "api_key"]):
matched_categories.add("hardcoded_secrets")
elif any(w in message.lower() for w in ["md5", "sha1", "hash"]):
matched_categories.add("weak_crypto")
elif "xss" in message.lower() or "innerhtml" in message.lower():
matched_categories.add("xss")
elif "shell" in message.lower() or "subprocess" in message.lower():
matched_categories.add("command_injection")
# ββ Code quality score (10 = perfect, 0 = critical) βββ
# Start at 10, deduct for each issue found
model_uncertain = abs(vuln_conf - clean_conf) < 15.0
# Deductions
model_penalty = (vuln_conf / 100) * 4.0 if not model_uncertain else 0.0
findings_penalty = min(len(findings) * 1.2, 6.0)
total_deduction = round(min(model_penalty + findings_penalty, 10.0), 1)
quality_score = round(max(10.0 - total_deduction, 0.0), 1)
# ββ Grade label ββββββββββββββββββββββββββββββββββββββββ
if quality_score >= 9.0:
grade, verdict = "A", "EXCELLENT"
elif quality_score >= 7.5:
grade, verdict = "B", "GOOD"
elif quality_score >= 6.0:
grade, verdict = "C", "FAIR"
elif quality_score >= 4.0:
grade, verdict = "D", "POOR"
elif quality_score >= 2.0:
grade, verdict = "F", "VULNERABLE"
else:
grade, verdict = "F-", "CRITICAL"
# ββ Risk label (kept for API compatibility) ββββββββββββ
if quality_score <= 3.0: risk = "critical"
elif quality_score <= 5.0: risk = "high"
elif quality_score <= 6.5: risk = "medium"
elif quality_score <= 8.5: risk = "low"
else: risk = "safe"
# ββ Smart tips βββββββββββββββββββββββββββββββββββββββββ
tips = [SMART_TIPS[c] for c in matched_categories if c in SMART_TIPS]
general = CODE_TIPS.get(language, ["Follow security best practices."])
tips += random.sample(general, min(max(0, 3 - len(tips)), len(general)))
return {
"score": quality_score, # now 10 = best, 0 = worst
"grade": grade,
"risk": risk,
"verdict": verdict,
"clean_confidence": clean_conf,
"vuln_confidence": vuln_conf,
"findings": findings,
"tips": tips[:3],
"language": language,
"lines_analyzed": len(code.splitlines()),
}
# ββ FastAPI app ββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(
title = "PolyGuard API",
description = "Code vulnerability scanner powered by CodeBERT fine-tuned on CrossVUL.",
version = "3.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
class AnalyzeRequest(BaseModel):
code: str
language: str
class SnippetItem(BaseModel):
code: str
language: str
class BatchRequest(BaseModel):
snippets: List[SnippetItem]
@app.get("/")
def index():
return {
"service": "PolyGuard β Code Vulnerability Scanner",
"version": "3.0.0",
"model": MODEL_ID,
"status": "running",
"docs": "/docs",
"supported_languages": SUPPORTED_LANGUAGES,
"endpoints": ["/health", "/analyze", "/analyze_batch"],
}
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_ID, "supported_languages": SUPPORTED_LANGUAGES}
@app.post("/analyze")
def analyze(req: AnalyzeRequest):
return analyze_code(req.code, req.language)
@app.post("/analyze_batch")
def analyze_batch(req: BatchRequest):
results = [analyze_code(s.code, s.language) for s in req.snippets]
return {"count": len(results), "results": results}
|