DarthBihan commited on
Commit
f6d50d1
·
verified ·
1 Parent(s): 61db20b

Upload 7 files

Browse files
Files changed (6) hide show
  1. DockerFile +17 -0
  2. app.py +42 -32
  3. extensions.py +7 -0
  4. model.py +337 -264
  5. requirements.txt +28 -34
  6. schemas.py +34 -2
DockerFile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt requirements.txt
6
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ RUN useradd -m -u 1000 user
11
+ USER user
12
+ ENV HOME=/home/user \
13
+ PATH=/home/user/.local/bin:$PATH
14
+
15
+ EXPOSE 7860
16
+
17
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "app:app", "--timeout", "120"]
app.py CHANGED
@@ -9,13 +9,12 @@ import os
9
  import subprocess
10
  import json
11
  import time
 
12
  from pymongo import MongoClient
13
  from routes.reviews import reviews_bp
14
- from flask_limiter import Limiter
15
- from flask_limiter.util import get_remote_address
16
  from datetime import datetime
17
  from datetime import timedelta
18
- from bson import ObjectId
19
  from models.reviews import reviews_collection
20
  from schemas import ScanRequest
21
  from pydantic import ValidationError
@@ -26,9 +25,13 @@ load_dotenv()
26
 
27
  app = Flask(__name__)
28
  Compress(app)
29
- CORS(app)
30
 
31
- # Load environment variables
 
 
 
 
 
32
  env = os.getenv('FLASK_ENV', 'development')
33
  if env == 'development':
34
  load_dotenv('.env.development')
@@ -47,9 +50,7 @@ else:
47
  "CACHE_DEFAULT_TIMEOUT": 3600
48
  })
49
 
50
- cache.init_app(app)
51
-
52
- cache.init_app(app)
53
 
54
  def files_hash(files):
55
  h = hashlib.sha256()
@@ -75,10 +76,6 @@ scan_history = db["scan_history"]
75
 
76
  app.register_blueprint(auth_bp, url_prefix="/api")
77
 
78
- limiter = Limiter(
79
- key_func=get_remote_address,
80
- default_limits=["20 per minute"]
81
- )
82
  limiter.init_app(app)
83
 
84
  @app.route('/api/scan', methods=['POST'])
@@ -87,7 +84,6 @@ limiter.init_app(app)
87
  def scan_code():
88
  try:
89
  data = request.get_json()
90
- app.logger.info(f"Incoming request: {data}")
91
 
92
  try:
93
  req = ScanRequest(**data)
@@ -95,11 +91,13 @@ def scan_code():
95
  app.logger.error(f"Validation error: {e.errors()}")
96
  return jsonify({"error": e.errors()}), 400
97
 
98
- files = [f.dict() for f in req.files]
99
  language = req.language.lower()
 
 
100
  username = get_jwt_identity()
101
 
102
- key = f"scan:{username}:{files_hash(files)}"
103
  cached = cache.get(key)
104
  if cached:
105
  return jsonify({"result": cached, "cached": True})
@@ -111,22 +109,28 @@ def scan_code():
111
  code_file.write(f["content"])
112
 
113
  if language == "python":
114
- scan_command = ["python", "-m", "bandit", "-r", temp_dir, "-f", "json"]
115
  elif language == "javascript":
116
- scan_command = ["python", "-m", "semgrep", "--config=p/javascript", "--json", temp_dir]
117
  else:
118
  return jsonify({"error": "Unsupported language"}), 400
119
 
120
  app.logger.info(f"Running: {' '.join(scan_command)}")
121
  result = subprocess.run(scan_command, capture_output=True, text=True)
122
 
123
- app.logger.info(f"stdout: {result.stdout[:500]}")
124
  app.logger.info(f"stderr: {result.stderr}")
125
 
126
  if result.returncode not in (0, 1, 2):
127
  return jsonify({"error": result.stderr}), 500
 
 
128
 
129
  try:
 
 
 
 
 
130
  output_json = json.loads(result.stdout)
131
  except Exception as e:
132
  return jsonify({"error": f"JSON parse failed: {str(e)}", "raw": result.stdout}), 500
@@ -147,7 +151,6 @@ def scan_code():
147
  return jsonify({"error": str(e)}), 500
148
 
149
 
150
-
151
  @app.route("/api/health")
152
  def health():
153
  return jsonify({"status": "ok"})
@@ -168,10 +171,8 @@ def enhance():
168
  if not code.strip():
169
  return jsonify({"error": "No code provided"}), 400
170
 
171
- # 🔹 New format (returns dict)
172
  result = enhance_code(code, language)
173
 
