DarthBihan commited on
Commit
61db20b
·
verified ·
1 Parent(s): 3396f0a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +264 -133
model.py CHANGED
@@ -1,133 +1,264 @@
1
- from transformers import (
2
- AutoTokenizer,
3
- AutoModelForSeq2SeqLM,
4
- AutoModelForCausalLM
5
- )
6
- import difflib, re, torch
7
-
8
- # Models and their types
9
- MODEL_CONFIGS = {
10
- "Salesforce/codet5-base": "seq2seq", # CodeT5
11
- "EleutherAI/gpt-neo-1.3B": "causal", # GPT-Neo
12
- "microsoft/CodeGPT-small-py": "causal", # CodeGPT-small (Python)
13
- }
14
-
15
- # Load tokenizers and models
16
- tokenizers, models = {}, {}
17
-
18
- for name, mtype in MODEL_CONFIGS.items():
19
- tokenizers[name] = AutoTokenizer.from_pretrained(name)
20
- if mtype == "seq2seq":
21
- models[name] = AutoModelForSeq2SeqLM.from_pretrained(name)
22
- else:
23
- models[name] = AutoModelForCausalLM.from_pretrained(name)
24
-
25
- # Rule-based fixes
26
- SECURE_REPLACEMENTS = {
27
- "hashlib.md5": ("hashlib.sha256", "MD5 is weak, replaced with SHA-256."),
28
- "hashlib.sha1": ("hashlib.sha256", "SHA1 is weak, replaced with SHA-256."),
29
- "eval(": ("ast.literal_eval(", "Unsafe eval removed, replaced with safe literal_eval."),
30
- "pickle.load(": ("# pickle.load removed", "pickle.load is unsafe, consider json/safe loaders."),
31
- }
32
-
33
- def rule_based_patch(code: str):
34
- explanations = []
35
- patched = code
36
- for bad, (good, reason) in SECURE_REPLACEMENTS.items():
37
- if bad in patched:
38
- patched = patched.replace(bad, good)
39
- explanations.append({"change": f"{bad} → {good}", "reason": reason})
40
- return patched, explanations
41
-
42
- def preserve_structure(original: str, enhanced: str):
43
- """Ensure imports and function signatures remain if model drops them."""
44
- final_code = enhanced
45
- original_imports = [l for l in original.splitlines() if l.strip().startswith("import")]
46
- for imp in original_imports:
47
- if imp not in final_code:
48
- final_code = imp + "\n" + final_code
49
- original_defs = [l for l in original.splitlines() if l.strip().startswith("def ")]
50
- for d in original_defs:
51
- if d.split("(")[0] not in final_code:
52
- final_code = d + "\n # [!] Function body missing, please review\n" + final_code
53
- return final_code
54
-
55
- def create_diff(original: str, enhanced: str):
56
- """Return structured diff for frontend rendering."""
57
- diff_lines = difflib.unified_diff(
58
- original.splitlines(), enhanced.splitlines(),
59
- fromfile="Original", tofile="Enhanced", lineterm=""
60
- )
61
- formatted = []
62
- for line in diff_lines:
63
- if line.startswith("+") and not line.startswith("+++"):
64
- formatted.append({"type": "add", "content": line[1:]})
65
- elif line.startswith("-") and not line.startswith("---"):
66
- formatted.append({"type": "remove", "content": line[1:]})
67
- elif not line.startswith("@@"):
68
- formatted.append({"type": "context", "content": line})
69
- return formatted
70
-
71
- def postprocess_code(code: str):
72
- code = re.sub(r'^"""|"""$', '', code.strip())
73
- lines = code.splitlines()
74
- return "\n".join([l.replace("\t", " ").rstrip() for l in lines])
75
-
76
- def run_model(model_name, code, language):
77
- tokenizer = tokenizers[model_name]
78
- model = models[model_name]
79
- mtype = MODEL_CONFIGS[model_name]
80
-
81
- prompt = f"fix {language} code: {code}"
82
-
83
- if mtype == "seq2seq":
84
- inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
85
- outputs = model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
86
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
87
- else:
88
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
89
- outputs = model.generate(
90
- **inputs,
91
- max_new_tokens=256,
92
- temperature=0.3,
93
- top_p=0.95,
94
- do_sample=False
95
- )
96
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
97
-
98
- def enhance_code(code: str, language: str):
99
- try:
100
- patched_code, rule_explanations = rule_based_patch(code)
101
-
102
- candidates = []
103
- for m in MODEL_CONFIGS.keys():
104
- try:
105
- enhanced = run_model(m, patched_code, language)
106
- enhanced = postprocess_code(enhanced)
107
- enhanced = preserve_structure(code, enhanced)
108
- candidates.append({"model": m, "code": enhanced})
109
- except Exception as e:
110
- candidates.append({"model": m, "code": f"# [!] Failed: {str(e)}"})
111
-
112
- best = max(candidates, key=lambda c: len(c["code"]))
113
- diff = create_diff(code, best["code"])
114
-
115
- explanations = rule_explanations + [
116
- {"change": "Model improvements", "reason": "Best candidate chosen among ensemble"}
117
- ]
118
-
119
- return {
120
- "enhanced_code": best["code"],
121
- "diff": diff,
122
- "candidates": candidates[:3],
123
- "explanations": explanations
124
- }
125
-
126
- except Exception as e:
127
- fallback = code + f"\n# [!] Enhancer crashed: {str(e)}"
128
- return {
129
- "enhanced_code": fallback,
130
- "diff": create_diff(code, fallback),
131
- "candidates": [],
132
- "explanations": [{"change": "Error", "reason": str(e)}]
133
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import difflib
3
+ import re
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForSeq2SeqLM,
7
+ AutoModelForCausalLM
8
+ )
9
+
10
+ # ----------------------------
11
+ # Performance Settings
12
+ # ----------------------------
13
+
14
+ torch.set_num_threads(2)
15
+ DEVICE = "cpu"
16
+
17
+ # ----------------------------
18
+ # Models and their types
19
+ # ----------------------------
20
+
21
+ MODEL_CONFIGS = {
22
+ "Salesforce/codet5-base": "seq2seq", # CodeT5
23
+ #"EleutherAI/gpt-neo-1.3B": "causal", # GPT-Neo (disabled due to free hosting for now; enable on local hosting)
24
+ "microsoft/CodeGPT-small-py": "causal", # CodeGPT-small (Python)
25
+ }
26
+
27
+ # ----------------------------
28
+ # Load tokenizers and models
29
+ # ----------------------------
30
+
31
+ tokenizers = {}
32
+ models = {}
33
+
34
+ print("🔹 Loading models...")
35
+
36
+ for name, mtype in MODEL_CONFIGS.items():
37
+ print(f"Loading {name} ...")
38
+
39
+ tokenizers[name] = AutoTokenizer.from_pretrained(name)
40
+
41
+ if mtype == "seq2seq":
42
+ model = AutoModelForSeq2SeqLM.from_pretrained(name)
43
+ else:
44
+ model = AutoModelForCausalLM.from_pretrained(name)
45
+
46
+ model.to(DEVICE)
47
+ model.eval()
48
+ models[name] = model
49
+
50
+ print("✅ All models loaded")
51
+
52
+ # ----------------------------
53
+ # Rule-based fixes
54
+ # ----------------------------
55
+
56
+ SECURE_REPLACEMENTS = {
57
+ "hashlib.md5": ("hashlib.sha256", "MD5 is weak, replaced with SHA-256."),
58
+ "hashlib.sha1": ("hashlib.sha256", "SHA1 is weak, replaced with SHA-256."),
59
+ "eval(": ("ast.literal_eval(", "Unsafe eval removed, replaced with safe literal_eval."),
60
+ "pickle.load(": ("# pickle.load removed", "pickle.load is unsafe, consider json/safe loaders."),
61
+ }
62
+
63
+ def rule_based_patch(code: str):
64
+ explanations = []
65
+ patched = code
66
+
67
+ for bad, (good, reason) in SECURE_REPLACEMENTS.items():
68
+ if bad in patched:
69
+ patched = patched.replace(bad, good)
70
+ explanations.append({
71
+ "change": f"{bad} → {good}",
72
+ "reason": reason
73
+ })
74
+
75
+ return patched, explanations
76
+
77
+ # ----------------------------
78
+ # Structure preservation
79
+ # ----------------------------
80
+
81
+ def preserve_structure(original: str, enhanced: str):
82
+ final_code = enhanced
83
+
84
+ original_imports = [
85
+ l for l in original.splitlines()
86
+ if l.strip().startswith(("import ", "from "))
87
+ ]
88
+
89
+ for imp in original_imports:
90
+ if imp not in final_code:
91
+ final_code = imp + "\n" + final_code
92
+
93
+ original_defs = [
94
+ l for l in original.splitlines()
95
+ if l.strip().startswith("def ")
96
+ ]
97
+
98
+ for d in original_defs:
99
+ if d.split("(")[0] not in final_code:
100
+ final_code = (
101
+ d +
102
+ "\n # [!] Function body missing, please review\n" +
103
+ final_code
104
+ )
105
+
106
+ return final_code
107
+
108
+ # ----------------------------
109
+ # Diff creation
110
+ # ----------------------------
111
+
112
+ def create_diff(original: str, enhanced: str):
113
+ diff_lines = difflib.unified_diff(
114
+ original.splitlines(),
115
+ enhanced.splitlines(),
116
+ lineterm=""
117
+ )
118
+
119
+ formatted = []
120
+
121
+ for line in diff_lines:
122
+ if line.startswith("+") and not line.startswith("+++"):
123
+
124
+ formatted.append({
125
+ "type": "add",
126
+ "content": line[1:]
127
+ })
128
+
129
+ elif line.startswith("-") and not line.startswith("---"):
130
+
131
+ formatted.append({
132
+ "type": "remove",
133
+ "content": line[1:]
134
+ })
135
+
136
+ elif not line.startswith("@@"):
137
+
138
+ formatted.append({
139
+ "type": "context",
140
+ "content": line
141
+ })
142
+
143
+ return formatted
144
+
145
+ # ----------------------------
146
+ # Postprocess output
147
+ # ----------------------------
148
+
149
+ def postprocess_code(code: str):
150
+ code = re.sub(r'^"""|"""$', '', code.strip())
151
+ lines = code.splitlines()
152
+ return "\n".join(
153
+ l.replace("\t", " ").rstrip()
154
+ for l in lines
155
+ )
156
+
157
+ # ----------------------------
158
+ # Run one model
159
+ # ----------------------------
160
+
161
+ def run_model(model_name, code, language):
162
+
163
+ tokenizer = tokenizers[model_name]
164
+ model = models[model_name]
165
+ mtype = MODEL_CONFIGS[model_name]
166
+
167
+ prompt = f"Fix security issues in this {language} code:\n{code}"
168
+
169
+ if mtype == "seq2seq":
170
+
171
+ inputs = tokenizer(
172
+ prompt,
173
+ return_tensors="pt",
174
+ truncation=True,
175
+ max_length=512
176
+ ).to(DEVICE)
177
+
178
+ outputs = model.generate(
179
+ **inputs,
180
+ max_new_tokens=512,
181
+ num_beams=4
182
+ )
183
+
184
+ else:
185
+
186
+ inputs = tokenizer(
187
+ prompt,
188
+ return_tensors="pt",
189
+ truncation=True,
190
+ max_length=512
191
+ ).to(DEVICE)
192
+
193
+ outputs = model.generate(
194
+ **inputs,
195
+ max_new_tokens=256,
196
+ temperature=0.3,
197
+ top_p=0.95,
198
+ do_sample=False
199
+ )
200
+
201
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
202
+
203
+ # ----------------------------
204
+ # Main enhancer
205
+ # ----------------------------
206
+
207
+ def enhance_code(code: str, language: str):
208
+
209
+ with torch.no_grad():
210
+
211
+ try:
212
+ # 1️⃣ Rule-based fixes
213
+ patched_code, rule_explanations = rule_based_patch(code)
214
+
215
+ # 2️⃣ Model ensemble
216
+ candidates = []
217
+
218
+ for m in MODEL_CONFIGS.keys():
219
+ try:
220
+ enhanced = run_model(m, patched_code, language)
221
+ enhanced = postprocess_code(enhanced)
222
+ enhanced = preserve_structure(code, enhanced)
223
+
224
+ candidates.append({
225
+ "model": m,
226
+ "code": enhanced
227
+ })
228
+
229
+ except Exception as e:
230
+ candidates.append({
231
+ "model": m,
232
+ "code": f"# [!] Failed: {str(e)}"
233
+ })
234
+
235
+ # 3️⃣ Choose longest output as best
236
+ best = max(candidates, key=lambda c: len(c["code"]))
237
+
238
+ diff = create_diff(code, best["code"])
239
+
240
+ explanations = rule_explanations + [{
241
+ "change": "Model ensemble",
242
+ "reason": "Best candidate selected from multiple models"
243
+ }]
244
+
245
+ return {
246
+ "enhanced_code": best["code"],
247
+ "diff": diff,
248
+ "candidates": candidates[:3],
249
+ "explanations": explanations
250
+ }
251
+
252
+ except Exception as e:
253
+
254
+ fallback = code + f"\n# [!] Enhancer crashed: {str(e)}"
255
+
256
+ return {
257
+ "enhanced_code": fallback,
258
+ "diff": create_diff(code, fallback),
259
+ "candidates": [],
260
+ "explanations": [{
261
+ "change": "Error",
262
+ "reason": str(e)
263
+ }]
264
+ }