polyguard-api / app.py
MUHAMMADSAADAMIN's picture
done
7a50c62 verified
import os, re, random
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ── Load model from HF Hub ─────────────────────────────────
MODEL_ID = "MUHAMMADSAADAMIN/PolyGuard"
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval()
print("βœ“ Model ready")
# ── Vulnerability rules ────────────────────────────────────
VULN_RULES = {
"python": [
(r"execute\s*\(\s*[f'\"].*?\{", "Use parameterized queries instead of building SQL strings manually."),
(r"execute\s*\(\s*\".*?%", "Use parameterized queries instead of building SQL strings manually."),
(r"eval\s*\(", "Avoid eval() β€” it executes arbitrary code and is a critical security risk."),
(r"exec\s*\(", "Avoid exec() β€” it executes arbitrary code and is a critical security risk."),
(r"pickle\.loads?\s*\(", "Avoid pickle.load() on untrusted data β€” it can execute arbitrary code."),
(r"subprocess.*shell\s*=\s*True", "Never use shell=True in subprocess β€” use a list of arguments instead."),
(r"os\.system\s*\(", "Avoid os.system() β€” use subprocess with a list of arguments instead."),
(r"hashlib\.md5\s*\(", "MD5 is cryptographically broken β€” use SHA-256 or stronger."),
(r"hashlib\.sha1\s*\(", "SHA-1 is weak β€” use SHA-256 or stronger."),
(r"random\.(random|randint|choice)\s*\(", "Use secrets module instead of random for security-sensitive operations."),
(r"open\s*\(.*['\"]w['\"]", "Validate file paths before writing to prevent path traversal attacks."),
(r"request\.(args|form|json)\[", "Validate and sanitize all user input before use."),
(r"render_template_string\s*\(", "Avoid render_template_string with user input β€” use template files instead."),
(r"yaml\.load\s*\(", "Use yaml.safe_load() instead of yaml.load() to prevent code execution."),
(r"SSL_VERIFY\s*=\s*False|verify\s*=\s*False", "Never disable SSL verification in production."),
(r"password\s*=\s*['\"][^'\"]{1,20}['\"]", "Hardcoded password detected β€” use environment variables instead."),
(r"secret\s*=\s*['\"][^'\"]{1,20}['\"]", "Hardcoded secret detected β€” use environment variables instead."),
(r"api_key\s*=\s*['\"][^'\"]+['\"]", "Hardcoded API key detected β€” use environment variables instead."),
],
"javascript": [
(r"eval\s*\(", "Avoid eval() β€” it executes arbitrary code and is a critical security risk."),
(r"innerHTML\s*=", "Avoid innerHTML β€” use textContent or DOMPurify to prevent XSS."),
(r"document\.write\s*\(", "Avoid document.write() β€” it can lead to XSS vulnerabilities."),
(r"dangerouslySetInnerHTML", "Avoid dangerouslySetInnerHTML β€” sanitize content with DOMPurify first."),
(r"localStorage\.setItem.*password", "Never store passwords or secrets in localStorage."),
(r"Math\.random\s*\(", "Use crypto.getRandomValues() instead of Math.random() for security tokens."),
(r"http://", "Use HTTPS instead of HTTP for all external requests."),
(r"password\s*=\s*['\"][^'\"]+['\"]", "Hardcoded password detected β€” use environment variables instead."),
],
"sql": [
(r"'\s*\+\s*", "String concatenation in SQL is vulnerable to injection β€” use parameterized queries."),
(r"GRANT ALL", "Avoid GRANT ALL β€” apply principle of least privilege."),
(r"DROP TABLE", "Dangerous DDL statement detected β€” ensure proper access controls."),
(r"xp_cmdshell", "xp_cmdshell is a critical security risk β€” disable it on the server."),
],
"php": [
(r"mysql_query\s*\(", "mysql_* functions are deprecated β€” use PDO or mysqli with prepared statements."),
(r"\$_GET\[|\$_POST\[|\$_REQUEST\[", "Sanitize all user input from $_GET/$_POST/$_REQUEST before use."),
(r"eval\s*\(", "Avoid eval() β€” it executes arbitrary code."),
(r"system\s*\(|exec\s*\(", "Avoid system()/exec() with user input β€” use escapeshellarg()."),
(r"md5\s*\(", "MD5 is not suitable for password hashing β€” use password_hash() instead."),
],
}
CODE_TIPS = {
"python": ["Use list comprehensions instead of for loops.", "Use f-strings for string formatting.", "Use with open() for file handling.", "Add type hints to function signatures.", "Use logging instead of print() in production.", "Use dataclasses or Pydantic instead of plain dicts."],
"javascript": ["Use const and let instead of var.", "Use async/await instead of callback chains.", "Use strict equality (===) instead of ==.", "Prefer arrow functions for concise syntax."],
"sql": ["Always use parameterized queries.", "Add indexes on frequently queried columns.", "Use EXPLAIN to analyze query performance."],
"php": ["Use Composer for dependency management.", "Enable strict_types=1 at the top of files.", "Use prepared statements for all database queries."],
}
SMART_TIPS = {
"sql_injection": "Use parameterized queries e.g. cursor.execute(query, params) to prevent SQL injection.",
"code_execution": "Replace eval()/exec() with ast.literal_eval() for safe expression parsing.",
"hardcoded_secrets": "Store secrets in environment variables and load with os.environ.get(KEY).",
"weak_crypto": "Replace MD5/SHA1 with hashlib.sha256() or bcrypt for password hashing.",
"xss": "Sanitize output with DOMPurify or use textContent instead of innerHTML.",
"command_injection": "Pass arguments as a list to subprocess.run() and never use shell=True.",
}
SUPPORTED_LANGUAGES = list(VULN_RULES.keys()) + ["java", "c", "cpp", "go", "ruby", "rust"]
def analyze_code(code: str, language: str):
language = language.lower().strip()
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1).squeeze().tolist()
if not isinstance(probs, list):
probs = [probs, 1 - probs]
clean_conf = round(probs[0] * 100, 1)
vuln_conf = round(probs[1] * 100, 1)
# ── Static rule scan ───────────────────────────────────
findings, matched_categories = [], set()
for pattern, message in VULN_RULES.get(language, []):
if re.search(pattern, code, re.IGNORECASE):
if message not in findings:
findings.append(message)
if "sql" in message.lower() or "query" in message.lower():
matched_categories.add("sql_injection")
elif "eval" in message.lower() or "exec" in message.lower():
matched_categories.add("code_execution")
elif any(w in message.lower() for w in ["password", "secret", "api_key"]):
matched_categories.add("hardcoded_secrets")
elif any(w in message.lower() for w in ["md5", "sha1", "hash"]):
matched_categories.add("weak_crypto")
elif "xss" in message.lower() or "innerhtml" in message.lower():
matched_categories.add("xss")
elif "shell" in message.lower() or "subprocess" in message.lower():
matched_categories.add("command_injection")
# ── Code quality score (10 = perfect, 0 = critical) ───
# Start at 10, deduct for each issue found
model_uncertain = abs(vuln_conf - clean_conf) < 15.0
# Deductions
model_penalty = (vuln_conf / 100) * 4.0 if not model_uncertain else 0.0
findings_penalty = min(len(findings) * 1.2, 6.0)
total_deduction = round(min(model_penalty + findings_penalty, 10.0), 1)
quality_score = round(max(10.0 - total_deduction, 0.0), 1)
# ── Grade label ────────────────────────────────────────
if quality_score >= 9.0:
grade, verdict = "A", "EXCELLENT"
elif quality_score >= 7.5:
grade, verdict = "B", "GOOD"
elif quality_score >= 6.0:
grade, verdict = "C", "FAIR"
elif quality_score >= 4.0:
grade, verdict = "D", "POOR"
elif quality_score >= 2.0:
grade, verdict = "F", "VULNERABLE"
else:
grade, verdict = "F-", "CRITICAL"
# ── Risk label (kept for API compatibility) ────────────
if quality_score <= 3.0: risk = "critical"
elif quality_score <= 5.0: risk = "high"
elif quality_score <= 6.5: risk = "medium"
elif quality_score <= 8.5: risk = "low"
else: risk = "safe"
# ── Smart tips ─────────────────────────────────────────
tips = [SMART_TIPS[c] for c in matched_categories if c in SMART_TIPS]
general = CODE_TIPS.get(language, ["Follow security best practices."])
tips += random.sample(general, min(max(0, 3 - len(tips)), len(general)))
return {
"score": quality_score, # now 10 = best, 0 = worst
"grade": grade,
"risk": risk,
"verdict": verdict,
"clean_confidence": clean_conf,
"vuln_confidence": vuln_conf,
"findings": findings,
"tips": tips[:3],
"language": language,
"lines_analyzed": len(code.splitlines()),
}
# ── FastAPI app ────────────────────────────────────────────
app = FastAPI(
title = "PolyGuard API",
description = "Code vulnerability scanner powered by CodeBERT fine-tuned on CrossVUL.",
version = "3.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
class AnalyzeRequest(BaseModel):
code: str
language: str
class SnippetItem(BaseModel):
code: str
language: str
class BatchRequest(BaseModel):
snippets: List[SnippetItem]
@app.get("/")
def index():
return {
"service": "PolyGuard β€” Code Vulnerability Scanner",
"version": "3.0.0",
"model": MODEL_ID,
"status": "running",
"docs": "/docs",
"supported_languages": SUPPORTED_LANGUAGES,
"endpoints": ["/health", "/analyze", "/analyze_batch"],
}
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_ID, "supported_languages": SUPPORTED_LANGUAGES}
@app.post("/analyze")
def analyze(req: AnalyzeRequest):
return analyze_code(req.code, req.language)
@app.post("/analyze_batch")
def analyze_batch(req: BatchRequest):
results = [analyze_code(s.code, s.language) for s in req.snippets]
return {"count": len(results), "results": results}