| import textwrap |
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| |
| model_id = "jugalgajjar/PyJavaCPP-Vuln-Fixer" |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| dtype=torch.float32, |
| device_map="cpu" |
| ) |
|
|
| SYSTEM_MESSAGE = ( |
| "You are a code security expert. Given vulnerable source code, " |
| "output ONLY the fixed version of the code with the vulnerability repaired. " |
| "Do not include explanations, just the corrected code." |
| ) |
|
|
| |
| def fix_code(language, vulnerable_code): |
| if not vulnerable_code.strip(): |
| return "Please enter the code you want to fix." |
| |
| messages = [ |
| {"role": "system", "content": SYSTEM_MESSAGE}, |
| {"role": "user", "content": f"Fix the below given vulnerable {language} code:\n{vulnerable_code}"}, |
| ] |
|
|
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=1024, |
| temperature=0.2, |
| top_p=0.95, |
| do_sample=True, |
| repetition_penalty=1.15, |
| ) |
|
|
| |
| new_tokens = outputs[0][inputs["input_ids"].shape[1]:] |
| result = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
| |
| start_idx = result.find("```") |
| end_idx = result.rfind("```") |
| if start_idx != -1 and end_idx != -1: |
| fixed_code = result[start_idx + 3 : end_idx] |
| return fixed_code.strip() |
| else: |
| return result.strip() |
|
|
| EXAMPLES = [ |
| [ |
| "python", |
| textwrap.dedent("""\ |
| import os |
| from flask import Flask, request |
| |
| app = Flask(__name__) |
| |
| @app.route("/run") |
| def run(): |
| cmd = request.args.get("cmd") |
| return os.popen(cmd).read() |
| |
| if __name__ == "__main__": |
| app.run(debug=False)"""), |
| ], |
| [ |
| "java", |
| textwrap.dedent("""\ |
| import java.sql.*; |
| import javax.servlet.http.*; |
| |
| public class UserServlet extends HttpServlet { |
| public void doGet(HttpServletRequest req, HttpServletResponse res) { |
| try { |
| String id = req.getParameter("id"); |
| Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db", "user", "pass"); |
| Statement stmt = conn.createStatement(); |
| ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id='" + id + "'"); |
| } catch (Exception e) { |
| e.printStackTrace(); |
| } |
| } |
| }"""), |
| ], |
| [ |
| "cpp", |
| textwrap.dedent("""\ |
| #include <iostream> |
| #include <cstring> |
| |
| void login(char *input) { |
| char password[8]; |
| strcpy(password, input); |
| } |
| |
| int main(int argc, char *argv[]) { |
| if (argc > 1) { |
| login(argv[1]); |
| } |
| return 0; |
| }"""), |
| ], |
| ] |
|
|
| |
| with gr.Blocks(title="PyJavaCPP Vuln-Fixer") as demo: |
| gr.Markdown("# 🛡️ PyJavaCPP Vulnerability Fixer (CPU)") |
| gr.Markdown( |
| "Select your language, paste your code, and get a secured version of your code!" |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| lang_input = gr.Dropdown( |
| choices=["python", "java", "cpp"], |
| value="python", |
| label="Target Language", |
| ) |
| code_input = gr.Textbox( |
| label="Vulnerable Code", |
| lines=15, |
| max_lines=30, |
| placeholder="Paste your vulnerable code here...", |
| ) |
| submit_btn = gr.Button("Secure My Code ✨", variant="primary") |
|
|
| with gr.Column(): |
| code_output = gr.Textbox( |
| label="Fixed Code", |
| lines=15, |
| max_lines=30, |
| interactive=False, |
| ) |
|
|
| gr.Examples( |
| examples=EXAMPLES, |
| inputs=[lang_input, code_input], |
| ) |
|
|
| submit_btn.click(fix_code, [lang_input, code_input], code_output) |
|
|
| demo.launch(ssr_mode=False) |