|
|
from transformers import (
|
|
|
AutoTokenizer,
|
|
|
AutoModelForSeq2SeqLM,
|
|
|
AutoModelForCausalLM
|
|
|
)
|
|
|
import difflib, re, torch
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = {
|
|
|
"Salesforce/codet5-base": "seq2seq",
|
|
|
"EleutherAI/gpt-neo-1.3B": "causal",
|
|
|
"microsoft/CodeGPT-small-py": "causal",
|
|
|
}
|
|
|
|
|
|
|
|
|
tokenizers, models = {}, {}
|
|
|
|
|
|
for name, mtype in MODEL_CONFIGS.items():
|
|
|
tokenizers[name] = AutoTokenizer.from_pretrained(name)
|
|
|
if mtype == "seq2seq":
|
|
|
models[name] = AutoModelForSeq2SeqLM.from_pretrained(name)
|
|
|
else:
|
|
|
models[name] = AutoModelForCausalLM.from_pretrained(name)
|
|
|
|
|
|
|
|
|
SECURE_REPLACEMENTS = {
|
|
|
"hashlib.md5": ("hashlib.sha256", "MD5 is weak, replaced with SHA-256."),
|
|
|
"hashlib.sha1": ("hashlib.sha256", "SHA1 is weak, replaced with SHA-256."),
|
|
|
"eval(": ("ast.literal_eval(", "Unsafe eval removed, replaced with safe literal_eval."),
|
|
|
"pickle.load(": ("# pickle.load removed", "pickle.load is unsafe, consider json/safe loaders."),
|
|
|
}
|
|
|
|
|
|
def rule_based_patch(code: str):
|
|
|
explanations = []
|
|
|
patched = code
|
|
|
for bad, (good, reason) in SECURE_REPLACEMENTS.items():
|
|
|
if bad in patched:
|
|
|
patched = patched.replace(bad, good)
|
|
|
explanations.append({"change": f"{bad} → {good}", "reason": reason})
|
|
|
return patched, explanations
|
|
|
|
|
|
def preserve_structure(original: str, enhanced: str):
|
|
|
"""Ensure imports and function signatures remain if model drops them."""
|
|
|
final_code = enhanced
|
|
|
original_imports = [l for l in original.splitlines() if l.strip().startswith("import")]
|
|
|
for imp in original_imports:
|
|
|
if imp not in final_code:
|
|
|
final_code = imp + "\n" + final_code
|
|
|
original_defs = [l for l in original.splitlines() if l.strip().startswith("def ")]
|
|
|
for d in original_defs:
|
|
|
if d.split("(")[0] not in final_code:
|
|
|
final_code = d + "\n # [!] Function body missing, please review\n" + final_code
|
|
|
return final_code
|
|
|
|
|
|
def create_diff(original: str, enhanced: str):
|
|
|
"""Return structured diff for frontend rendering."""
|
|
|
diff_lines = difflib.unified_diff(
|
|
|
original.splitlines(), enhanced.splitlines(),
|
|
|
fromfile="Original", tofile="Enhanced", lineterm=""
|
|
|
)
|
|
|
formatted = []
|
|
|
for line in diff_lines:
|
|
|
if line.startswith("+") and not line.startswith("+++"):
|
|
|
formatted.append({"type": "add", "content": line[1:]})
|
|
|
elif line.startswith("-") and not line.startswith("---"):
|
|
|
formatted.append({"type": "remove", "content": line[1:]})
|
|
|
elif not line.startswith("@@"):
|
|
|
formatted.append({"type": "context", "content": line})
|
|
|
return formatted
|
|
|
|
|
|
def postprocess_code(code: str):
|
|
|
code = re.sub(r'^"""|"""$', '', code.strip())
|
|
|
lines = code.splitlines()
|
|
|
return "\n".join([l.replace("\t", " ").rstrip() for l in lines])
|
|
|
|
|
|
def run_model(model_name, code, language):
|
|
|
tokenizer = tokenizers[model_name]
|
|
|
model = models[model_name]
|
|
|
mtype = MODEL_CONFIGS[model_name]
|
|
|
|
|
|
prompt = f"fix {language} code: {code}"
|
|
|
|
|
|
if mtype == "seq2seq":
|
|
|
inputs = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
|
|
|
outputs = model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
else:
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
|
|
outputs = model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=256,
|
|
|
temperature=0.3,
|
|
|
top_p=0.95,
|
|
|
do_sample=False
|
|
|
)
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
def enhance_code(code: str, language: str):
|
|
|
try:
|
|
|
patched_code, rule_explanations = rule_based_patch(code)
|
|
|
|
|
|
candidates = []
|
|
|
for m in MODEL_CONFIGS.keys():
|
|
|
try:
|
|
|
enhanced = run_model(m, patched_code, language)
|
|
|
enhanced = postprocess_code(enhanced)
|
|
|
enhanced = preserve_structure(code, enhanced)
|
|
|
candidates.append({"model": m, "code": enhanced})
|
|
|
except Exception as e:
|
|
|
candidates.append({"model": m, "code": f"# [!] Failed: {str(e)}"})
|
|
|
|
|
|
best = max(candidates, key=lambda c: len(c["code"]))
|
|
|
diff = create_diff(code, best["code"])
|
|
|
|
|
|
explanations = rule_explanations + [
|
|
|
{"change": "Model improvements", "reason": "Best candidate chosen among ensemble"}
|
|
|
]
|
|
|
|
|
|
return {
|
|
|
"enhanced_code": best["code"],
|
|
|
"diff": diff,
|
|
|
"candidates": candidates[:3],
|
|
|
"explanations": explanations
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
fallback = code + f"\n# [!] Enhancer crashed: {str(e)}"
|
|
|
return {
|
|
|
"enhanced_code": fallback,
|
|
|
"diff": create_diff(code, fallback),
|
|
|
"candidates": [],
|
|
|
"explanations": [{"change": "Error", "reason": str(e)}]
|
|
|
}
|
|
|
|