File size: 4,634 Bytes
1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 55be89a 1fff7d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import textwrap
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Model Configuration
model_id = "jugalgajjar/PyJavaCPP-Vuln-Fixer"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.float32,
device_map="cpu"
)
SYSTEM_MESSAGE = (
"You are a code security expert. Given vulnerable source code, "
"output ONLY the fixed version of the code with the vulnerability repaired. "
"Do not include explanations, just the corrected code."
)
# Prediction
def fix_code(language, vulnerable_code):
if not vulnerable_code.strip():
return "Please enter the code you want to fix."
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": f"Fix the below given vulnerable {language} code:\n{vulnerable_code}"},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Encode inputs to CPU
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.2,
top_p=0.95,
do_sample=True,
repetition_penalty=1.15,
)
# Decode only the new tokens
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
result = tokenizer.decode(new_tokens, skip_special_tokens=True)
# Extract the fixed code inside ```
start_idx = result.find("```")
end_idx = result.rfind("```")
if start_idx != -1 and end_idx != -1:
fixed_code = result[start_idx + 3 : end_idx]
return fixed_code.strip()
else:
return result.strip()
EXAMPLES = [
[
"python",
textwrap.dedent("""\
import os
from flask import Flask, request
app = Flask(__name__)
@app.route("/run")
def run():
cmd = request.args.get("cmd")
return os.popen(cmd).read()
if __name__ == "__main__":
app.run(debug=False)"""),
],
[
"java",
textwrap.dedent("""\
import java.sql.*;
import javax.servlet.http.*;
public class UserServlet extends HttpServlet {
public void doGet(HttpServletRequest req, HttpServletResponse res) {
try {
String id = req.getParameter("id");
Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db", "user", "pass");
Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id='" + id + "'");
} catch (Exception e) {
e.printStackTrace();
}
}
}"""),
],
[
"cpp",
textwrap.dedent("""\
#include <iostream>
#include <cstring>
void login(char *input) {
char password[8];
strcpy(password, input);
}
int main(int argc, char *argv[]) {
if (argc > 1) {
login(argv[1]);
}
return 0;
}"""),
],
]
# UI Layout
with gr.Blocks(title="PyJavaCPP Vuln-Fixer") as demo:
gr.Markdown("# 🛡️ PyJavaCPP Vulnerability Fixer (CPU)")
gr.Markdown(
"Select your language, paste your code, and get a secured version of your code!"
)
with gr.Row():
with gr.Column():
lang_input = gr.Dropdown(
choices=["python", "java", "cpp"],
value="python",
label="Target Language",
)
code_input = gr.Textbox(
label="Vulnerable Code",
lines=15,
max_lines=30,
placeholder="Paste your vulnerable code here...",
)
submit_btn = gr.Button("Secure My Code ✨", variant="primary")
with gr.Column():
code_output = gr.Textbox(
label="Fixed Code",
lines=15,
max_lines=30,
interactive=False,
)
gr.Examples(
examples=EXAMPLES,
inputs=[lang_input, code_input],
)
submit_btn.click(fix_code, [lang_input, code_input], code_output)
demo.launch(ssr_mode=False) |