174
- # Save to history (with candidates + explanations)
175
  enhance_history.insert_one({
176
  "username": username,
177
  "code": code,
@@ -195,11 +196,9 @@ def history():
195
  try:
196
  username = get_jwt_identity()
197
 
198
- # Fetch both histories
199
  enhance_records = list(enhance_history.find({"username": username}).sort("timestamp", -1))
200
  scan_records = list(scan_history.find({"username": username}).sort("timestamp", -1))
201
 
202
- # Convert ObjectId to string & return only relevant fields
203
  def clean(record, record_type):
204
  return {
205
  "id": str(record.get("_id")),
@@ -207,8 +206,8 @@ def history():
207
  "code": record.get("code"),
208
  "enhanced_code": record.get("enhanced_code"),
209
  "diff": record.get("diff"),
210
- "candidates": record.get("candidates", []), # ✅ added
211
- "explanations": record.get("explanations", []), # ✅ added
212
  "result": record.get("result") if record_type == "scan" else None,
213
  "timestamp": record.get("timestamp"),
214
  }
@@ -257,21 +256,22 @@ def submit_review():
257
 
258
  except Exception as e:
259
  return jsonify({"error": str(e)}), 500
260
-
261
 
262
  @app.route("/api/enhance-stream", methods=["POST"])
 
263
  @jwt_required()
264
  def enhance_stream():
265
  data = request.get_json()
266
  code = data.get("code", "")
267
  language = data.get("language", "python")
 
268
 
269
  if not code.strip():
270
  return jsonify({"error": "No code"}), 400
271
 
272
  def generate():
273
  try:
274
- # 1️⃣ starting
275
  yield json.dumps({
276
  "type": "progress",
277
  "progress": 5
@@ -279,16 +279,27 @@ def enhance_stream():
279
 
280
  time.sleep(0.5)
281
 
282
- # 2️⃣ preprocessing
283
  yield json.dumps({
284
  "type": "progress",
285
  "progress": 20
286
  }) + "\n"
287
 
288
- # 3️⃣ heavy AI call
289
  result = enhance_code(code, language)
290
 
291
- # 4️⃣ done
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  yield json.dumps({
293
  "type": "result",
294
  "data": result
@@ -306,6 +317,5 @@ def enhance_stream():
306
  )
307
 
308
 
309
-
310
  if __name__ == '__main__':
311
- app.run(host="0.0.0.0", port=5000, debug=True)
 
9
  import subprocess
10
  import json
11
  import time
12
+ import sys
13
  from pymongo import MongoClient
14
  from routes.reviews import reviews_bp
15
+ from extensions import limiter
 
16
  from datetime import datetime
17
  from datetime import timedelta
 
18
  from models.reviews import reviews_collection
19
  from schemas import ScanRequest
20
  from pydantic import ValidationError
 
25
 
26
  app = Flask(__name__)
27
  Compress(app)
 
28
 
29
+ FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:5173")
30
+ allowed_origins = list({FRONTEND_URL, "http://localhost:5173", "http://localhost:3000", "http://localhost:8080"})
31
+ CORS(app, origins=allowed_origins)
32
+
33
+ app.config['MAX_CONTENT_LENGTH'] = 1 * 1024 * 1024
34
+
35
  env = os.getenv('FLASK_ENV', 'development')
36
  if env == 'development':
37
  load_dotenv('.env.development')
 
50
  "CACHE_DEFAULT_TIMEOUT": 3600
51
  })
52
 
53
+ cache.init_app(app)
 
 
54
 
55
  def files_hash(files):
56
  h = hashlib.sha256()
 
76
 
77
  app.register_blueprint(auth_bp, url_prefix="/api")
78
 
 
 
 
 
79
  limiter.init_app(app)
80
 
81
  @app.route('/api/scan', methods=['POST'])
 
84
  def scan_code():
85
  try:
86
  data = request.get_json()
 
87
 
88
  try:
89
  req = ScanRequest(**data)
 
91
  app.logger.error(f"Validation error: {e.errors()}")
92
  return jsonify({"error": e.errors()}), 400
93
 
94
+ files = [f.model_dump() for f in req.files]
95
  language = req.language.lower()
96
+ app.logger.info(f"Scan request: {len(files)} files, language={language}")
97
+
98
  username = get_jwt_identity()
99
 
100
+ key = f"scan:{username}:{language}:{files_hash(files)}"
101
  cached = cache.get(key)
102
  if cached:
103
  return jsonify({"result": cached, "cached": True})
 
109
  code_file.write(f["content"])
110
 
111
  if language == "python":
112
+ scan_command = [sys.executable, "-m", "bandit", "-r", temp_dir, "-f", "json"]
113
  elif language == "javascript":
114
+ scan_command = [sys.executable, "-m", "semgrep", "--config=p/javascript", "--json", temp_dir]
115
  else:
116
  return jsonify({"error": "Unsupported language"}), 400
117
 
118
  app.logger.info(f"Running: {' '.join(scan_command)}")
119
  result = subprocess.run(scan_command, capture_output=True, text=True)
120
 
 
121
  app.logger.info(f"stderr: {result.stderr}")
122
 
123
  if result.returncode not in (0, 1, 2):
124
  return jsonify({"error": result.stderr}), 500
125
+
126
+ app.logger.info("Scan completed successfully")
127
 
128
  try:
129
+ if not result.stdout.strip():
130
+ return jsonify({
131
+ "error": "Scanner produced no output",
132
+ "stderr": result.stderr
133
+ }), 500
134
  output_json = json.loads(result.stdout)
135
  except Exception as e:
136
  return jsonify({"error": f"JSON parse failed: {str(e)}", "raw": result.stdout}), 500
 
151
  return jsonify({"error": str(e)}), 500
152
 
153
 
 
154
  @app.route("/api/health")
155
  def health():
156
  return jsonify({"status": "ok"})
 
171
  if not code.strip():
172
  return jsonify({"error": "No code provided"}), 400
173
 
 
174
  result = enhance_code(code, language)
175
 
 
176
  enhance_history.insert_one({
177
  "username": username,
178
  "code": code,
 
196
  try:
197
  username = get_jwt_identity()
198
 
 
199
  enhance_records = list(enhance_history.find({"username": username}).sort("timestamp", -1))
200
  scan_records = list(scan_history.find({"username": username}).sort("timestamp", -1))
201
 
 
202
  def clean(record, record_type):
203
  return {
204
  "id": str(record.get("_id")),
 
206
  "code": record.get("code"),
207
  "enhanced_code": record.get("enhanced_code"),
208
  "diff": record.get("diff"),
209
+ "candidates": record.get("candidates", []),
210
+ "explanations": record.get("explanations", []),
211
  "result": record.get("result") if record_type == "scan" else None,
212
  "timestamp": record.get("timestamp"),
213
  }
 
256
 
257
  except Exception as e:
258
  return jsonify({"error": str(e)}), 500
259
+
260
 
261
  @app.route("/api/enhance-stream", methods=["POST"])
262
+ @limiter.limit("5/minute")
263
  @jwt_required()
264
  def enhance_stream():
265
  data = request.get_json()
266
  code = data.get("code", "")
267
  language = data.get("language", "python")
268
+ username = get_jwt_identity()
269
 
270
  if not code.strip():
271
  return jsonify({"error": "No code"}), 400
272
 
273
  def generate():
274
  try:
 
275
  yield json.dumps({
276
  "type": "progress",
277
  "progress": 5
 
279
 
280
  time.sleep(0.5)
281
 
 
282
  yield json.dumps({
283
  "type": "progress",
284
  "progress": 20
285
  }) + "\n"
286
 
 
287
  result = enhance_code(code, language)
288
 
289
+ try:
290
+ enhance_history.insert_one({
291
+ "username": username,
292
+ "code": code,
293
+ "language": language,
294
+ "enhanced_code": result.get("enhanced_code", ""),
295
+ "diff": result.get("diff", []),
296
+ "candidates": result.get("candidates", []),
297
+ "explanations": result.get("explanations", []),
298
+ "timestamp": datetime.utcnow().isoformat()
299
+ })
300
+ except Exception:
301
+ pass
302
+
303
  yield json.dumps({
304
  "type": "result",
305
  "data": result
 
317
  )
318
 
319
 
 
320
  if __name__ == '__main__':
321
+ app.run(host="0.0.0.0", port=5000, debug=True)
extensions.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from flask_limiter import Limiter
2
+ from flask_limiter.util import get_remote_address
3
+
4
+ limiter = Limiter(
5
+ key_func=get_remote_address,
6
+ default_limits=["20 per minute"]
7
+ )
model.py CHANGED
@@ -1,264 +1,337 @@
1
- import torch
2
- import difflib
3
- import re
4
- from transformers import (
5
- AutoTokenizer,
6
- AutoModelForSeq2SeqLM,
7
- AutoModelForCausalLM
8
- )
9
-
10
- # ----------------------------
11
- # Performance Settings
12
- # ----------------------------
13
-
14
- torch.set_num_threads(2)
15
- DEVICE = "cpu"
16
-
17
- # ----------------------------
18
- # Models and their types
19
- # ----------------------------
20
-
21
- MODEL_CONFIGS = {
22
- "Salesforce/codet5-base": "seq2seq", # CodeT5
23
- #"EleutherAI/gpt-neo-1.3B": "causal", # GPT-Neo (disabled due to free hosting for now; enable on local hosting)
24
- "microsoft/CodeGPT-small-py": "causal", # CodeGPT-small (Python)
25
- }
26
-
27
- # ----------------------------
28
- # Load tokenizers and models
29
- # ----------------------------
30
-
31
- tokenizers = {}
32
- models = {}
33
-
34
- print("🔹 Loading models...")
35
-
36
- for name, mtype in MODEL_CONFIGS.items():
37
- print(f"Loading {name} ...")
38
-
39
- tokenizers[name] = AutoTokenizer.from_pretrained(name)
40
-
41
- if mtype == "seq2seq":
42
- model = AutoModelForSeq2SeqLM.from_pretrained(name)
43
- else:
44
- model = AutoModelForCausalLM.from_pretrained(name)
45
-
46
- model.to(DEVICE)
47
- model.eval()
48
- models[name] = model
49
-
50
- print("✅ All models loaded")
51
-
52
- # ----------------------------
53
- # Rule-based fixes
54
- # ----------------------------
55
-
56
- SECURE_REPLACEMENTS = {
57
- "hashlib.md5": ("hashlib.sha256", "MD5 is weak, replaced with SHA-256."),
58
- "hashlib.sha1": ("hashlib.sha256", "SHA1 is weak, replaced with SHA-256."),
59
- "eval(": ("ast.literal_eval(", "Unsafe eval removed, replaced with safe literal_eval."),
60
- "pickle.load(": ("# pickle.load removed", "pickle.load is unsafe, consider json/safe loaders."),
61
- }
62
-
63
- def rule_based_patch(code: str):
64
- explanations = []
65
- patched = code
66
-
67
- for bad, (good, reason) in SECURE_REPLACEMENTS.items():
68
- if bad in patched:
69
- patched = patched.replace(bad, good)
70
- explanations.append({
71
- "change": f"{bad} {good}",
72
- "reason": reason
73
- })
74
-
75
- return patched, explanations
76
-
77
- # ----------------------------
78
- # Structure preservation
79
- # ----------------------------
80
-
81
- def preserve_structure(original: str, enhanced: str):
82
- final_code = enhanced
83
-
84
- original_imports = [
85
- l for l in original.splitlines()
86
- if l.strip().startswith(("import ", "from "))
87
- ]
88
-
89
- for imp in original_imports:
90
- if imp not in final_code:
91
- final_code = imp + "\n" + final_code
92
-
93
- original_defs = [
94
- l for l in original.splitlines()
95
- if l.strip().startswith("def ")
96
- ]
97
-
98
- for d in original_defs:
99
- if d.split("(")[0] not in final_code:
100
- final_code = (
101
- d +
102
- "\n # [!] Function body missing, please review\n" +
103
- final_code
104
- )
105
-
106
- return final_code
107
-
108
- # ----------------------------
109
- # Diff creation
110
- # ----------------------------
111
-
112
- def create_diff(original: str, enhanced: str):
113
- diff_lines = difflib.unified_diff(
114
- original.splitlines(),
115
- enhanced.splitlines(),
116
- lineterm=""
117
- )
118
-
119
- formatted = []
120
-
121
- for line in diff_lines:
122
- if line.startswith("+") and not line.startswith("+++"):
123
-
124
- formatted.append({
125
- "type": "add",
126
- "content": line[1:]
127
- })
128
-
129
- elif line.startswith("-") and not line.startswith("---"):
130
-
131
- formatted.append({
132
- "type": "remove",
133
- "content": line[1:]
134
- })
135
-
136
- elif not line.startswith("@@"):
137
-
138
- formatted.append({
139
- "type": "context",
140
- "content": line
141
- })
142
-
143
- return formatted
144
-
145
- # ----------------------------
146
- # Postprocess output
147
- # ----------------------------
148
-
149
- def postprocess_code(code: str):
150
- code = re.sub(r'^"""|"""$', '', code.strip())
151
- lines = code.splitlines()
152
- return "\n".join(
153
- l.replace("\t", " ").rstrip()
154
- for l in lines
155
- )
156
-
157
- # ----------------------------
158
- # Run one model
159
- # ----------------------------
160
-
161
- def run_model(model_name, code, language):
162
-
163
- tokenizer = tokenizers[model_name]
164
- model = models[model_name]
165
- mtype = MODEL_CONFIGS[model_name]
166
-
167
- prompt = f"Fix security issues in this {language} code:\n{code}"
168
-
169
- if mtype == "seq2seq":
170
-
171
- inputs = tokenizer(
172
- prompt,
173
- return_tensors="pt",
174
- truncation=True,
175
- max_length=512
176
- ).to(DEVICE)
177
-
178
- outputs = model.generate(
179
- **inputs,
180
- max_new_tokens=512,
181
- num_beams=4
182
- )
183
-
184
- else:
185
-
186
- inputs = tokenizer(
187
- prompt,
188
- return_tensors="pt",
189
- truncation=True,
190
- max_length=512
191
- ).to(DEVICE)
192
-
193
- outputs = model.generate(
194
- **inputs,
195
- max_new_tokens=256,
196
- temperature=0.3,
197
- top_p=0.95,
198
- do_sample=False
199
- )
200
-
201
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
202
-
203
- # ----------------------------
204
- # Main enhancer
205
- # ----------------------------
206
-
207
- def enhance_code(code: str, language: str):
208
-
209
- with torch.no_grad():
210
-
211
- try:
212
- # 1️⃣ Rule-based fixes
213
- patched_code, rule_explanations = rule_based_patch(code)
214
-
215
- # 2️⃣ Model ensemble
216
- candidates = []
217
-
218
- for m in MODEL_CONFIGS.keys():
219
- try:
220
- enhanced = run_model(m, patched_code, language)
221
- enhanced = postprocess_code(enhanced)
222
- enhanced = preserve_structure(code, enhanced)
223
-
224
- candidates.append({
225
- "model": m,
226
- "code": enhanced
227
- })
228
-
229
- except Exception as e:
230
- candidates.append({
231
- "model": m,
232
- "code": f"# [!] Failed: {str(e)}"
233
- })
234
-
235
- # 3️⃣ Choose longest output as best
236
- best = max(candidates, key=lambda c: len(c["code"]))
237
-
238
- diff = create_diff(code, best["code"])
239
-
240
- explanations = rule_explanations + [{
241
- "change": "Model ensemble",
242
- "reason": "Best candidate selected from multiple models"
243
- }]
244
-
245
- return {
246
- "enhanced_code": best["code"],
247
- "diff": diff,
248
- "candidates": candidates[:3],
249
- "explanations": explanations
250
- }
251
-
252
- except Exception as e:
253
-
254
- fallback = code + f"\n# [!] Enhancer crashed: {str(e)}"
255
-
256
- return {
257
- "enhanced_code": fallback,
258
- "diff": create_diff(code, fallback),
259
- "candidates": [],
260
- "explanations": [{
261
- "change": "Error",
262
- "reason": str(e)
263
- }]
264
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import difflib
3
+ import re
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSeq2SeqLM,
7
+ AutoModelForCausalLM
8
+ )
9
+ import os
10
+
11
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
12
+ DEVICE = "cpu"
13
+
14
+ MAX_CODE_CHARS = 8000
15
+
16
+ MODEL_CONFIGS = {
17
+ "Salesforce/codet5-base": "seq2seq", # CodeT5
18
+ #"EleutherAI/gpt-neo-1.3B": "causal", # GPT-Neo (disabled due to free hosting)
19
+ "microsoft/CodeGPT-small-py": "causal", # CodeGPT-small (Python)
20
+ }
21
+
22
+ tokenizers = {}
23
+ models = {}
24
+
25
+ def load_models():
26
+ for name, mtype in MODEL_CONFIGS.items():
27
+ print(f"Loading {name} ...")
28
+
29
+ tokenizers[name] = AutoTokenizer.from_pretrained(name, use_fast=False)
30
+
31
+ if mtype == "seq2seq":
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(name)
33
+ else:
34
+ model = AutoModelForCausalLM.from_pretrained(name)
35
+
36
+ model.to(DEVICE)
37
+ model.eval()
38
+ models[name] = model
39
+
40
+ print("✅ All models loaded")
41
+
42
+ print("🔹 Loading models...")
43
+ _werkzeug_parent = (
44
+ os.environ.get("WERKZEUG_RUN_MAIN") is None
45
+ and os.environ.get("FLASK_DEBUG", "0") in ("1", "true", "True")
46
+ )
47
+ if not _werkzeug_parent:
48
+ load_models()
49
+
50
+ SECURE_REPLACEMENTS = {
51
+ # Weak hashing
52
+ "hashlib.md5": ("hashlib.sha256", "MD5 is cryptographically broken; replaced with SHA-256."),
53
+ "hashlib.sha1": ("hashlib.sha256", "SHA-1 is deprecated for security use; replaced with SHA-256."),
54
+ # Dangerous execution
55
+ "eval(": ("ast.literal_eval(", "eval() executes arbitrary code; use ast.literal_eval for safe parsing."),
56
+ "exec(": ("# exec() removed —", "exec() executes arbitrary code strings; remove or sandbox."),
57
+ # Insecure deserialization
58
+ "pickle.load(": ("# pickle.load UNSAFE —","pickle.load deserialises arbitrary objects; use json or safer alternatives."),
59
+ "pickle.loads(": ("# pickle.loads UNSAFE —","pickle.loads is an RCE risk; use json.loads instead."),
60
+ "yaml.load(": ("yaml.safe_load(", "yaml.load with arbitrary loader executes code; use yaml.safe_load."),
61
+ # Command injection
62
+ "os.system(": ("subprocess.run([", "os.system passes args to the shell; use subprocess.run with a list to avoid injection."),
63
+ "shell=True": ("shell=False", "shell=True enables command injection; pass args as a list with shell=False."),
64
+ # Insecure temp files
65
+ "tempfile.mktemp(": ("tempfile.mkstemp(", "mktemp() has a race condition; use mkstemp() which atomically creates the file."),
66
+ # Weak randomness
67
+ "random.random()": ("secrets.token_bytes(16)","random is not cryptographically secure; use the secrets module for security-sensitive values."),
68
+ "random.randint(": ("secrets.randbelow(", "random.randint is not cryptographically secure; use secrets.randbelow for security tokens."),
69
+ # Insecure TLS
70
+ "verify=False": ("verify=True", "Disabling TLS verification allows MITM attacks; always verify certificates."),
71
+ "ssl.CERT_NONE": ("ssl.CERT_REQUIRED", "ssl.CERT_NONE disables certificate validation entirely; use ssl.CERT_REQUIRED."),
72
+ # Debug/info leakage
73
+ "DEBUG = True": ("DEBUG = False", "Debug mode exposes stack traces and internal config; disable in production."),
74
+ "app.run(debug=True": ("app.run(debug=False", "Flask debug=True enables the Werkzeug debugger which allows arbitrary code execution."),
75
+ }
76
+
77
+ PYTHON_EXTRA_REPLACEMENTS = {
78
+ "% username": ("# Use parameterised query","String-formatted SQL allows injection; use parameterised queries with ? placeholders."),
79
+ "format(username": ("# Use parameterised query","String-formatted SQL allows injection; use parameterised queries."),
80
+ "http://": ("https://", "Unencrypted HTTP transmits data in plaintext; upgrade to HTTPS."),
81
+ }
82
+
83
+ JS_SECURE_REPLACEMENTS = {
84
+ "innerHTML =": ("textContent =", "innerHTML enables XSS; use textContent to safely set plain text."),
85
+ "innerHTML+=": ("textContent+=", "innerHTML enables XSS; use textContent instead."),
86
+ "document.write(": ("// document.write removed —","document.write allows XSS injection; use DOM APIs instead."),
87
+ "eval(": ("JSON.parse(", "eval() executes arbitrary JavaScript; use JSON.parse for data or a safe alternative."),
88
+ "Math.random()": ("crypto.getRandomValues(", "Math.random is not cryptographically secure; use crypto.getRandomValues."),
89
+ "http://": ("https://", "Unencrypted HTTP in JS code; upgrade to HTTPS."),
90
+ "dangerouslySetInnerHTML": ("// dangerouslySetInnerHTML review needed","dangerouslySetInnerHTML bypasses React's XSS protection; sanitise input with DOMPurify first."),
91
+ "localStorage.setItem": ("// Consider sessionStorage ","localStorage persists indefinitely; prefer sessionStorage for sensitive session data."),
92
+ }
93
+
94
+ SECRET_PATTERNS = [
95
+ (
96
+ r'(?i)(?:password|passwd|pwd)\s*=\s*["\'][^"\']{8,}["\']',
97
+ "Hardcoded password detected — move to an environment variable."
98
+ ),
99
+ (
100
+ r'(?i)(?:api_key|apikey|secret_key|secret|auth_token)\s*=\s*["\'][a-zA-Z0-9+/=_\-]{16,}["\']',
101
+ "Hardcoded API key or secret detected — move to an environment variable."
102
+ ),
103
+ (
104
+ r'(?:AKIA|ASIA)[A-Z0-9]{16}',
105
+ "AWS Access Key ID pattern detected — never hardcode AWS credentials."
106
+ ),
107
+ (
108
+ r'(?i)private_key\s*=\s*["\'][^"\']{10,}["\']',
109
+ "Hardcoded private key detected — load from a secure vault or environment variable."
110
+ ),
111
+ ]
112
+
113
+ def scan_secrets(code: str) -> list:
114
+ """Detect hardcoded secrets via regex and return explanation entries."""
115
+ findings = []
116
+ for pattern, reason in SECRET_PATTERNS:
117
+ for match in re.finditer(pattern, code):
118
+ snippet = match.group()[:60]
119
+ findings.append({
120
+ "change": f"Hardcoded secret: {snippet}{'...' if len(match.group()) > 60 else ''}",
121
+ "reason": reason
122
+ })
123
+ return findings
124
+
125
+ def rule_based_patch(code: str, language: str = "python"):
126
+ explanations = []
127
+ patched = code
128
+
129
+ for bad, (good, reason) in SECURE_REPLACEMENTS.items():
130
+ if bad in patched:
131
+ patched = patched.replace(bad, good)
132
+ explanations.append({
133
+ "change": f"{bad} → {good}",
134
+ "reason": reason
135
+ })
136
+
137
+ lang_rules = JS_SECURE_REPLACEMENTS if language == "javascript" else PYTHON_EXTRA_REPLACEMENTS
138
+ for bad, (good, reason) in lang_rules.items():
139
+ if bad in patched:
140
+ patched = patched.replace(bad, good)
141
+ explanations.append({
142
+ "change": f"{bad} → {good}",
143
+ "reason": reason
144
+ })
145
+
146
+ secret_findings = scan_secrets(code)
147
+ explanations.extend(secret_findings)
148
+
149
+ return patched, explanations
150
+
151
+ def preserve_structure(original: str, enhanced: str):
152
+ final_code = enhanced
153
+
154
+ original_imports = [
155
+ l for l in original.splitlines()
156
+ if l.strip().startswith(("import ", "from "))
157
+ ]
158
+
159
+ for imp in original_imports:
160
+ if imp not in final_code:
161
+ final_code = imp + "\n" + final_code
162
+
163
+ original_defs = [
164
+ l for l in original.splitlines()
165
+ if l.strip().startswith("def ")
166
+ ]
167
+
168
+ for d in original_defs:
169
+ if d.split("(")[0] not in final_code:
170
+ final_code = (
171
+ d +
172
+ "\n # [!] Function body missing, please review\n" +
173
+ final_code
174
+ )
175
+
176
+ return final_code
177
+
178
+ def create_diff(original: str, enhanced: str):
179
+ diff_lines = difflib.unified_diff(
180
+ original.splitlines(),
181
+ enhanced.splitlines(),
182
+ lineterm=""
183
+ )
184
+
185
+ formatted = []
186
+
187
+ for line in diff_lines:
188
+ if line.startswith("+") and not line.startswith("+++"):
189
+ formatted.append({
190
+ "type": "add",
191
+ "content": line[1:]
192
+ })
193
+ elif line.startswith("-") and not line.startswith("---"):
194
+ formatted.append({
195
+ "type": "remove",
196
+ "content": line[1:]
197
+ })
198
+ elif not line.startswith("@@"):
199
+ formatted.append({
200
+ "type": "context",
201
+ "content": line
202
+ })
203
+
204
+ return formatted
205
+
206
+ def postprocess_code(code: str):
207
+ code = re.sub(r'^"""|"""$', '', code.strip())
208
+ lines = code.splitlines()
209
+ return "\n".join(
210
+ l.replace("\t", " ").rstrip()
211
+ for l in lines
212
+ )
213
+
214
+ def score_candidate(candidate_code: str, original_code: str) -> int:
215
+ """
216
+ Score a candidate by how many known bad patterns it fixed
217
+ minus any new bad patterns it introduced.
218
+ Failed/crashed candidates are heavily penalised.
219
+ """
220
+ if "# [!] Failed" in candidate_code[:80]:
221
+ return -9999
222
+
223
+ all_bad = (
224
+ list(SECURE_REPLACEMENTS.keys()) +
225
+ list(PYTHON_EXTRA_REPLACEMENTS.keys()) +
226
+ list(JS_SECURE_REPLACEMENTS.keys())
227
+ )
228
+ fixed = sum(1 for p in all_bad if p in original_code and p not in candidate_code)
229
+ new_issues = sum(1 for p in all_bad if p not in original_code and p in candidate_code)
230
+ return fixed - new_issues
231
+
232
+ def run_model(model_name, code, language):
233
+ tokenizer = tokenizers[model_name]
234
+ model = models[model_name]
235
+ mtype = MODEL_CONFIGS[model_name]
236
+
237
+ prompt = f"Fix security issues in this {language} code:\n{code}"
238
+
239
+ if mtype == "seq2seq":
240
+ inputs = tokenizer(
241
+ prompt,
242
+ return_tensors="pt",
243
+ truncation=True,
244
+ max_length=512
245
+ ).to(DEVICE)
246
+
247
+ outputs = model.generate(
248
+ **inputs,
249
+ max_new_tokens=512,
250
+ num_beams=4
251
+ )
252
+ else:
253
+ inputs = tokenizer(
254
+ prompt,
255
+ return_tensors="pt",
256
+ truncation=True,
257
+ max_length=512
258
+ ).to(DEVICE)
259
+
260
+ outputs = model.generate(
261
+ **inputs,
262
+ max_new_tokens=256,
263
+ temperature=0.3,
264
+ top_p=0.95,
265
+ do_sample=True
266
+ )
267
+
268
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
269
+
270
+ def enhance_code(code: str, language: str):
271
+
272
+ if len(code) > MAX_CODE_CHARS:
273
+ return {
274
+ "enhanced_code": code,
275
+ "diff": [],
276
+ "candidates": [],
277
+ "explanations": [{
278
+ "change": "Input too large",
279
+ "reason": f"Code exceeds {MAX_CODE_CHARS} character limit. Please split into smaller files."
280
+ }]
281
+ }
282
+
283
+ with torch.no_grad():
284
+ try:
285
+ patched_code, rule_explanations = rule_based_patch(code, language)
286
+
287
+ candidates = []
288
+
289
+ for m in MODEL_CONFIGS.keys():
290
+ try:
291
+ enhanced = run_model(m, patched_code, language)
292
+ enhanced = postprocess_code(enhanced)
293
+ enhanced = preserve_structure(code, enhanced)
294
+
295
+ candidates.append({
296
+ "model": m,
297
+ "code": enhanced
298
+ })
299
+
300
+ except Exception as e:
301
+ candidates.append({
302
+ "model": m,
303
+ "code": f"# [!] Failed: {str(e)}"
304
+ })
305
+
306
+ valid_candidates = [c for c in candidates if "# [!] Failed" not in c["code"][:80]]
307
+ if valid_candidates:
308
+ best = max(valid_candidates, key=lambda c: score_candidate(c["code"], code))
309
+ else:
310
+ best = {"model": "rule-based", "code": patched_code}
311
+
312
+ diff = create_diff(code, best["code"])
313
+
314
+ explanations = rule_explanations + [{
315
+ "change": "Model ensemble",
316
+ "reason": "Best candidate selected from multiple models based on security improvement score"
317
+ }]
318
+
319
+ return {
320
+ "enhanced_code": best["code"],
321
+ "diff": diff,
322
+ "candidates": candidates[:3],
323
+ "explanations": explanations
324
+ }
325
+
326
+ except Exception as e:
327
+ fallback = code + f"\n# [!] Enhancer crashed: {str(e)}"
328
+
329
+ return {
330
+ "enhanced_code": fallback,
331
+ "diff": create_diff(code, fallback),
332
+ "candidates": [],
333
+ "explanations": [{
334
+ "change": "Error",
335
+ "reason": str(e)
336
+ }]
337
+ }
requirements.txt CHANGED
@@ -1,34 +1,28 @@
1
-
2
- # Flask dependencies
3
- flask==2.3.3
4
- flask-cors==4.0.0
5
- flask-jwt-extended==4.5.3
6
- flask-limiter==3.5.0
7
- flask-compress==1.13
8
- flask-caching==2.1.0
9
- python-dotenv==1.0.0
10
-
11
- # Database
12
- pymongo==4.5.0
13
-
14
- # Validation
15
- pydantic>=2.9.2
16
-
17
-
18
- # AI/ML dependencies
19
- torch>=2.0.0
20
- transformers>=4.30.0
21
-
22
- # Security scanning tools
23
- bandit>=1.7.5
24
- semgrep>=1.45.0
25
-
26
- # Optional: for better performance
27
- accelerate>=0.20.0
28
- safetensors>=0.3.0
29
-
30
- bcrypt>=4.0.1
31
- python-dotenv
32
- gunicorn
33
-
34
- requests==2.31.0
 
1
+ flask==2.3.3
2
+ flask-cors==4.0.0
3
+ flask-jwt-extended==4.5.3
4
+ flask-limiter==3.5.0
5
+ flask-compress==1.13
6
+ flask-caching==2.1.0
7
+ python-dotenv==1.0.0
8
+
9
+ pymongo==4.5.0
10
+
11
+ pydantic>=2.9.2
12
+
13
+
14
+ torch>=2.0.0
15
+ transformers==4.40.2
16
+ tokenizers==0.19.1
17
+
18
+ bandit>=1.7.5
19
+ semgrep>=1.45.0
20
+
21
+ accelerate>=0.20.0
22
+ safetensors>=0.3.0
23
+
24
+ bcrypt>=4.0.1
25
+ python-dotenv
26
+ gunicorn
27
+
28
+ requests==2.31.0
 
 
 
 
 
 
schemas.py CHANGED
@@ -1,11 +1,43 @@
1
- # backend/schemas.py
2
- from pydantic import BaseModel
3
  from typing import List, Optional
4
 
5
  class FileModel(BaseModel):
6
  filename: str
7
  content: str
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class ScanRequest(BaseModel):
10
  files: List[FileModel]
11
  language: str = "python"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, field_validator
 
2
  from typing import List, Optional
3
 
4
  class FileModel(BaseModel):
5
  filename: str
6
  content: str
7
 
8
+ @field_validator('filename')
9
+ @classmethod
10
+ def validate_filename(cls, v):
11
+ if not v or not v.strip():
12
+ raise ValueError('Filename cannot be empty')
13
+ if len(v) > 255:
14
+ raise ValueError('Filename must be 255 characters or less')
15
+ return v
16
+
17
+ @field_validator('content')
18
+ @classmethod
19
+ def validate_content(cls, v):
20
+ if len(v) > 100_000:
21
+ raise ValueError('File content exceeds 100 KB limit. Please split into smaller files.')
22
+ return v
23
+
24
  class ScanRequest(BaseModel):
25
  files: List[FileModel]
26
  language: str = "python"
27
+
28
+ @field_validator('files')
29
+ @classmethod
30
+ def validate_files(cls, v):
31
+ if len(v) == 0:
32
+ raise ValueError('At least one file is required')
33
+ if len(v) > 10:
34
+ raise ValueError('Maximum 10 files allowed per scan request')
35
+ return v
36
+
37
+ @field_validator('language')
38
+ @classmethod
39
+ def validate_language(cls, v):
40
+ supported = ("python", "javascript")
41
+ if v.lower() not in supported:
42
+ raise ValueError(f'Language must be one of: {", ".join(supported)}')
43
+ return v.lower()