jugalgajjar commited on
Commit
55be89a
·
verified ·
1 Parent(s): b8fe0f6

create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # Model Configuration
6
+ model_id = "jugalgajjar/PyJavaCPP-Vuln-Fixer"
7
+
8
+ # Load tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ dtype=torch.float32,
13
+ device_map="cpu"
14
+ )
15
+
16
+ SYSTEM_MESSAGE = (
17
+ "You are a code security expert. Given vulnerable source code, "
18
+ "output ONLY the fixed version of the code with the vulnerability repaired. "
19
+ "Do not include explanations, just the corrected code."
20
+ )
21
+
22
+ # Prediction
23
+ def fix_code(language, vulnerable_code):
24
+ if not vulnerable_code.strip():
25
+ return "Please enter the code you want to fix."
26
+
27
+ messages = [
28
+ {"role": "system", "content": SYSTEM_MESSAGE},
29
+ {"role": "user", "content": f"Fix the below given vulnerable {language} code:\n{vulnerable_code}"},
30
+ ]
31
+
32
+ prompt = tokenizer.apply_chat_template(
33
+ messages,
34
+ tokenize=False,
35
+ add_generation_prompt=True,
36
+ )
37
+
38
+ # Encode inputs to CPU
39
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
40
+
41
+ with torch.no_grad():
42
+ outputs = model.generate(
43
+ **inputs,
44
+ max_new_tokens=1024,
45
+ temperature=0.2,
46
+ top_p=0.95,
47
+ do_sample=True,
48
+ repetition_penalty=1.15,
49
+ )
50
+
51
+ # Decode only the new tokens
52
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
53
+ result = tokenizer.decode(new_tokens, skip_special_tokens=True)
54
+
55
+ # Extract the fixed code inside ```
56
+ start_idx = result.find("```")
57
+ end_idx = result.rfind("```")
58
+ if start_idx != -1 and end_idx != -1:
59
+ fixed_code = result[start_idx + 3 : end_idx]
60
+ return fixed_code.strip()
61
+ else:
62
+ return result.strip()
63
+
64
+ # UI Layout
65
+ with gr.Blocks(title="PyJavaCPP Vuln-Fixer", theme=gr.themes.Base()) as demo:
66
+ gr.Markdown("# 🛡️ PyJavaCPP Vulnerability Fixer (CPU)")
67
+ gr.Markdown("Select your language, paste your code, and get a secured version of your code!")
68
+
69
+ with gr.Row():
70
+ with gr.Column():
71
+ lang_input = gr.Dropdown(
72
+ choices=["python", "java", "cpp"],
73
+ value="python",
74
+ label="Target Language"
75
+ )
76
+ code_input = gr.Code(
77
+ label="Vulnerable Code",
78
+ language="python",
79
+ lines=12
80
+ )
81
+ submit_btn = gr.Button("Secure My Code ✨", variant="primary")
82
+
83
+ with gr.Column():
84
+ code_output = gr.Code(
85
+ label="Fixed Code",
86
+ language="python",
87
+ lines=12,
88
+ interactive=False
89
+ )
90
+
91
+ # Example Snippets for quick testing
92
+ gr.Examples(
93
+ examples=[
94
+ ["python", r"""import os
95
+ from flask import Flask, request
96
+
97
+ app = Flask(__name__)
98
+
99
+ @app.route("/run")
100
+ def run():
101
+ cmd = request.args.get("cmd")
102
+ # Vulnerable: Command Injection
103
+
104
+ return os.popen(cmd).read()
105
+
106
+ if __name__ == "__main__":
107
+ app.run(debug=False)"""],
108
+
109
+ ["java", r"""import java.sql.*;
110
+ import javax.servlet.http.*;
111
+
112
+ public class UserServlet extends HttpServlet {
113
+ public void doGet(HttpServletRequest req, HttpServletResponse res) {
114
+ try {
115
+ String id = req.getParameter("id");
116
+ Connection conn = DriverManager.getConnection("jdbc:mysql://localhost/db", "user", "pass");
117
+ Statement stmt = conn.createStatement();
118
+ // Vulnerable: SQL Injection
119
+ ResultSet rs = stmt.executeQuery("SELECT * FROM users WHERE id='" + id + "'");
120
+ } catch (Exception e) {
121
+ e.printStackTrace();
122
+ }
123
+ }
124
+
125
+ public static void main(String[] args) {
126
+ System.out.println("Servlet loaded.");
127
+ }
128
+ }"""],
129
+
130
+ ["cpp", r"""#include <iostream>
131
+ #include <cstring>
132
+
133
+ void login(char *input) {
134
+ char password[8];
135
+ // Vulnerable: Buffer Overflow
136
+ strcpy(password, input);
137
+ }
138
+
139
+ int main(int argc, char *argv[]) {
140
+ if (argc > 1) {
141
+ login(argv[1]);
142
+ }
143
+
144
+ return 0;
145
+ }"""]
146
+ ],
147
+ inputs=[lang_input, code_input]
148
+ )
149
+
150
+ # Update syntax highlighting based on dropdown
151
+ def update_syntax(lang):
152
+ return gr.update(language=lang), gr.update(language=lang)
153
+
154
+ lang_input.change(update_syntax, lang_input, [code_input, code_output])
155
+ submit_btn.click(fix_code, [lang_input, code_input], code_output)
156
+
157
+ demo.launch()