File size: 5,223 Bytes
c109b62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM
)
import difflib, re, torch
# Models and their types
MODEL_CONFIGS = {
"Salesforce/codet5-base": "seq2seq", # CodeT5
"EleutherAI/gpt-neo-1.3B": "causal", # GPT-Neo
"microsoft/CodeGPT-small-py": "causal", # CodeGPT-small (Python)
}
# Load tokenizers and models
tokenizers, models = {}, {}
for name, mtype in MODEL_CONFIGS.items():
tokenizers[name] = AutoTokenizer.from_pretrained(name)
if mtype == "seq2seq":
models[name] = AutoModelForSeq2SeqLM.from_pretrained(name)
else:
models[name] = AutoModelForCausalLM.from_pretrained(name)
# Rule-based fixes
SECURE_REPLACEMENTS = {
"hashlib.md5": ("hashlib.sha256", "MD5 is weak, replaced with SHA-256."),
"hashlib.sha1": ("hashlib.sha256", "SHA1 is weak, replaced with SHA-256."),
"eval(": ("ast.literal_eval(", "Unsafe eval removed, replaced with safe literal_eval."),
"pickle.load(": ("# pickle.load removed", "pickle.load is unsafe, consider json/safe loaders."),
}
def rule_based_patch(code: str):
explanations = []
patched = code
for bad, (good, reason) in SECURE_REPLACEMENTS.items():
if bad in patched:
patched = patched.replace(bad, good)
explanations.append({"change": f"{bad} → {good}", "reason": reason})
return patched, explanations
def preserve_structure(original: str, enhanced: str):
"""Ensure imports and function signatures remain if model drops them."""
final_code = enhanced
original_imports = [l for l in original.splitlines() if l.strip().startswith("import")]
for imp in original_imports:
if imp not in final_code:
final_code = imp + "\n" + final_code
original_defs = [l for l in original.splitlines() if l.strip().startswith("def ")]
for d in original_defs:
if d.split("(")[0] not in final_code:
final_code = d + "\n # [!] Function body missing, please review\n" + final_code
return final_code
def create_diff(original: str, enhanced: str):
"""Return structured diff for frontend rendering."""
diff_lines = difflib.unified_diff(
original.splitlines(), enhanced.splitlines(),
fromfile="Original", tofile="Enhanced", lineterm=""
)
formatted = []
for line in diff_lines:
if line.startswith("+") and not line.startswith("+++"):
formatted.append({"type": "add", "content": line[1:]})
elif line.startswith("-") and not line.startswith("---"):
formatted.append({"type": "remove", "content": line[1:]})
elif not line.startswith("@@"):
formatted.append({"type": "context", "content": line})
return formatted
def postprocess_code(code: str):
code = re.sub(r'^"""|"""$', '', code.strip())
lines = code.splitlines()
return "\n".join([l.replace("\t", " ").rstrip() for l in lines])
def run_model(model_name, code, language):
tokenizer = tokenizers[model_name]
model = models[model_name]
mtype = MODEL_CONFIGS[model_name]
prompt = f"fix {language} code: {code}"
if mtype == "seq2seq":
inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
outputs = model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.3,
top_p=0.95,
do_sample=False
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def enhance_code(code: str, language: str):
try:
patched_code, rule_explanations = rule_based_patch(code)
candidates = []
for m in MODEL_CONFIGS.keys():
try:
enhanced = run_model(m, patched_code, language)
enhanced = postprocess_code(enhanced)
enhanced = preserve_structure(code, enhanced)
candidates.append({"model": m, "code": enhanced})
except Exception as e:
candidates.append({"model": m, "code": f"# [!] Failed: {str(e)}"})
best = max(candidates, key=lambda c: len(c["code"]))
diff = create_diff(code, best["code"])
explanations = rule_explanations + [
{"change": "Model improvements", "reason": "Best candidate chosen among ensemble"}
]
return {
"enhanced_code": best["code"],
"diff": diff,
"candidates": candidates[:3],
"explanations": explanations
}
except Exception as e:
fallback = code + f"\n# [!] Enhancer crashed: {str(e)}"
return {
"enhanced_code": fallback,
"diff": create_diff(code, fallback),
"candidates": [],
"explanations": [{"change": "Error", "reason": str(e)}]
}
|